Skip to content

Commit a90cda2

Browse files
committed
feat: boost ranking for tested functions and enable reference graph
- Add existing_unit_test_count() with parametrized test deduplication - Stable-sort ranked functions so tested ones come first - Enable reference graph resolver (was disabled) for non-CI runs - Add per-function logging with ref count and test count - Auto-upgrade top N functions to high effort when user hasn't set --effort - Add CallGraph model with traversal (BFS, topological sort, subgraph) - Add get_call_graph() to DependencyResolver protocol and ReferenceGraph - Refactor get_callees() to delegate through get_call_graph() CF-1660
1 parent 5cef345 commit a90cda2

10 files changed

Lines changed: 1031 additions & 71 deletions

File tree

codeflash/code_utils/config_consts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050

5151
MAX_CONTEXT_LEN_REVIEW = 1000
5252

53+
HIGH_EFFORT_TOP_N = 15
54+
5355

5456
class EffortLevel(str, Enum):
5557
LOW = "low"

codeflash/discovery/discover_unit_tests.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,21 @@
3535
from codeflash.verification.verification_utils import TestConfig
3636

3737

38+
def existing_unit_test_count(
39+
func: FunctionToOptimize, project_root: Path, function_to_tests: dict[str, set[FunctionCalledInTest]]
40+
) -> int:
41+
key = func.qualified_name_with_modules_from_root(project_root)
42+
tests = function_to_tests.get(key, set())
43+
seen: set[tuple[Path, str | None, str]] = set()
44+
for t in tests:
45+
if t.tests_in_file.test_type != TestType.EXISTING_UNIT_TEST:
46+
continue
47+
tif = t.tests_in_file
48+
base_name = tif.test_function.split("[", 1)[0]
49+
seen.add((tif.test_file, tif.test_class, base_name))
50+
return len(seen)
51+
52+
3853
@final
3954
class PytestExitCode(enum.IntEnum): # don't need to import entire pytest just for this
4055
#: Tests passed.

codeflash/languages/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pathlib import Path
1919

2020
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
21+
from codeflash.models.call_graph import CallGraph
2122
from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId, ValidCode
2223
from codeflash.verification.verification_utils import TestConfig
2324

@@ -250,6 +251,12 @@ def count_callees_per_function(
250251
"""Return the number of callees for each (file_path, qualified_name) pair."""
251252
...
252253

254+
def get_call_graph(
255+
self, file_path_to_qualified_names: dict[Path, set[str]], *, include_metadata: bool = False
256+
) -> CallGraph:
257+
"""Return a CallGraph with full caller→callee edges for the given functions."""
258+
...
259+
253260
def close(self) -> None:
254261
"""Release resources (e.g. database connections)."""
255262
...

codeflash/languages/function_optimizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ def __init__(
425425
args: Namespace | None = None,
426426
replay_tests_dir: Path | None = None,
427427
call_graph: DependencyResolver | None = None,
428+
effort_override: str | None = None,
428429
) -> None:
429430
self.project_root = test_cfg.project_root_path.resolve()
430431
self.test_cfg = test_cfg
@@ -451,7 +452,8 @@ def __init__(
451452
self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None
452453
self.test_files = TestFiles(test_files=[])
453454

454-
self.effort = getattr(args, "effort", EffortLevel.MEDIUM.value) if args else EffortLevel.MEDIUM.value
455+
default_effort = getattr(args, "effort", EffortLevel.MEDIUM.value) if args else EffortLevel.MEDIUM.value
456+
self.effort = effort_override or default_effort
455457

456458
self.args = args # Check defaults for these
457459
self.function_trace_id: str = str(uuid.uuid4())

codeflash/languages/python/reference_graph.py

Lines changed: 91 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,21 @@
33
import hashlib
44
import os
55
import sqlite3
6-
from collections import defaultdict
76
from pathlib import Path
87
from typing import TYPE_CHECKING
98

109
from codeflash.cli_cmds.console import logger
1110
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
1211
from codeflash.languages.base import IndexResult
13-
from codeflash.models.models import FunctionSource
1412

1513
if TYPE_CHECKING:
1614
from collections.abc import Callable, Iterable
1715

1816
from jedi.api.classes import Name
1917

18+
from codeflash.models.call_graph import CallGraph
19+
from codeflash.models.models import FunctionSource
20+
2021

2122
# ---------------------------------------------------------------------------
2223
# Module-level helpers (must be top-level for ProcessPoolExecutor pickling)
@@ -262,49 +263,10 @@ def _init_schema(self) -> None:
262263
def get_callees(
263264
self, file_path_to_qualified_names: dict[Path, set[str]]
264265
) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]:
265-
file_path_to_function_source: dict[Path, set[FunctionSource]] = defaultdict(set)
266-
function_source_list: list[FunctionSource] = []
267-
268-
all_caller_keys: list[tuple[str, str]] = []
269-
for file_path, qualified_names in file_path_to_qualified_names.items():
270-
resolved = str(file_path.resolve())
271-
self.ensure_file_indexed(file_path, resolved)
272-
all_caller_keys.extend((resolved, qn) for qn in qualified_names)
266+
from codeflash.models.call_graph import callees_from_graph
273267

274-
if not all_caller_keys:
275-
return file_path_to_function_source, function_source_list
276-
277-
cur = self.conn.cursor()
278-
cur.execute("CREATE TEMP TABLE IF NOT EXISTS _caller_keys (caller_file TEXT, caller_qualified_name TEXT)")
279-
cur.execute("DELETE FROM _caller_keys")
280-
cur.executemany("INSERT INTO _caller_keys VALUES (?, ?)", all_caller_keys)
281-
282-
rows = cur.execute(
283-
"""
284-
SELECT ce.callee_file, ce.callee_qualified_name, ce.callee_fully_qualified_name,
285-
ce.callee_only_function_name, ce.callee_definition_type, ce.callee_source_line
286-
FROM call_edges ce
287-
INNER JOIN _caller_keys ck
288-
ON ce.caller_file = ck.caller_file AND ce.caller_qualified_name = ck.caller_qualified_name
289-
WHERE ce.project_root = ? AND ce.language = ?
290-
""",
291-
(self.project_root_str, self.language),
292-
).fetchall()
293-
294-
for callee_file, callee_qn, callee_fqn, callee_name, callee_type, callee_src in rows:
295-
callee_path = Path(callee_file)
296-
fs = FunctionSource(
297-
file_path=callee_path,
298-
qualified_name=callee_qn,
299-
fully_qualified_name=callee_fqn,
300-
only_function_name=callee_name,
301-
source_code=callee_src,
302-
definition_type=callee_type,
303-
)
304-
file_path_to_function_source[callee_path].add(fs)
305-
function_source_list.append(fs)
306-
307-
return file_path_to_function_source, function_source_list
268+
graph = self.get_call_graph(file_path_to_qualified_names, include_metadata=True)
269+
return callees_from_graph(graph)
308270

309271
def count_callees_per_function(
310272
self, file_path_to_qualified_names: dict[Path, set[str]]
@@ -540,5 +502,90 @@ def _fallback_sequential_index(
540502
result = self.index_file(file_path, file_hash, resolved)
541503
self._report_progress(on_progress, result)
542504

505+
def get_call_graph(
506+
self, file_path_to_qualified_names: dict[Path, set[str]], *, include_metadata: bool = False
507+
) -> CallGraph:
508+
from codeflash.models.call_graph import CallEdge, CalleeMetadata, CallGraph, FunctionNode
509+
510+
all_caller_keys: list[tuple[Path, str, str]] = []
511+
for file_path, qualified_names in file_path_to_qualified_names.items():
512+
resolved = str(file_path.resolve())
513+
self.ensure_file_indexed(file_path, resolved)
514+
all_caller_keys.extend((file_path, resolved, qn) for qn in qualified_names)
515+
516+
if not all_caller_keys:
517+
return CallGraph(edges=[])
518+
519+
cur = self.conn.cursor()
520+
cur.execute("CREATE TEMP TABLE IF NOT EXISTS _graph_keys (caller_file TEXT, caller_qualified_name TEXT)")
521+
cur.execute("DELETE FROM _graph_keys")
522+
cur.executemany(
523+
"INSERT INTO _graph_keys VALUES (?, ?)", [(resolved, qn) for _, resolved, qn in all_caller_keys]
524+
)
525+
526+
if include_metadata:
527+
rows = cur.execute(
528+
"""
529+
SELECT ce.caller_file, ce.caller_qualified_name,
530+
ce.callee_file, ce.callee_qualified_name,
531+
ce.callee_fully_qualified_name, ce.callee_only_function_name,
532+
ce.callee_definition_type, ce.callee_source_line
533+
FROM call_edges ce
534+
INNER JOIN _graph_keys gk
535+
ON ce.caller_file = gk.caller_file AND ce.caller_qualified_name = gk.caller_qualified_name
536+
WHERE ce.project_root = ? AND ce.language = ?
537+
""",
538+
(self.project_root_str, self.language),
539+
).fetchall()
540+
541+
edges: list[CallEdge] = []
542+
for (
543+
caller_file,
544+
caller_qn,
545+
callee_file,
546+
callee_qn,
547+
callee_fqn,
548+
callee_name,
549+
callee_type,
550+
callee_src,
551+
) in rows:
552+
edges.append(
553+
CallEdge(
554+
caller=FunctionNode(file_path=Path(caller_file), qualified_name=caller_qn),
555+
callee=FunctionNode(file_path=Path(callee_file), qualified_name=callee_qn),
556+
is_cross_file=caller_file != callee_file,
557+
callee_metadata=CalleeMetadata(
558+
fully_qualified_name=callee_fqn,
559+
only_function_name=callee_name,
560+
definition_type=callee_type,
561+
source_line=callee_src,
562+
),
563+
)
564+
)
565+
else:
566+
rows = cur.execute(
567+
"""
568+
SELECT ce.caller_file, ce.caller_qualified_name,
569+
ce.callee_file, ce.callee_qualified_name
570+
FROM call_edges ce
571+
INNER JOIN _graph_keys gk
572+
ON ce.caller_file = gk.caller_file AND ce.caller_qualified_name = gk.caller_qualified_name
573+
WHERE ce.project_root = ? AND ce.language = ?
574+
""",
575+
(self.project_root_str, self.language),
576+
).fetchall()
577+
578+
edges = []
579+
for caller_file, caller_qn, callee_file, callee_qn in rows:
580+
edges.append(
581+
CallEdge(
582+
caller=FunctionNode(file_path=Path(caller_file), qualified_name=caller_qn),
583+
callee=FunctionNode(file_path=Path(callee_file), qualified_name=callee_qn),
584+
is_cross_file=caller_file != callee_file,
585+
)
586+
)
587+
588+
return CallGraph(edges=edges)
589+
543590
def close(self) -> None:
544591
self.conn.close()

codeflash/languages/python/support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,7 @@ def create_dependency_resolver(self, project_root: Path) -> DependencyResolver |
957957
try:
958958
return ReferenceGraph(project_root, language=self.language.value)
959959
except Exception:
960-
logger.debug("Failed to initialize ReferenceGraph, falling back to per-function Jedi analysis")
960+
logger.info("Failed to initialize ReferenceGraph, falling back to per-function Jedi analysis")
961961
return None
962962

963963
def instrument_existing_test(

0 commit comments

Comments
 (0)