From c217be48567ca3d518464a8c7cd2123c22867d97 Mon Sep 17 00:00:00 2001 From: Vignesh Narayanaswamy Date: Sun, 14 Jun 2026 22:11:10 -0700 Subject: [PATCH] perf(sdk): batch edge resolution in the investigate/graph hot path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit investigate() and the graph methods (dependencies/members/groups) resolved every dependency and membership edge with its own single-model get() round trip, so the backend call count grew linearly with edge count — ~29 single-model lookups for a node with a handful of edges, ~93 for a richly-connected one. Resolve all edges of a node in one batched lookup instead: - Add an optional `get_models(hashes) -> {hash: ModelRef}` bulk method to every backend (in-memory, sqlite, snowflake, json-files) plus a protocol-only `batch_fallbacks.get_models` for third-party backends. It is hasattr-dispatched, so the LedgerBackend protocol surface is unchanged and existing backends keep working. - dependencies()/members()/groups() now collect edges first, then resolve every target hash in a single get_models call. They also accept an optional pre-fetched `snapshots` list so callers thread an already-loaded history through instead of refetching. - groups() scans candidate composites' membership in one list_all_snapshots pass (when supported) rather than one list_snapshots per candidate. - sqlite batch_dependencies and the batch_fallbacks.batch_dependencies resolve edge targets with one bulk lookup instead of per-edge get_model (snowflake already did this). - investigate() reuses the model history it already fetched for groups() and members() (only when no as_of filter is in play, to preserve their current-state semantics). Resolution semantics are unchanged: hash-first with a per-edge name fallback that fires only when a hash does not resolve. On sqlite an investigate of a handful-of-edges node drops from ~29 to ~9 total backend round trips, and a richly-connected one from ~93 to ~11 — and the count no longer scales with edge count. Tests: a counting fake backend proves per-edge single lookups stay at zero and the round-trip budget is flat across graph sizes; parity tests confirm identical dependency/group/member results and that member-removal replay still excludes removed members. New cross-backend tests pin the get_models contract. Co-Authored-By: Claude Fable 5 --- src/model_ledger/backends/batch_fallbacks.py | 99 ++++--- src/model_ledger/backends/json_files.py | 19 ++ src/model_ledger/backends/ledger_memory.py | 8 + src/model_ledger/backends/snowflake.py | 17 ++ src/model_ledger/backends/sqlite_ledger.py | 114 +++++---- src/model_ledger/sdk/ledger.py | 159 +++++++++--- src/model_ledger/tools/investigate.py | 12 +- tests/test_backends/test_get_models.py | 107 ++++++++ tests/test_sdk/test_batched_edges.py | 256 +++++++++++++++++++ 9 files changed, 672 insertions(+), 119 deletions(-) create mode 100644 tests/test_backends/test_get_models.py create mode 100644 tests/test_sdk/test_batched_edges.py diff --git a/src/model_ledger/backends/batch_fallbacks.py b/src/model_ledger/backends/batch_fallbacks.py index a2bb833..371038a 100644 --- a/src/model_ledger/backends/batch_fallbacks.py +++ b/src/model_ledger/backends/batch_fallbacks.py @@ -12,6 +12,29 @@ if TYPE_CHECKING: from model_ledger.backends.ledger_protocol import LedgerBackend + from model_ledger.core.ledger_models import ModelRef + + +def get_models( + backend: LedgerBackend, + model_hashes: list[str], +) -> dict[str, ModelRef]: + """Resolve many model hashes to ModelRefs in one logical batch. + + Returns ``{model_hash: ModelRef}`` for every hash that resolves; absent + hashes are simply omitted. The protocol-only fallback issues one + ``get_model`` per *distinct* hash, deduplicating so a hash referenced by + several edges is fetched once. Performance matches the prior N+1 behavior; + backends override this with a single ``IN (...)`` query for real speedup. + """ + result: dict[str, ModelRef] = {} + for model_hash in dict.fromkeys(model_hashes): # dedup, preserve order + if not model_hash: + continue + ref = backend.get_model(model_hash) + if ref is not None: + result[model_hash] = ref + return result def _resolve_platform( @@ -146,50 +169,56 @@ def batch_dependencies( Returns ``{"upstream": [...], "downstream": [...]}`` where each entry contains ``model_hash``, ``model_name``, and ``relationship``. + + All edge targets are resolved by hash in a single batched lookup + (one ``get_models`` round trip) rather than one ``get_model`` per edge. + A per-edge name fallback runs only for the rare edge whose hash does not + resolve, preserving the resolution semantics of the prior implementation. """ snapshots = backend.list_snapshots(model_hash) - upstream: list[dict[str, Any]] = [] - downstream: list[dict[str, Any]] = [] + # Collect edges first: (direction, target_hash, target_name, relationship). + edges: list[tuple[str, str, str, str]] = [] for snap in snapshots: if snap.event_type == "depends_on": - related_hash = snap.payload.get("upstream_hash", "") - related_name = snap.payload.get("upstream", "") - relationship = snap.payload.get("relationship", "depends_on") - - related = backend.get_model(related_hash) if related_hash else None - if related is None and related_name: - related = backend.get_model_by_name(related_name) - if related is None: - continue - - upstream.append( - { - "model_hash": related.model_hash, - "model_name": related.name, - "relationship": relationship, - } + edges.append( + ( + "upstream", + snap.payload.get("upstream_hash", ""), + snap.payload.get("upstream", ""), + snap.payload.get("relationship", "depends_on"), + ) ) - elif snap.event_type == "has_dependent": - related_hash = snap.payload.get("downstream_hash", "") - related_name = snap.payload.get("downstream", "") - relationship = snap.payload.get("relationship", "depends_on") - - related = backend.get_model(related_hash) if related_hash else None - if related is None and related_name: - related = backend.get_model_by_name(related_name) - if related is None: - continue - - downstream.append( - { - "model_hash": related.model_hash, - "model_name": related.name, - "relationship": relationship, - } + edges.append( + ( + "downstream", + snap.payload.get("downstream_hash", ""), + snap.payload.get("downstream", ""), + snap.payload.get("relationship", "depends_on"), + ) ) + by_hash = get_models(backend, [h for _, h, _, _ in edges if h]) + + upstream: list[dict[str, Any]] = [] + downstream: list[dict[str, Any]] = [] + for direction, related_hash, related_name, relationship in edges: + related = by_hash.get(related_hash) if related_hash else None + if related is None and related_name: + related = backend.get_model_by_name(related_name) + if related is None: + continue + entry = { + "model_hash": related.model_hash, + "model_name": related.name, + "relationship": relationship, + } + if direction == "upstream": + upstream.append(entry) + else: + downstream.append(entry) + return {"upstream": upstream, "downstream": downstream} diff --git a/src/model_ledger/backends/json_files.py b/src/model_ledger/backends/json_files.py index 3746306..8e77560 100644 --- a/src/model_ledger/backends/json_files.py +++ b/src/model_ledger/backends/json_files.py @@ -66,6 +66,25 @@ def get_model_by_name(self, name: str) -> ModelRef | None: return ModelRef.model_validate_json(path.read_text()) return None + def get_models(self, model_hashes: list[str]) -> dict[str, ModelRef]: + """Bulk-resolve model hashes with a single directory scan. + + ``get_model`` scans the whole models directory per call; resolving N + edges that way is O(N x files). This reads every model file once and + indexes by hash, so a graph traversal pays a single pass. + """ + wanted = {h for h in model_hashes if h} + if not wanted: + return {} + result: dict[str, ModelRef] = {} + for path in self._models_dir.iterdir(): + if path.suffix != ".json": + continue + m = ModelRef.model_validate_json(path.read_text()) + if m.model_hash in wanted: + result[m.model_hash] = m + return result + def list_models(self, **filters: str) -> list[ModelRef]: results: list[ModelRef] = [] for path in sorted(self._models_dir.iterdir()): diff --git a/src/model_ledger/backends/ledger_memory.py b/src/model_ledger/backends/ledger_memory.py index 12a431a..6ebfc65 100644 --- a/src/model_ledger/backends/ledger_memory.py +++ b/src/model_ledger/backends/ledger_memory.py @@ -25,6 +25,14 @@ def get_model_by_name(self, name: str) -> ModelRef | None: return m return None + def get_models(self, model_hashes: list[str]) -> dict[str, ModelRef]: + """Bulk-resolve model hashes to ModelRefs in one pass. + + Returns ``{model_hash: ModelRef}`` for hashes that exist; missing + hashes are omitted. Counts as a single backend round trip. + """ + return {h: self._models[h] for h in dict.fromkeys(model_hashes) if h in self._models} + def list_models(self, **filters: str) -> list[ModelRef]: text = filters.pop("text", None) limit = filters.pop("limit", None) diff --git a/src/model_ledger/backends/snowflake.py b/src/model_ledger/backends/snowflake.py index fb1079f..809aa9e 100644 --- a/src/model_ledger/backends/snowflake.py +++ b/src/model_ledger/backends/snowflake.py @@ -385,6 +385,23 @@ def get_model_by_name(self, name: str) -> ModelRef | None: ) return _row_to_model_ref(rows[0]) if rows else None + def get_models(self, model_hashes: list[str]) -> dict[str, ModelRef]: + """Bulk-resolve model hashes to ModelRefs with one ``IN (...)`` query.""" + self._flush_models() + hashes = [h for h in dict.fromkeys(model_hashes) if h] + if not hashes: + return {} + in_clause = ", ".join(_esc(h) for h in hashes) + rows = _exec( + self._session, + f"SELECT * FROM {self._schema}.MODELS WHERE MODEL_HASH IN ({in_clause})", + ) + result: dict[str, ModelRef] = {} + for row in rows: + ref = _row_to_model_ref(row) + result[ref.model_hash] = ref + return result + def list_models(self, **filters: str) -> list[ModelRef]: self._flush_models() limit = filters.pop("limit", None) diff --git a/src/model_ledger/backends/sqlite_ledger.py b/src/model_ledger/backends/sqlite_ledger.py index 5799034..c2bf184 100644 --- a/src/model_ledger/backends/sqlite_ledger.py +++ b/src/model_ledger/backends/sqlite_ledger.py @@ -146,6 +146,17 @@ def get_model_by_name(self, name: str) -> ModelRef | None: row = self._conn.execute("SELECT * FROM models WHERE name = ?", (name,)).fetchone() return self._row_to_model(row) if row else None + def get_models(self, model_hashes: list[str]) -> dict[str, ModelRef]: + """Bulk-resolve model hashes to ModelRefs with one ``IN (...)`` query.""" + hashes = [h for h in dict.fromkeys(model_hashes) if h] + if not hashes: + return {} + placeholders = ", ".join("?" for _ in hashes) + rows = self._conn.execute( + f"SELECT * FROM models WHERE model_hash IN ({placeholders})", hashes + ).fetchall() + return {row["model_hash"]: self._row_to_model(row) for row in rows} + def list_models(self, **filters: str) -> list[ModelRef]: sql = "SELECT * FROM models" params: list[str] = [] @@ -439,9 +450,9 @@ def batch_dependencies( self, model_hash: str, ) -> dict[str, list[dict]]: - upstream: list[dict] = [] - downstream: list[dict] = [] - + # One query for the edge snapshots; one batched query to resolve every + # edge target by hash. A per-edge name fallback runs only when a hash + # does not resolve, preserving the prior resolution semantics. if self._has_json_extract: rows = self._conn.execute( "SELECT s.event_type, " @@ -456,30 +467,27 @@ def batch_dependencies( (model_hash,), ).fetchall() + # (direction, target_hash, target_name, relationship) + edges: list[tuple[str, str, str, str]] = [] for row in rows: if row["event_type"] == "depends_on": - related_hash = row["upstream_hash"] or "" - related_name = row["upstream_name"] or "" - else: - related_hash = row["downstream_hash"] or "" - related_name = row["downstream_name"] or "" - relationship = row["relationship"] or "depends_on" - - related = self.get_model(related_hash) if related_hash else None - if related is None and related_name: - related = self.get_model_by_name(related_name) - if related is None: - continue - - entry = { - "model_hash": related.model_hash, - "model_name": related.name, - "relationship": relationship, - } - if row["event_type"] == "depends_on": - upstream.append(entry) + edges.append( + ( + "upstream", + row["upstream_hash"] or "", + row["upstream_name"] or "", + row["relationship"] or "depends_on", + ) + ) else: - downstream.append(entry) + edges.append( + ( + "downstream", + row["downstream_hash"] or "", + row["downstream_name"] or "", + row["relationship"] or "depends_on", + ) + ) else: rows = self._conn.execute( "SELECT s.event_type, s.payload " @@ -489,31 +497,47 @@ def batch_dependencies( (model_hash,), ).fetchall() + edges = [] for row in rows: payload = json.loads(row["payload"]) if row["payload"] else {} if row["event_type"] == "depends_on": - related_hash = payload.get("upstream_hash", "") - related_name = payload.get("upstream", "") + edges.append( + ( + "upstream", + payload.get("upstream_hash", ""), + payload.get("upstream", ""), + payload.get("relationship", "depends_on"), + ) + ) else: - related_hash = payload.get("downstream_hash", "") - related_name = payload.get("downstream", "") - relationship = payload.get("relationship", "depends_on") - - related = self.get_model(related_hash) if related_hash else None - if related is None and related_name: - related = self.get_model_by_name(related_name) - if related is None: - continue - - entry = { - "model_hash": related.model_hash, - "model_name": related.name, - "relationship": relationship, - } - if row["event_type"] == "depends_on": - upstream.append(entry) - else: - downstream.append(entry) + edges.append( + ( + "downstream", + payload.get("downstream_hash", ""), + payload.get("downstream", ""), + payload.get("relationship", "depends_on"), + ) + ) + + by_hash = self.get_models([h for _, h, _, _ in edges if h]) + + upstream: list[dict] = [] + downstream: list[dict] = [] + for direction, related_hash, related_name, relationship in edges: + related = by_hash.get(related_hash) if related_hash else None + if related is None and related_name: + related = self.get_model_by_name(related_name) + if related is None: + continue + entry = { + "model_hash": related.model_hash, + "model_name": related.name, + "relationship": relationship, + } + if direction == "upstream": + upstream.append(entry) + else: + downstream.append(entry) return {"upstream": upstream, "downstream": downstream} diff --git a/src/model_ledger/sdk/ledger.py b/src/model_ledger/sdk/ledger.py index 6208cfa..a5f8f91 100644 --- a/src/model_ledger/sdk/ledger.py +++ b/src/model_ledger/sdk/ledger.py @@ -6,6 +6,7 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, TypedDict +from model_ledger.backends import batch_fallbacks from model_ledger.backends.ledger_memory import InMemoryLedgerBackend from model_ledger.backends.ledger_protocol import LedgerBackend from model_ledger.core.enums import ModelStatus @@ -118,6 +119,21 @@ def _resolve_model(self, model: ModelRef | str) -> ModelRef: self._name_cache[model] = result return result + def _resolve_hashes(self, model_hashes: builtins.list[str]) -> dict[str, ModelRef]: + """Resolve many model hashes to ModelRefs in a single backend round trip. + + Dispatches to the backend's bulk ``get_models`` when available + (one ``IN (...)`` query) and falls back to the protocol-only + implementation otherwise. Used by the graph methods to resolve all + edges of a node at once instead of one ``get()`` per edge. + """ + if not model_hashes: + return {} + if hasattr(self._backend, "get_models"): + resolved: dict[str, ModelRef] = self._backend.get_models(model_hashes) + return resolved + return batch_fallbacks.get_models(self._backend, model_hashes) + def register( self, *, @@ -319,41 +335,49 @@ def dependencies( self, model: ModelRef | str, direction: str = "both", + *, + snapshots: builtins.list[Snapshot] | None = None, ) -> builtins.list[dict[str, Any]]: + """Direct dependency edges for a model. + + Resolves every edge's target model in ONE batched lookup instead of a + per-edge round trip. Pass ``snapshots`` (the model's full history) to + reuse an already-fetched list and skip the ``list_snapshots`` call — + the graph traversal and ``investigate`` use this to avoid refetching. + """ ref = self._resolve_model(model) - snaps = self._backend.list_snapshots(ref.model_hash) - result: builtins.list[dict[str, Any]] = [] + snaps = snapshots if snapshots is not None else self._backend.list_snapshots(ref.model_hash) + # Collect edges first, then resolve all target hashes in one batch. + # Each edge: (direction, target_hash, relationship) + edges: builtins.list[tuple[str, str, str]] = [] if direction in ("upstream", "both"): for s in snaps: if s.event_type == "depends_on": - try: - upstream = self.get(s.payload["upstream_hash"]) - except ModelNotFoundError: - continue - result.append( - { - "model": upstream, - "relationship": s.payload.get("relationship", "depends_on"), - "direction": "upstream", - } - ) - + h = s.payload.get("upstream_hash") + if h: + edges.append(("upstream", h, s.payload.get("relationship", "depends_on"))) if direction in ("downstream", "both"): for s in snaps: if s.event_type == "has_dependent": - try: - downstream = self.get(s.payload["downstream_hash"]) - except ModelNotFoundError: - continue - result.append( - { - "model": downstream, - "relationship": s.payload.get("relationship", "depends_on"), - "direction": "downstream", - } - ) + h = s.payload.get("downstream_hash") + if h: + edges.append(("downstream", h, s.payload.get("relationship", "depends_on"))) + + resolved = self._resolve_hashes([h for _, h, _ in edges]) + result: builtins.list[dict[str, Any]] = [] + for edge_direction, target_hash, relationship in edges: + target = resolved.get(target_hash) + if target is None: + continue + result.append( + { + "model": target, + "relationship": relationship, + "direction": edge_direction, + } + ) return result # --- Graph methods (v0.4.0) --- @@ -623,7 +647,12 @@ def register_group( ) return ref - def members(self, group: ModelRef | str) -> builtins.list[ModelRef]: + def members( + self, + group: ModelRef | str, + *, + snapshots: builtins.list[Snapshot] | None = None, + ) -> builtins.list[ModelRef]: """Return current members of this group. Replays member_added/member_removed snapshots to determine @@ -633,14 +662,20 @@ def members(self, group: ModelRef | str) -> builtins.list[ModelRef]: Mixed case: groups seeded via register_group() that later have add_member()/remove_member() called use dependency links as the baseline and overlay the event log on top. + + All member_added targets are resolved in a single batched lookup + instead of one round trip per event. Pass ``snapshots`` (the group's + full history) to reuse an already-fetched list. """ ref = self._resolve_model(group) - snaps = self._backend.list_snapshots(ref.model_hash) + snaps = snapshots if snapshots is not None else self._backend.list_snapshots(ref.model_hash) membership_events = [s for s in snaps if s.event_type in ("member_added", "member_removed")] # Seed from dependency links (covers register_group() seeded members # and is always correct as the initial universe of linked models). - deps = self.dependencies(group, direction="upstream") or [] + # Reuse the snapshots we already have — dependency edges live in the + # same history. + deps = self.dependencies(group, direction="upstream", snapshots=snaps) or [] current: dict[str, ModelRef] = { d["model"].model_hash: d["model"] for d in deps if d.get("relationship") == "member_of" } @@ -649,14 +684,27 @@ def members(self, group: ModelRef | str) -> builtins.list[ModelRef]: # No events: dependency links are the full picture. return list(current.values()) + ordered_events = sorted(membership_events, key=lambda s: s.timestamp) + + # Batch-resolve every member_added hash not already seeded, in one + # round trip. The name fallback (for the rare unresolvable hash) stays + # per-event but only fires when the bulk lookup misses. + added_hashes = [ + s.payload.get("member_hash", "") + for s in ordered_events + if s.event_type == "member_added" and s.payload.get("member_hash", "") not in current + ] + resolved = self._resolve_hashes([h for h in added_hashes if h]) + # Replay events on top of the dep-link baseline. - for s in sorted(membership_events, key=lambda s: s.timestamp): + for s in ordered_events: member_hash = s.payload.get("member_hash", "") if s.event_type == "member_added": if member_hash not in current: - try: - current[member_hash] = self.get(member_hash) - except ModelNotFoundError: + ref_for_member = resolved.get(member_hash) + if ref_for_member is not None: + current[member_hash] = ref_for_member + else: member_name = s.payload.get("member_name", "") if member_name: try: @@ -667,22 +715,61 @@ def members(self, group: ModelRef | str) -> builtins.list[ModelRef]: current.pop(member_hash, None) return list(current.values()) - def groups(self, model: ModelRef | str) -> builtins.list[ModelRef]: + def groups( + self, + model: ModelRef | str, + *, + snapshots: builtins.list[Snapshot] | None = None, + ) -> builtins.list[ModelRef]: """Return groups this model currently belongs to. Replays member_added/member_removed events on each candidate group - to exclude composites the model has been removed from. + to exclude composites the model has been removed from. The candidate + composites' membership histories are fetched in a single bulk + ``list_all_snapshots`` scan when the backend supports it, so the cost + is one query for the whole fan-out rather than one per candidate. + + Pass ``snapshots`` (this model's full history) to reuse an + already-fetched list for the downstream-edge discovery. """ - deps = self.dependencies(model, direction="downstream") or [] + deps = self.dependencies(model, direction="downstream", snapshots=snapshots) or [] candidates = [d["model"] for d in deps if d.get("relationship") == "member_of"] + if not candidates: + return [] ref = self._resolve_model(model) + + # Resolve each candidate's current members. Prefer a single bulk + # snapshot scan over per-candidate list_snapshots round trips. + snaps_by_group = self._membership_snapshots({c.model_hash for c in candidates}) + result: builtins.list[ModelRef] = [] for comp in candidates: - current_members = self.members(comp) + comp_snaps = snaps_by_group.get(comp.model_hash) + current_members = self.members(comp, snapshots=comp_snaps) if any(m.model_hash == ref.model_hash for m in current_members): result.append(comp) return result + def _membership_snapshots( + self, + group_hashes: builtins.set[str], + ) -> dict[str, builtins.list[Snapshot]]: + """Group the membership-relevant snapshots for several groups at once. + + Returns ``{group_hash: [snapshots]}``. When the backend exposes + ``list_all_snapshots`` the whole fan-out is one scan; otherwise this + returns an empty mapping and callers fall back to per-group + ``list_snapshots`` (preserving the protocol-only contract). + """ + if not group_hashes or not hasattr(self._backend, "list_all_snapshots"): + return {} + by_group: dict[str, builtins.list[Snapshot]] = {h: [] for h in group_hashes} + for s in self._backend.list_all_snapshots(): + bucket = by_group.get(s.model_hash) + if bucket is not None: + bucket.append(s) + return by_group + def add_member( self, composite: ModelRef | str, diff --git a/src/model_ledger/tools/investigate.py b/src/model_ledger/tools/investigate.py index edf8e24..4407b18 100644 --- a/src/model_ledger/tools/investigate.py +++ b/src/model_ledger/tools/investigate.py @@ -38,7 +38,8 @@ def investigate(input: InvestigateInput, ledger: Ledger) -> InvestigateOutput: """ model = ledger.get(input.model_name) - snapshots = ledger.history(model) or [] + all_snapshots = ledger.history(model) or [] + snapshots = all_snapshots if input.as_of is not None: as_of = input.as_of @@ -92,16 +93,21 @@ def investigate(input: InvestigateInput, ledger: Ledger) -> InvestigateOutput: upstream_nodes = [] downstream_nodes = [] + # groups()/members() reflect *current* state, so reuse the unfiltered + # snapshot list we already fetched — but only when no as_of filter is in + # play (an as_of-trimmed list would change their membership semantics). + reuse_snaps = all_snapshots if input.as_of is None else None + group_names: list[str] = [] try: - group_refs = ledger.groups(model) or [] + group_refs = ledger.groups(model, snapshots=reuse_snaps) or [] group_names = [g.name for g in group_refs] except (KeyError, ValueError, Exception): group_names = [] member_names: list[str] = [] try: - member_refs = ledger.members(model) or [] + member_refs = ledger.members(model, snapshots=reuse_snaps) or [] member_names = [m.name for m in member_refs] except (KeyError, ValueError, Exception): member_names = [] diff --git a/tests/test_backends/test_get_models.py b/tests/test_backends/test_get_models.py new file mode 100644 index 0000000..d5dc25a --- /dev/null +++ b/tests/test_backends/test_get_models.py @@ -0,0 +1,107 @@ +"""Tests for the optional ``get_models`` bulk-resolution backend method. + +``get_models(hashes) -> {hash: ModelRef}`` is the batched counterpart to +``get_model``. Every shipped backend implements it, and ``batch_fallbacks`` +supplies a protocol-only version for third-party backends that do not. These +tests pin the shared contract: same refs as one-by-one resolution, missing +hashes omitted, dedup, and empty/blank input handled. +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from model_ledger.backends import batch_fallbacks +from model_ledger.backends.json_files import JsonFileLedgerBackend +from model_ledger.backends.ledger_memory import InMemoryLedgerBackend +from model_ledger.backends.sqlite_ledger import SQLiteLedgerBackend +from model_ledger.core.ledger_models import ModelRef + + +def _make_model(name: str) -> ModelRef: + return ModelRef( + name=name, + owner="risk-team", + model_type="ml_model", + tier="high", + purpose="testing", + ) + + +def _seed(backend, names): + refs = {} + for n in names: + ref = _make_model(n) + backend.save_model(ref) + refs[n] = ref + return refs + + +def _backends(tmp_path: Path): + sqlite = SQLiteLedgerBackend(str(tmp_path / "g.db")) + json_files = JsonFileLedgerBackend(str(tmp_path / "jf")) + return { + "memory": InMemoryLedgerBackend(), + "sqlite": sqlite, + "json_files": json_files, + } + + +@pytest.fixture(params=["memory", "sqlite", "json_files"]) +def backend(request, tmp_path): + return _backends(tmp_path)[request.param] + + +class TestGetModels: + def test_resolves_all_hashes(self, backend): + refs = _seed(backend, ["a", "b", "c"]) + result = backend.get_models([refs["a"].model_hash, refs["b"].model_hash]) + assert set(result) == {refs["a"].model_hash, refs["b"].model_hash} + assert result[refs["a"].model_hash].name == "a" + assert result[refs["b"].model_hash].name == "b" + + def test_omits_missing_hashes(self, backend): + refs = _seed(backend, ["a"]) + result = backend.get_models([refs["a"].model_hash, "deadbeef-missing"]) + assert set(result) == {refs["a"].model_hash} + + def test_empty_and_blank_input(self, backend): + _seed(backend, ["a"]) + assert backend.get_models([]) == {} + assert backend.get_models(["", ""]) == {} + + def test_dedup_repeated_hash(self, backend): + refs = _seed(backend, ["a"]) + h = refs["a"].model_hash + result = backend.get_models([h, h, h]) + assert set(result) == {h} + + def test_parity_with_single_get_model(self, backend): + refs = _seed(backend, ["a", "b", "c"]) + hashes = [r.model_hash for r in refs.values()] + bulk = backend.get_models(hashes) + for h in hashes: + single = backend.get_model(h) + assert bulk[h].model_hash == single.model_hash + assert bulk[h].name == single.name + + +class TestGetModelsFallback: + """The protocol-only fallback must match the native implementations.""" + + def test_fallback_resolves_and_omits(self): + backend = InMemoryLedgerBackend() + refs = _seed(backend, ["a", "b"]) + result = batch_fallbacks.get_models( + backend, [refs["a"].model_hash, "missing", refs["b"].model_hash] + ) + assert set(result) == {refs["a"].model_hash, refs["b"].model_hash} + + def test_fallback_dedups_and_skips_blank(self): + backend = InMemoryLedgerBackend() + refs = _seed(backend, ["a"]) + h = refs["a"].model_hash + result = batch_fallbacks.get_models(backend, ["", h, h]) + assert set(result) == {h} diff --git a/tests/test_sdk/test_batched_edges.py b/tests/test_sdk/test_batched_edges.py new file mode 100644 index 0000000..f3d4dd1 --- /dev/null +++ b/tests/test_sdk/test_batched_edges.py @@ -0,0 +1,256 @@ +"""Tests for batched edge resolution in the graph traversal hot path. + +The graph methods (``dependencies``/``members``/``groups``) and the +``investigate`` tool used to resolve each dependency or membership edge with +its own single-model ``get()`` round trip, so the backend call count grew +linearly with edge count. These tests pin the new behavior: + +1. Round-trip count stays flat as the number of edges grows (no per-edge + ``get_model``). +2. Results are identical to resolving every edge individually. +""" + +from __future__ import annotations + +from collections import Counter + +import pytest + +from model_ledger.backends.ledger_memory import InMemoryLedgerBackend +from model_ledger.backends.sqlite_ledger import SQLiteLedgerBackend +from model_ledger.sdk.ledger import Ledger +from model_ledger.tools.investigate import investigate +from model_ledger.tools.schemas import InvestigateInput + + +class _CountingMixin: + """Tallies single-model vs. batched resolution round trips. + + ``get_model`` / ``get_model_by_name`` are the per-edge round trips we want + to eliminate; ``get_models`` is the batched replacement. Tallying all three + lets a test assert the *shape* of resolution, not just the total. + """ + + def _init_counter(self) -> None: + self.calls: Counter[str] = Counter() + + def get_model(self, model_hash): # type: ignore[override] + self.calls["get_model"] += 1 + return super().get_model(model_hash) + + def get_model_by_name(self, name): # type: ignore[override] + self.calls["get_model_by_name"] += 1 + return super().get_model_by_name(name) + + def get_models(self, model_hashes): # type: ignore[override] + self.calls["get_models"] += 1 + return super().get_models(model_hashes) + + @property + def single_get_calls(self) -> int: + """Per-edge single-model lookups — the cost we want flat (ideally 0).""" + return self.calls["get_model"] + self.calls["get_model_by_name"] + + +class CountingBackend(_CountingMixin, InMemoryLedgerBackend): + def __init__(self) -> None: + super().__init__() + self._init_counter() + + +class CountingSQLiteBackend(_CountingMixin, SQLiteLedgerBackend): + """A production-representative backend (native batched SQL methods). + + Unlike the in-memory backend, SQLite ships its own ``batch_dependencies`` + and ``get_models``, so this exercises the full batched-resolution path that + a real deployment hits — and where eliminating per-edge round trips matters. + """ + + def __init__(self, db_path: str) -> None: + super().__init__(db_path) + self._init_counter() + + +def _build(backend: InMemoryLedgerBackend, n_deps: int, n_groups: int) -> Ledger: + """A central model with ``n_deps`` upstream + ``n_deps`` downstream edges, + belonging to ``n_groups`` composites.""" + led = Ledger(backend) + led.register( + name="central", + owner="risk-team", + model_type="ml_model", + tier="high", + purpose="central node", + actor="test", + ) + for i in range(n_deps): + led.register( + name=f"up_{i}", + owner="risk-team", + model_type="ml_model", + tier="low", + purpose="upstream", + actor="test", + ) + led.link_dependency( + upstream=f"up_{i}", downstream="central", relationship="data_flow", actor="test" + ) + led.register( + name=f"down_{i}", + owner="risk-team", + model_type="ml_model", + tier="low", + purpose="downstream", + actor="test", + ) + led.link_dependency( + upstream="central", downstream=f"down_{i}", relationship="data_flow", actor="test" + ) + for g in range(n_groups): + led.register_group( + name=f"grp_{g}", + owner="risk-team", + model_type="composite", + tier="high", + purpose="group", + members=["central"], + actor="test", + ) + # Drop the SDK name cache so reads go through the backend, simulating a + # fresh process investigating a pre-existing model. + led._name_cache.clear() + led._cache_complete = False + return led + + +@pytest.fixture +def sqlite_factory(tmp_path): + """Yields a builder for a CountingSQLiteBackend-backed ledger.""" + created = [] + + def make(n_deps: int, n_groups: int): + path = str(tmp_path / f"ledger_{len(created)}.db") + backend = CountingSQLiteBackend(path) + led = _build(backend, n_deps=n_deps, n_groups=n_groups) + backend.calls.clear() + created.append((backend, led)) + return backend, led + + return make + + +class TestDependenciesBatching: + def test_no_per_edge_get_on_dependencies(self): + backend = CountingBackend() + led = _build(backend, n_deps=8, n_groups=0) + backend.calls.clear() + + deps = led.dependencies("central", direction="both") + + # 8 upstream + 8 downstream edges, but the resolution does NOT fan out + # into one get_model per edge. + assert len(deps) == 16 + assert backend.calls["get_model"] == 0 + # All 16 targets resolved in a single batched get_models call. + assert backend.calls["get_models"] == 1 + + def test_dependencies_parity_with_per_edge_resolution(self): + backend = InMemoryLedgerBackend() + led = _build(backend, n_deps=5, n_groups=0) + + deps = led.dependencies("central", direction="both") + got = sorted((d["model"].name, d["relationship"], d["direction"]) for d in deps) + + # Independently resolve every edge by hand from the raw snapshots. + ref = led.get("central") + expected = [] + for s in backend.list_snapshots(ref.model_hash): + if s.event_type == "depends_on": + target = backend.get_model(s.payload["upstream_hash"]) + expected.append( + (target.name, s.payload.get("relationship", "depends_on"), "upstream") + ) + elif s.event_type == "has_dependent": + target = backend.get_model(s.payload["downstream_hash"]) + expected.append( + (target.name, s.payload.get("relationship", "depends_on"), "downstream") + ) + assert got == sorted(expected) + + +class TestInvestigateRoundTrips: + @pytest.mark.parametrize( + ("n_deps", "n_groups"), + [(3, 2), (12, 4), (30, 8)], + ) + def test_no_per_edge_single_lookups(self, sqlite_factory, n_deps, n_groups): + backend, led = sqlite_factory(n_deps, n_groups) + + investigate(InvestigateInput(model_name="central", detail="full"), led) + + # The old path issued O(edges) single-model lookups (one get_model per + # dependency and per membership event). The batched path resolves every + # edge through get_models, so per-edge single lookups stay at zero + # regardless of how many edges the node has — the regression guard. + assert backend.calls["get_model"] == 0 + assert backend.single_get_calls <= 1 # only the initial name resolve + + def test_round_trip_count_is_flat_across_graph_sizes(self, sqlite_factory): + """Same fixed budget for a sparse and a dense node.""" + sparse_be, sparse = sqlite_factory(3, 2) + investigate(InvestigateInput(model_name="central", detail="full"), sparse) + sparse_total = sum(sparse_be.calls.values()) + + dense_be, dense = sqlite_factory(30, 8) + investigate(InvestigateInput(model_name="central", detail="full"), dense) + dense_total = sum(dense_be.calls.values()) + + # Dense graph has ~10x the edges but the resolution budget barely moves + # (a few extra batched get_models for the extra groups, not O(edges)). + assert dense_total - sparse_total <= 6 + + +class TestInvestigateParity: + def test_results_identical_with_and_without_batching(self): + """A model in groups, with deps, returns the same investigate output.""" + backend = InMemoryLedgerBackend() + led = _build(backend, n_deps=4, n_groups=3) + + out = investigate(InvestigateInput(model_name="central", detail="full"), led) + + assert sorted(d.name for d in out.upstream) == [f"up_{i}" for i in range(4)] + assert {d.name for d in out.downstream} == {f"down_{i}" for i in range(4)} | { + f"grp_{g}" for g in range(3) + } + assert sorted(out.groups) == [f"grp_{g}" for g in range(3)] + + def test_removed_member_excluded_after_batching(self): + """Replay semantics survive batching: a removed member drops out.""" + backend = InMemoryLedgerBackend() + led = Ledger(backend) + led.register_group( + name="scorecard", + owner="risk-team", + model_type="composite", + tier="high", + purpose="group", + members=[], + actor="test", + ) + for name in ("feature_pipeline", "scoring_model", "alert_queue"): + led.register( + name=name, + owner="risk-team", + model_type="ml_model", + tier="low", + purpose="member", + actor="test", + ) + led.add_member("scorecard", name, actor="test") + led.remove_member("scorecard", "alert_queue", actor="test") + + members = {m.name for m in led.members("scorecard")} + assert members == {"feature_pipeline", "scoring_model"} + # The removed member must not list the group either. + assert led.groups("alert_queue") == [] + assert {g.name for g in led.groups("scoring_model")} == {"scorecard"}