|
3 | 3 | import hashlib |
4 | 4 | import os |
5 | 5 | import sqlite3 |
6 | | -from collections import defaultdict |
7 | 6 | from pathlib import Path |
8 | 7 | from typing import TYPE_CHECKING |
9 | 8 |
|
10 | 9 | from codeflash.cli_cmds.console import logger |
11 | 10 | from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages |
12 | 11 | from codeflash.languages.base import IndexResult |
13 | | -from codeflash.models.models import FunctionSource |
14 | 12 |
|
15 | 13 | if TYPE_CHECKING: |
16 | 14 | from collections.abc import Callable, Iterable |
17 | 15 |
|
18 | 16 | from jedi.api.classes import Name |
19 | 17 |
|
| 18 | + from codeflash.models.call_graph import CallGraph |
| 19 | + from codeflash.models.models import FunctionSource |
| 20 | + |
20 | 21 |
|
21 | 22 | # --------------------------------------------------------------------------- |
22 | 23 | # Module-level helpers (must be top-level for ProcessPoolExecutor pickling) |
@@ -262,49 +263,10 @@ def _init_schema(self) -> None: |
262 | 263 | def get_callees( |
263 | 264 | self, file_path_to_qualified_names: dict[Path, set[str]] |
264 | 265 | ) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]: |
265 | | - file_path_to_function_source: dict[Path, set[FunctionSource]] = defaultdict(set) |
266 | | - function_source_list: list[FunctionSource] = [] |
267 | | - |
268 | | - all_caller_keys: list[tuple[str, str]] = [] |
269 | | - for file_path, qualified_names in file_path_to_qualified_names.items(): |
270 | | - resolved = str(file_path.resolve()) |
271 | | - self.ensure_file_indexed(file_path, resolved) |
272 | | - all_caller_keys.extend((resolved, qn) for qn in qualified_names) |
| 266 | + from codeflash.models.call_graph import callees_from_graph |
273 | 267 |
|
274 | | - if not all_caller_keys: |
275 | | - return file_path_to_function_source, function_source_list |
276 | | - |
277 | | - cur = self.conn.cursor() |
278 | | - cur.execute("CREATE TEMP TABLE IF NOT EXISTS _caller_keys (caller_file TEXT, caller_qualified_name TEXT)") |
279 | | - cur.execute("DELETE FROM _caller_keys") |
280 | | - cur.executemany("INSERT INTO _caller_keys VALUES (?, ?)", all_caller_keys) |
281 | | - |
282 | | - rows = cur.execute( |
283 | | - """ |
284 | | - SELECT ce.callee_file, ce.callee_qualified_name, ce.callee_fully_qualified_name, |
285 | | - ce.callee_only_function_name, ce.callee_definition_type, ce.callee_source_line |
286 | | - FROM call_edges ce |
287 | | - INNER JOIN _caller_keys ck |
288 | | - ON ce.caller_file = ck.caller_file AND ce.caller_qualified_name = ck.caller_qualified_name |
289 | | - WHERE ce.project_root = ? AND ce.language = ? |
290 | | - """, |
291 | | - (self.project_root_str, self.language), |
292 | | - ).fetchall() |
293 | | - |
294 | | - for callee_file, callee_qn, callee_fqn, callee_name, callee_type, callee_src in rows: |
295 | | - callee_path = Path(callee_file) |
296 | | - fs = FunctionSource( |
297 | | - file_path=callee_path, |
298 | | - qualified_name=callee_qn, |
299 | | - fully_qualified_name=callee_fqn, |
300 | | - only_function_name=callee_name, |
301 | | - source_code=callee_src, |
302 | | - definition_type=callee_type, |
303 | | - ) |
304 | | - file_path_to_function_source[callee_path].add(fs) |
305 | | - function_source_list.append(fs) |
306 | | - |
307 | | - return file_path_to_function_source, function_source_list |
| 268 | + graph = self.get_call_graph(file_path_to_qualified_names, include_metadata=True) |
| 269 | + return callees_from_graph(graph) |
308 | 270 |
|
309 | 271 | def count_callees_per_function( |
310 | 272 | self, file_path_to_qualified_names: dict[Path, set[str]] |
@@ -540,5 +502,90 @@ def _fallback_sequential_index( |
540 | 502 | result = self.index_file(file_path, file_hash, resolved) |
541 | 503 | self._report_progress(on_progress, result) |
542 | 504 |
|
| 505 | + def get_call_graph( |
| 506 | + self, file_path_to_qualified_names: dict[Path, set[str]], *, include_metadata: bool = False |
| 507 | + ) -> CallGraph: |
| 508 | + from codeflash.models.call_graph import CallEdge, CalleeMetadata, CallGraph, FunctionNode |
| 509 | + |
| 510 | + all_caller_keys: list[tuple[Path, str, str]] = [] |
| 511 | + for file_path, qualified_names in file_path_to_qualified_names.items(): |
| 512 | + resolved = str(file_path.resolve()) |
| 513 | + self.ensure_file_indexed(file_path, resolved) |
| 514 | + all_caller_keys.extend((file_path, resolved, qn) for qn in qualified_names) |
| 515 | + |
| 516 | + if not all_caller_keys: |
| 517 | + return CallGraph(edges=[]) |
| 518 | + |
| 519 | + cur = self.conn.cursor() |
| 520 | + cur.execute("CREATE TEMP TABLE IF NOT EXISTS _graph_keys (caller_file TEXT, caller_qualified_name TEXT)") |
| 521 | + cur.execute("DELETE FROM _graph_keys") |
| 522 | + cur.executemany( |
| 523 | + "INSERT INTO _graph_keys VALUES (?, ?)", [(resolved, qn) for _, resolved, qn in all_caller_keys] |
| 524 | + ) |
| 525 | + |
| 526 | + if include_metadata: |
| 527 | + rows = cur.execute( |
| 528 | + """ |
| 529 | + SELECT ce.caller_file, ce.caller_qualified_name, |
| 530 | + ce.callee_file, ce.callee_qualified_name, |
| 531 | + ce.callee_fully_qualified_name, ce.callee_only_function_name, |
| 532 | + ce.callee_definition_type, ce.callee_source_line |
| 533 | + FROM call_edges ce |
| 534 | + INNER JOIN _graph_keys gk |
| 535 | + ON ce.caller_file = gk.caller_file AND ce.caller_qualified_name = gk.caller_qualified_name |
| 536 | + WHERE ce.project_root = ? AND ce.language = ? |
| 537 | + """, |
| 538 | + (self.project_root_str, self.language), |
| 539 | + ).fetchall() |
| 540 | + |
| 541 | + edges: list[CallEdge] = [] |
| 542 | + for ( |
| 543 | + caller_file, |
| 544 | + caller_qn, |
| 545 | + callee_file, |
| 546 | + callee_qn, |
| 547 | + callee_fqn, |
| 548 | + callee_name, |
| 549 | + callee_type, |
| 550 | + callee_src, |
| 551 | + ) in rows: |
| 552 | + edges.append( |
| 553 | + CallEdge( |
| 554 | + caller=FunctionNode(file_path=Path(caller_file), qualified_name=caller_qn), |
| 555 | + callee=FunctionNode(file_path=Path(callee_file), qualified_name=callee_qn), |
| 556 | + is_cross_file=caller_file != callee_file, |
| 557 | + callee_metadata=CalleeMetadata( |
| 558 | + fully_qualified_name=callee_fqn, |
| 559 | + only_function_name=callee_name, |
| 560 | + definition_type=callee_type, |
| 561 | + source_line=callee_src, |
| 562 | + ), |
| 563 | + ) |
| 564 | + ) |
| 565 | + else: |
| 566 | + rows = cur.execute( |
| 567 | + """ |
| 568 | + SELECT ce.caller_file, ce.caller_qualified_name, |
| 569 | + ce.callee_file, ce.callee_qualified_name |
| 570 | + FROM call_edges ce |
| 571 | + INNER JOIN _graph_keys gk |
| 572 | + ON ce.caller_file = gk.caller_file AND ce.caller_qualified_name = gk.caller_qualified_name |
| 573 | + WHERE ce.project_root = ? AND ce.language = ? |
| 574 | + """, |
| 575 | + (self.project_root_str, self.language), |
| 576 | + ).fetchall() |
| 577 | + |
| 578 | + edges = [] |
| 579 | + for caller_file, caller_qn, callee_file, callee_qn in rows: |
| 580 | + edges.append( |
| 581 | + CallEdge( |
| 582 | + caller=FunctionNode(file_path=Path(caller_file), qualified_name=caller_qn), |
| 583 | + callee=FunctionNode(file_path=Path(callee_file), qualified_name=callee_qn), |
| 584 | + is_cross_file=caller_file != callee_file, |
| 585 | + ) |
| 586 | + ) |
| 587 | + |
| 588 | + return CallGraph(edges=edges) |
| 589 | + |
543 | 590 | def close(self) -> None: |
544 | 591 | self.conn.close() |
0 commit comments