Skip to content

Commit 10afac1

Browse files
committed
Add Ray-parallel PyCG shard processing
When --ray and --pycg-shard are both active, PyCG shards are submitted as Ray remote tasks simultaneously instead of running sequentially. Per-shard timeout is enforced via ray.wait(timeout=N) + ray.cancel at the orchestrator level. Key changes: - _pycg_shard_worker: picklable module-level function that runs PyCG in a Ray worker and returns (src, dst, weight) tuples - PyCG._build_sharded_ray: submits all eligible shards as ray.remote tasks, collects results with ray.wait(num_returns=N, timeout=T), cancels and logs stragglers, then runs the same dedup/merge as the sequential path - PyCG.__init__: new using_ray parameter (default False) - core._get_pycg_call_graph: passes using_ray=self.using_ray to PyCG Signed-off-by: Saurabh Sinha <sinha108@gmail.com>
1 parent 234b7ee commit 10afac1

2 files changed

Lines changed: 152 additions & 0 deletions

File tree

codeanalyzer/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,7 @@ def _get_pycg_call_graph(
636636
shard=self.options.pycg_shard,
637637
shard_ceiling=self.options.pycg_shard_ceiling,
638638
shard_timeout=self.options.pycg_shard_timeout,
639+
using_ray=self.using_ray,
639640
)
640641
return pycg.build_call_graph_edges(symbol_table)
641642
except PyCGExceptions.PyCGImportError as exc:

codeanalyzer/semantic_analysis/pycg/pycg_analysis.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,62 @@ def _handler(signum: int, frame: object) -> None:
8181
from codeanalyzer.utils import logger
8282

8383

84+
def _pycg_shard_worker(
85+
entry_points: List[str],
86+
package_dir: str,
87+
prefix: str,
88+
) -> List[tuple]:
89+
"""Run PyCG on one shard; called in a Ray worker process.
90+
91+
Returns a list of ``(source, target, weight)`` tuples that the caller
92+
converts to :class:`PyCallEdge` objects. This function is a plain
93+
module-level callable so it can be pickled by Ray without capturing any
94+
class-level state.
95+
"""
96+
import importlib
97+
import sys
98+
99+
# Python 3.13 compatibility pre-imports (mirroring the top-level block).
100+
import importlib.metadata # noqa: F401
101+
import importlib.util # noqa: F401
102+
import json # noqa: F401
103+
from collections import Counter as _WorkerCounter
104+
105+
CallGraphGenerator = None
106+
for pkg_name in ("pycg", "PyCG"):
107+
try:
108+
mod = importlib.import_module(pkg_name)
109+
sys.modules.setdefault("pycg", mod)
110+
sys.modules.setdefault("PyCG", mod)
111+
pycg_mod = importlib.import_module(f"{pkg_name}.pycg")
112+
CallGraphGenerator = pycg_mod.CallGraphGenerator
113+
break
114+
except ImportError:
115+
continue
116+
117+
if CallGraphGenerator is None:
118+
raise RuntimeError("pycg is not installed in Ray worker — run `pip install pycg`")
119+
120+
_apply_pycg_posonly_patch()
121+
122+
cg = CallGraphGenerator(
123+
entry_points=entry_points,
124+
package=package_dir,
125+
max_iter=-1,
126+
operation="call-graph",
127+
)
128+
cg.analyze()
129+
130+
edge_counts = _WorkerCounter()
131+
for src, dst in cg.output_edges():
132+
if prefix:
133+
src = f"{prefix}.{src}"
134+
dst = f"{prefix}.{dst}"
135+
edge_counts[(src, dst)] += 1
136+
137+
return [(src, dst, count) for (src, dst), count in edge_counts.items()]
138+
139+
84140
def _apply_pycg_posonly_patch() -> None:
85141
"""Monkey-patch PyCG's PreProcessor to handle Python 3.8+ positional-only params.
86142
@@ -256,6 +312,7 @@ def __init__(
256312
shard: bool = False,
257313
shard_ceiling: Optional[int] = None,
258314
shard_timeout: Optional[int] = None,
315+
using_ray: bool = False,
259316
) -> None:
260317
self.project_dir = Path(project_dir).resolve()
261318
self.skip_tests = skip_tests
@@ -266,6 +323,7 @@ def __init__(
266323
self.shard_timeout = (
267324
shard_timeout if shard_timeout is not None else self._PYCG_SHARD_TIMEOUT
268325
)
326+
self.using_ray = using_ray
269327
self._CallGraphGenerator: Optional[Any] = None
270328

271329
# ------------------------------------------------------------------
@@ -423,6 +481,9 @@ def _build_sharded(
423481
len(entry_points), len(shards),
424482
)
425483

484+
if self.using_ray:
485+
return self._build_sharded_ray(shards)
486+
426487
all_edges: List[PyCallEdge] = []
427488
skipped = 0
428489
for pkg_root, files in shards.items():
@@ -483,6 +544,96 @@ def _build_sharded(
483544
)
484545
return result
485546

547+
def _build_sharded_ray(self, shards: Dict[Path, List[str]]) -> List[PyCallEdge]:
548+
"""Ray-parallel variant of the sequential shard loop.
549+
550+
All eligible shards are submitted as Ray remote tasks simultaneously.
551+
``ray.wait(timeout=shard_timeout)`` is used to collect results and
552+
cancel stragglers — Ray workers cannot use SIGALRM, so the timeout is
553+
enforced at the orchestrator level instead.
554+
"""
555+
import ray
556+
557+
remote_fn = ray.remote(_pycg_shard_worker)
558+
futures: List[Any] = []
559+
meta: Dict[Any, tuple] = {} # ObjectRef -> (pkg_label, n_files)
560+
skipped = 0
561+
562+
for pkg_root, files in shards.items():
563+
n = len(files)
564+
pkg_label = str(pkg_root.relative_to(self.project_dir)) or "."
565+
if n > self.shard_ceiling:
566+
logger.warning(
567+
"PyCG shard '%s': %d files exceeds shard ceiling of %d — skipped",
568+
pkg_label, n, self.shard_ceiling,
569+
)
570+
skipped += 1
571+
continue
572+
prefix = self._package_prefix(pkg_root, self.project_dir)
573+
fut = remote_fn.remote(files, str(pkg_root), prefix)
574+
futures.append(fut)
575+
meta[fut] = (pkg_label, n)
576+
577+
all_edges: List[PyCallEdge] = []
578+
if futures:
579+
timeout = float(self.shard_timeout) if self.shard_timeout > 0 else None
580+
ready, timed_out = ray.wait(futures, num_returns=len(futures), timeout=timeout)
581+
582+
for fut in ready:
583+
pkg_label, n = meta[fut]
584+
try:
585+
triples = ray.get(fut)
586+
edges = [
587+
PyCallEdge(source=s, target=t, weight=w, provenance=["pycg"])
588+
for s, t, w in triples
589+
]
590+
all_edges.extend(edges)
591+
logger.debug(
592+
"PyCG shard '%s': %d edges from %d files (Ray)",
593+
pkg_label, len(edges), n,
594+
)
595+
except Exception as exc:
596+
logger.warning("PyCG shard '%s' failed — skipped: %s", pkg_label, exc)
597+
skipped += 1
598+
599+
for fut in timed_out:
600+
pkg_label, _ = meta[fut]
601+
logger.warning(
602+
"PyCG shard '%s' timed out after %ds — skipped",
603+
pkg_label, self.shard_timeout,
604+
)
605+
ray.cancel(fut, force=True)
606+
skipped += 1
607+
608+
if skipped:
609+
logger.warning(
610+
"PyCG: %d shard(s) were skipped (exceeded %d-file ceiling, "
611+
"%ds timeout, or failed)",
612+
skipped, self.shard_ceiling, self.shard_timeout,
613+
)
614+
615+
merged: Dict[tuple, PyCallEdge] = {}
616+
for edge in all_edges:
617+
key = (edge.source, edge.target)
618+
if key in merged:
619+
existing = merged[key]
620+
merged[key] = PyCallEdge(
621+
source=existing.source,
622+
target=existing.target,
623+
weight=existing.weight + edge.weight,
624+
provenance=existing.provenance,
625+
)
626+
else:
627+
merged[key] = edge
628+
629+
result = list(merged.values())
630+
logger.debug(
631+
"PyCG: Ray-parallel sharding produced %d edges (%d before dedup) "
632+
"from %d/%d shard(s)",
633+
len(result), len(all_edges), len(shards) - skipped, len(shards),
634+
)
635+
return result
636+
486637
# ------------------------------------------------------------------
487638
# Public API
488639
# ------------------------------------------------------------------

0 commit comments

Comments
 (0)