Skip to content

Commit 6916884

Browse files
committed
feat(pycg): coupling-aware shard planning from the Jedi module graph
Sharding lets PyCG (level 2) scale past its ~500-file ceiling by analysing the project in independent pieces. The existing scheme shards one-per-package with a flat file-count ceiling, which is blind to call coupling: it severs heavily-interacting modules (their cross-shard edges become ghost nodes PyCG never resolves) and drops oversized packages wholesale. Add a coupling-aware planner that partitions the module-dependency graph *derived from the Jedi call graph already computed at level 1*: 1. project Jedi callable->callable edges onto a weighted module DiGraph; 2. condense strongly-connected components (import cycles become atomic and are never split across shards); 3. cluster with Louvain so tightly-coupled modules co-compute; 4. enforce the per-shard file budget (re-partition oversized communities, then merge/first-fit-pack the remainder to recover edges and cut count). The reported cut_ratio (fraction of Jedi edge weight crossing shard boundaries) is an upper bound on PyCG edges lost to sharding; on a synthetic worst case it drops from 0.55 (per-package) to 0.03. Wire it into PyCG behind --pycg-shard-strategy {jedi,package} (default jedi). Because planner shards are arbitrary file sets rather than directories, each runs through a temporary symlink mini-project (_shard_symlink_root) so PyCG's own package-root bound confines analysis to the shard and emits project-relative edge names with no prefix rewrite. Thread the level-1 Jedi edges through core -> _get_pycg_call_graph -> build_call_graph_edges to feed the planner. Ray parallelism falls back to sequential under the jedi strategy for now. Add test/test_shard_planner.py (graph projection, SCC atomicity, budget, single-assignment, cut-ratio vs naive, determinism).
1 parent 0799828 commit 6916884

7 files changed

Lines changed: 774 additions & 14 deletions

File tree

codeanalyzer/__main__.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from codeanalyzer.utils import _set_log_level, logger
88
from codeanalyzer.config import OutputFormat
99
from codeanalyzer.schema import model_dump_json
10-
from codeanalyzer.options import AnalysisOptions, EmitTarget
10+
from codeanalyzer.options import AnalysisOptions, EmitTarget, ShardStrategy
1111

1212

1313
def main(
@@ -186,6 +186,20 @@ def main(
186186
min=0,
187187
),
188188
] = 120,
189+
pycg_shard_strategy: Annotated[
190+
ShardStrategy,
191+
typer.Option(
192+
"--pycg-shard-strategy",
193+
help=(
194+
"How --pycg-shard groups files (level 2 only). 'jedi' (default) "
195+
"partitions the Jedi module-dependency graph (SCC + Louvain) so "
196+
"tightly-coupled modules co-compute and few call edges are "
197+
"severed between shards; import cycles are never split. "
198+
"'package' uses the legacy one-shard-per-package-directory "
199+
"grouping."
200+
),
201+
),
202+
] = ShardStrategy.JEDI,
189203
):
190204
options = AnalysisOptions(
191205
input=input,
@@ -209,6 +223,7 @@ def main(
209223
pycg_shard=pycg_shard,
210224
pycg_shard_ceiling=pycg_shard_ceiling,
211225
pycg_shard_timeout=pycg_shard_timeout,
226+
pycg_shard_strategy=pycg_shard_strategy,
212227
)
213228

214229
_set_log_level(options.verbosity)

codeanalyzer/core.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,9 @@ def analyze(self) -> PyApplication:
433433
logger.info("✅ Jedi: %d edges in %.1fs", len(call_graph), time.perf_counter() - t0_jedi)
434434

435435
if self.analysis_level >= 2:
436-
# Level 2: also add PyCG edges.
437-
pycg_edges = self._get_pycg_call_graph(symbol_table)
436+
# Level 2: also add PyCG edges. The Jedi edges double as the
437+
# coupling graph that drives coupling-aware PyCG sharding.
438+
pycg_edges = self._get_pycg_call_graph(symbol_table, jedi_edges)
438439
call_graph = merge_edges(call_graph, pycg_edges)
439440

440441
call_graph = filter_external_edges(call_graph, symbol_table)
@@ -661,13 +662,18 @@ def _build_symbol_table(self, cached_symbol_table: Optional[Dict[str, PyModule]]
661662
def _get_pycg_call_graph(
662663
self,
663664
symbol_table: Dict[str, PyModule],
665+
jedi_edges: List[PyCallEdge],
664666
) -> List[PyCallEdge]:
665667
"""Build PyCG-resolved call edges.
666668
667669
Runs PyCG's iterative name-pointer analysis over the whole project
668670
and returns edges with ``provenance=["pycg"]``. Falls back to an
669671
empty list and logs a warning on any failure so the caller can
670672
continue with Jedi-only edges.
673+
674+
*jedi_edges* are the level-1 call edges; under the ``jedi`` shard
675+
strategy they drive coupling-aware partitioning (see
676+
:func:`shard_planner.plan_shards`).
671677
"""
672678
try:
673679
pycg = PyCG(
@@ -676,9 +682,10 @@ def _get_pycg_call_graph(
676682
shard=self.options.pycg_shard,
677683
shard_ceiling=self.options.pycg_shard_ceiling,
678684
shard_timeout=self.options.pycg_shard_timeout,
685+
shard_strategy=self.options.pycg_shard_strategy,
679686
using_ray=self.using_ray,
680687
)
681-
return pycg.build_call_graph_edges(symbol_table)
688+
return pycg.build_call_graph_edges(symbol_table, jedi_edges=jedi_edges)
682689
except PyCGExceptions.PyCGImportError as exc:
683690
logger.warning(f"PyCG not installed — level 2 edges will be Jedi-only: {exc}")
684691
return []

codeanalyzer/options/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .options import AnalysisOptions, EmitTarget, OutputFormat
1+
from .options import AnalysisOptions, EmitTarget, OutputFormat, ShardStrategy
22

3-
__all__ = ["AnalysisOptions", "EmitTarget", "OutputFormat"]
3+
__all__ = ["AnalysisOptions", "EmitTarget", "OutputFormat", "ShardStrategy"]

codeanalyzer/options/options.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,20 @@ class EmitTarget(str, Enum):
2323
SCHEMA = "schema"
2424

2525

26+
class ShardStrategy(str, Enum):
27+
"""How ``--pycg-shard`` groups files into shards (level 2 only).
28+
29+
- ``jedi`` : partition the Jedi module-dependency graph (strongly-
30+
connected-component condensation + Louvain) so tightly-
31+
coupled modules co-compute and few call edges are severed
32+
between shards. Import cycles are never split.
33+
- ``package`` : legacy one-shard-per-package-directory grouping.
34+
"""
35+
36+
JEDI = "jedi"
37+
PACKAGE = "package"
38+
39+
2640
@dataclass
2741
class AnalysisOptions:
2842
input: Path
@@ -46,3 +60,4 @@ class AnalysisOptions:
4660
pycg_shard: bool = False
4761
pycg_shard_ceiling: int = 100
4862
pycg_shard_timeout: int = 120
63+
pycg_shard_strategy: ShardStrategy = ShardStrategy.JEDI

codeanalyzer/semantic_analysis/pycg/pycg_analysis.py

Lines changed: 181 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,14 @@
4343
import importlib.util # noqa: F401
4444
import contextlib
4545
import json # noqa: F401
46+
import shutil
4647
import signal
48+
import tempfile
4749
import time
4850

4951
from collections import Counter, defaultdict
5052
from pathlib import Path
51-
from typing import Any, Dict, Generator, List, Optional, Set, Union
53+
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
5254

5355

5456
@contextlib.contextmanager
@@ -79,9 +81,64 @@ def _handler(signum: int, frame: object) -> None:
7981
from codeanalyzer.schema.py_schema import PyCallEdge, PyModule
8082
from codeanalyzer.semantic_analysis.call_graph import iter_callables_in_symbol_table
8183
from codeanalyzer.semantic_analysis.pycg.pycg_exceptions import PyCGExceptions
84+
from codeanalyzer.semantic_analysis.pycg.shard_planner import plan_shards
8285
from codeanalyzer.utils import ProgressBar, logger
8386

8487

88+
@contextlib.contextmanager
89+
def _shard_symlink_root(
90+
files: List[str],
91+
project_dir: Path,
92+
) -> Generator[Tuple[Path, List[str]], None, None]:
93+
"""Materialise a shard's files as a temporary mini-project.
94+
95+
PyCG bounds its import-following to the ``package`` directory — only
96+
modules whose resolved file lives under that root are followed; everything
97+
else becomes a ghost node (``ImportManager``: ``if self.mod_dir not in
98+
mod.__file__: return``). A coupling-derived shard is an arbitrary set of
99+
files that need not form a directory, so we mirror the project layout into
100+
a temp dir holding symlinks to exactly the shard's files plus the
101+
``__init__.py`` chain each needs for package resolution. Running PyCG with
102+
this mirror as the package root confines analysis to the shard while
103+
emitting project-relative edge names (so ``prefix=""`` — no rename needed).
104+
105+
Yields ``(root, entry_points)`` where *entry_points* are the symlinked
106+
paths inside *root*. The temp tree is removed on exit.
107+
"""
108+
root = Path(tempfile.mkdtemp(prefix="canpy_pycg_shard_"))
109+
entry_points: List[str] = []
110+
linked_inits: Set[Path] = set()
111+
try:
112+
for f in files:
113+
src = Path(f).resolve()
114+
try:
115+
rel = src.relative_to(project_dir)
116+
except ValueError:
117+
continue # defensively skip files outside the project
118+
dst = root / rel
119+
dst.parent.mkdir(parents=True, exist_ok=True)
120+
if not dst.exists():
121+
dst.symlink_to(src)
122+
entry_points.append(str(dst))
123+
124+
# Symlink the __init__.py chain from project root down to this
125+
# file's package so PyCG/importlib can resolve the dotted module
126+
# name. These add ~0 analysis cost (usually empty) and keep
127+
# out-of-shard siblings unresolved → ghost nodes.
128+
for i in range(len(rel.parent.parts) + 1):
129+
pkg_rel = Path(*rel.parent.parts[:i])
130+
real_init = project_dir / pkg_rel / "__init__.py"
131+
link_init = root / pkg_rel / "__init__.py"
132+
if real_init.exists() and link_init not in linked_inits:
133+
link_init.parent.mkdir(parents=True, exist_ok=True)
134+
if not link_init.exists():
135+
link_init.symlink_to(real_init.resolve())
136+
linked_inits.add(link_init)
137+
yield root, entry_points
138+
finally:
139+
shutil.rmtree(root, ignore_errors=True)
140+
141+
85142
def _pycg_shard_worker(
86143
entry_points: List[str],
87144
package_dir: str,
@@ -313,6 +370,7 @@ def __init__(
313370
shard: bool = False,
314371
shard_ceiling: Optional[int] = None,
315372
shard_timeout: Optional[int] = None,
373+
shard_strategy: str = "jedi",
316374
using_ray: bool = False,
317375
) -> None:
318376
self.project_dir = Path(project_dir).resolve()
@@ -324,9 +382,31 @@ def __init__(
324382
self.shard_timeout = (
325383
shard_timeout if shard_timeout is not None else self._PYCG_SHARD_TIMEOUT
326384
)
385+
# "jedi": partition the Jedi module graph (SCC + Louvain) so coupled
386+
# modules co-compute and few edges are severed (see shard_planner).
387+
# "package": legacy one-shard-per-package-directory grouping.
388+
self.shard_strategy = shard_strategy
327389
self.using_ray = using_ray
328390
self._CallGraphGenerator: Optional[Any] = None
329391

392+
@staticmethod
393+
def _coalesce_edges(edges: List[PyCallEdge]) -> List[PyCallEdge]:
394+
"""Sum weights of duplicate ``(source, target)`` pairs across shards."""
395+
merged: Dict[tuple, PyCallEdge] = {}
396+
for edge in edges:
397+
key = (edge.source, edge.target)
398+
if key in merged:
399+
existing = merged[key]
400+
merged[key] = PyCallEdge(
401+
source=existing.source,
402+
target=existing.target,
403+
weight=existing.weight + edge.weight,
404+
provenance=existing.provenance,
405+
)
406+
else:
407+
merged[key] = edge
408+
return list(merged.values())
409+
330410
# ------------------------------------------------------------------
331411
# Entry-point collection
332412
# ------------------------------------------------------------------
@@ -455,6 +535,88 @@ def _run_pycg_batch(
455535
# Sharded analysis
456536
# ------------------------------------------------------------------
457537

538+
def _build_sharded_planned(
539+
self,
540+
jedi_edges: List[PyCallEdge],
541+
symbol_table: Dict[str, PyModule],
542+
resolver: "_PyCGCallableResolver",
543+
) -> List[PyCallEdge]:
544+
"""Coupling-aware sharding driven by the Jedi module graph.
545+
546+
Unlike :meth:`_build_sharded` (one shard per package directory), the
547+
shards here are chosen to *minimise the call edges severed between
548+
shards*: :func:`shard_planner.plan_shards` condenses the Jedi call
549+
graph by strongly-connected component (so import cycles never split)
550+
and clusters it with Louvain so tightly-coupled modules land together.
551+
Each shard — an arbitrary set of files — is run through PyCG via a
552+
symlinked mini-project (:func:`_shard_symlink_root`) that bounds PyCG
553+
to exactly those files.
554+
555+
Reported ``cut_ratio`` is the fraction of Jedi edge weight crossing
556+
shard boundaries — an upper bound on the PyCG edges lost to sharding.
557+
"""
558+
plan = plan_shards(
559+
symbol_table, jedi_edges, budget=self.shard_ceiling, merge_small=True
560+
)
561+
m = plan.metrics
562+
logger.info(
563+
"PyCG: planned %d shard(s) from Jedi module graph "
564+
"(cut_ratio=%.3f, max_shard=%d files, %d modules)",
565+
int(m["num_shards"]), m["cut_ratio"],
566+
int(m["max_shard_files"]), int(m["modules"]),
567+
)
568+
if m["oversized_shards"]:
569+
logger.warning(
570+
"PyCG: %d shard(s) exceed the %d-file ceiling — skipped "
571+
"(atomic import cycles larger than the budget)",
572+
int(m["oversized_shards"]), self.shard_ceiling,
573+
)
574+
575+
if self.using_ray:
576+
logger.info(
577+
"PyCG: Ray parallelism is not yet wired for the 'jedi' shard "
578+
"strategy — running shards sequentially."
579+
)
580+
581+
all_edges: List[PyCallEdge] = []
582+
skipped = 0
583+
with ProgressBar(len(plan.shards), "Building call graph shards", item_label="shards") as progress:
584+
for idx, files in enumerate(plan.shards):
585+
n = len(files)
586+
if n > self.shard_ceiling:
587+
skipped += 1
588+
progress.advance()
589+
continue
590+
try:
591+
with _shard_symlink_root(files, self.project_dir) as (root, eps):
592+
with _shard_timeout(self.shard_timeout):
593+
edges = self._run_pycg_batch(eps, root, resolver, prefix="")
594+
all_edges.extend(edges)
595+
logger.debug("PyCG shard %d: %d edges from %d files", idx, len(edges), n)
596+
except TimeoutError:
597+
logger.warning(
598+
"PyCG shard %d timed out after %ds — skipped",
599+
idx, self.shard_timeout,
600+
)
601+
skipped += 1
602+
except PyCGExceptions.PyCGAnalysisError as exc:
603+
logger.warning("PyCG shard %d failed — skipped: %s", idx, exc)
604+
skipped += 1
605+
progress.advance()
606+
607+
if skipped:
608+
logger.warning(
609+
"PyCG: %d/%d shard(s) skipped (ceiling, %ds timeout, or failure)",
610+
skipped, len(plan.shards), self.shard_timeout,
611+
)
612+
613+
result = self._coalesce_edges(all_edges)
614+
logger.info(
615+
"PyCG: %d edges from %d/%d shard(s) (%d before dedup, Jedi-planned)",
616+
len(result), len(plan.shards) - skipped, len(plan.shards), len(all_edges),
617+
)
618+
return result
619+
458620
def _build_sharded(
459621
self,
460622
entry_points: List[str],
@@ -668,7 +830,9 @@ def _build_sharded_ray(self, shards: Dict[Path, List[str]]) -> List[PyCallEdge]:
668830
# ------------------------------------------------------------------
669831

670832
def build_call_graph_edges(
671-
self, symbol_table: Dict[str, PyModule]
833+
self,
834+
symbol_table: Dict[str, PyModule],
835+
jedi_edges: Optional[List[PyCallEdge]] = None,
672836
) -> List[PyCallEdge]:
673837
"""Run PyCG and return ``PyCallEdge`` entries with ``provenance=["pycg"]``.
674838
@@ -701,12 +865,21 @@ def build_call_graph_edges(
701865

702866
if n_files > self._PYCG_FILE_CEILING:
703867
if self.shard:
704-
mode = "Ray-parallel" if self.using_ray else "sequential"
705-
logger.info(
706-
"PyCG: starting sharded call graph analysis (%d files, %s)",
707-
n_files, mode,
708-
)
709-
edges = self._build_sharded(entry_points, resolver)
868+
if self.shard_strategy == "jedi" and jedi_edges is not None:
869+
logger.info(
870+
"PyCG: starting Jedi-planned sharded analysis (%d files)",
871+
n_files,
872+
)
873+
edges = self._build_sharded_planned(
874+
jedi_edges, symbol_table, resolver
875+
)
876+
else:
877+
mode = "Ray-parallel" if self.using_ray else "sequential"
878+
logger.info(
879+
"PyCG: starting per-package sharded analysis (%d files, %s)",
880+
n_files, mode,
881+
)
882+
edges = self._build_sharded(entry_points, resolver)
710883
else:
711884
logger.warning(
712885
"PyCG: %d entry points exceeds ceiling of %d — "

0 commit comments

Comments
 (0)