Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 64 additions & 35 deletions src/model_ledger/backends/batch_fallbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,29 @@

if TYPE_CHECKING:
from model_ledger.backends.ledger_protocol import LedgerBackend
from model_ledger.core.ledger_models import ModelRef


def get_models(
backend: LedgerBackend,
model_hashes: list[str],
) -> dict[str, ModelRef]:
"""Resolve many model hashes to ModelRefs in one logical batch.

Returns ``{model_hash: ModelRef}`` for every hash that resolves; absent
hashes are simply omitted. The protocol-only fallback issues one
``get_model`` per *distinct* hash, deduplicating so a hash referenced by
several edges is fetched once. Performance matches the prior N+1 behavior;
backends override this with a single ``IN (...)`` query for real speedup.
"""
result: dict[str, ModelRef] = {}
for model_hash in dict.fromkeys(model_hashes): # dedup, preserve order
if not model_hash:
continue
ref = backend.get_model(model_hash)
if ref is not None:
result[model_hash] = ref
return result


def _resolve_platform(
Expand Down Expand Up @@ -146,50 +169,56 @@ def batch_dependencies(

Returns ``{"upstream": [...], "downstream": [...]}`` where each entry
contains ``model_hash``, ``model_name``, and ``relationship``.

All edge targets are resolved by hash in a single batched lookup
(one ``get_models`` round trip) rather than one ``get_model`` per edge.
A per-edge name fallback runs only for the rare edge whose hash does not
resolve, preserving the resolution semantics of the prior implementation.
"""
snapshots = backend.list_snapshots(model_hash)
upstream: list[dict[str, Any]] = []
downstream: list[dict[str, Any]] = []

# Collect edges first: (direction, target_hash, target_name, relationship).
edges: list[tuple[str, str, str, str]] = []
for snap in snapshots:
if snap.event_type == "depends_on":
related_hash = snap.payload.get("upstream_hash", "")
related_name = snap.payload.get("upstream", "")
relationship = snap.payload.get("relationship", "depends_on")

related = backend.get_model(related_hash) if related_hash else None
if related is None and related_name:
related = backend.get_model_by_name(related_name)
if related is None:
continue

upstream.append(
{
"model_hash": related.model_hash,
"model_name": related.name,
"relationship": relationship,
}
edges.append(
(
"upstream",
snap.payload.get("upstream_hash", ""),
snap.payload.get("upstream", ""),
snap.payload.get("relationship", "depends_on"),
)
)

elif snap.event_type == "has_dependent":
related_hash = snap.payload.get("downstream_hash", "")
related_name = snap.payload.get("downstream", "")
relationship = snap.payload.get("relationship", "depends_on")

related = backend.get_model(related_hash) if related_hash else None
if related is None and related_name:
related = backend.get_model_by_name(related_name)
if related is None:
continue

downstream.append(
{
"model_hash": related.model_hash,
"model_name": related.name,
"relationship": relationship,
}
edges.append(
(
"downstream",
snap.payload.get("downstream_hash", ""),
snap.payload.get("downstream", ""),
snap.payload.get("relationship", "depends_on"),
)
)

by_hash = get_models(backend, [h for _, h, _, _ in edges if h])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Use backend get_models in fallback dependencies

When investigate() runs against a backend that has the new get_models method but no batch_dependencies implementation, such as JsonFileLedgerBackend, this fallback still calls the module-level get_models, which loops over backend.get_model once per distinct edge. For json-files that means a full models-directory scan per edge, so the dependency half of the hot path remains O(edges × files) despite the new bulk resolver; dispatch to backend.get_models here when it exists.

Useful? React with 👍 / 👎.


upstream: list[dict[str, Any]] = []
downstream: list[dict[str, Any]] = []
for direction, related_hash, related_name, relationship in edges:
related = by_hash.get(related_hash) if related_hash else None
if related is None and related_name:
related = backend.get_model_by_name(related_name)
if related is None:
continue
entry = {
"model_hash": related.model_hash,
"model_name": related.name,
"relationship": relationship,
}
if direction == "upstream":
upstream.append(entry)
else:
downstream.append(entry)

return {"upstream": upstream, "downstream": downstream}


Expand Down
19 changes: 19 additions & 0 deletions src/model_ledger/backends/json_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,25 @@ def get_model_by_name(self, name: str) -> ModelRef | None:
return ModelRef.model_validate_json(path.read_text())
return None

def get_models(self, model_hashes: list[str]) -> dict[str, ModelRef]:
"""Bulk-resolve model hashes with a single directory scan.

``get_model`` scans the whole models directory per call; resolving N
edges that way is O(N x files). This reads every model file once and
indexes by hash, so a graph traversal pays a single pass.
"""
wanted = {h for h in model_hashes if h}
if not wanted:
return {}
result: dict[str, ModelRef] = {}
for path in self._models_dir.iterdir():
if path.suffix != ".json":
continue
m = ModelRef.model_validate_json(path.read_text())
if m.model_hash in wanted:
result[m.model_hash] = m
return result

def list_models(self, **filters: str) -> list[ModelRef]:
results: list[ModelRef] = []
for path in sorted(self._models_dir.iterdir()):
Expand Down
8 changes: 8 additions & 0 deletions src/model_ledger/backends/ledger_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ def get_model_by_name(self, name: str) -> ModelRef | None:
return m
return None

def get_models(self, model_hashes: list[str]) -> dict[str, ModelRef]:
"""Bulk-resolve model hashes to ModelRefs in one pass.

Returns ``{model_hash: ModelRef}`` for hashes that exist; missing
hashes are omitted. Counts as a single backend round trip.
"""
return {h: self._models[h] for h in dict.fromkeys(model_hashes) if h in self._models}

def list_models(self, **filters: str) -> list[ModelRef]:
text = filters.pop("text", None)
limit = filters.pop("limit", None)
Expand Down
17 changes: 17 additions & 0 deletions src/model_ledger/backends/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,23 @@ def get_model_by_name(self, name: str) -> ModelRef | None:
)
return _row_to_model_ref(rows[0]) if rows else None

def get_models(self, model_hashes: list[str]) -> dict[str, ModelRef]:
"""Bulk-resolve model hashes to ModelRefs with one ``IN (...)`` query."""
self._flush_models()
hashes = [h for h in dict.fromkeys(model_hashes) if h]
if not hashes:
return {}
in_clause = ", ".join(_esc(h) for h in hashes)
rows = _exec(
self._session,
f"SELECT * FROM {self._schema}.MODELS WHERE MODEL_HASH IN ({in_clause})",
)
result: dict[str, ModelRef] = {}
for row in rows:
ref = _row_to_model_ref(row)
result[ref.model_hash] = ref
return result

def list_models(self, **filters: str) -> list[ModelRef]:
self._flush_models()
limit = filters.pop("limit", None)
Expand Down
114 changes: 69 additions & 45 deletions src/model_ledger/backends/sqlite_ledger.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,17 @@ def get_model_by_name(self, name: str) -> ModelRef | None:
row = self._conn.execute("SELECT * FROM models WHERE name = ?", (name,)).fetchone()
return self._row_to_model(row) if row else None

def get_models(self, model_hashes: list[str]) -> dict[str, ModelRef]:
"""Bulk-resolve model hashes to ModelRefs with one ``IN (...)`` query."""
hashes = [h for h in dict.fromkeys(model_hashes) if h]
if not hashes:
return {}
placeholders = ", ".join("?" for _ in hashes)
rows = self._conn.execute(
f"SELECT * FROM models WHERE model_hash IN ({placeholders})", hashes
).fetchall()
return {row["model_hash"]: self._row_to_model(row) for row in rows}

def list_models(self, **filters: str) -> list[ModelRef]:
sql = "SELECT * FROM models"
params: list[str] = []
Expand Down Expand Up @@ -439,9 +450,9 @@ def batch_dependencies(
self,
model_hash: str,
) -> dict[str, list[dict]]:
upstream: list[dict] = []
downstream: list[dict] = []

# One query for the edge snapshots; one batched query to resolve every
# edge target by hash. A per-edge name fallback runs only when a hash
# does not resolve, preserving the prior resolution semantics.
if self._has_json_extract:
rows = self._conn.execute(
"SELECT s.event_type, "
Expand All @@ -456,30 +467,27 @@ def batch_dependencies(
(model_hash,),
).fetchall()

# (direction, target_hash, target_name, relationship)
edges: list[tuple[str, str, str, str]] = []
for row in rows:
if row["event_type"] == "depends_on":
related_hash = row["upstream_hash"] or ""
related_name = row["upstream_name"] or ""
else:
related_hash = row["downstream_hash"] or ""
related_name = row["downstream_name"] or ""
relationship = row["relationship"] or "depends_on"

related = self.get_model(related_hash) if related_hash else None
if related is None and related_name:
related = self.get_model_by_name(related_name)
if related is None:
continue

entry = {
"model_hash": related.model_hash,
"model_name": related.name,
"relationship": relationship,
}
if row["event_type"] == "depends_on":
upstream.append(entry)
edges.append(
(
"upstream",
row["upstream_hash"] or "",
row["upstream_name"] or "",
row["relationship"] or "depends_on",
)
)
else:
downstream.append(entry)
edges.append(
(
"downstream",
row["downstream_hash"] or "",
row["downstream_name"] or "",
row["relationship"] or "depends_on",
)
)
else:
rows = self._conn.execute(
"SELECT s.event_type, s.payload "
Expand All @@ -489,31 +497,47 @@ def batch_dependencies(
(model_hash,),
).fetchall()

edges = []
for row in rows:
payload = json.loads(row["payload"]) if row["payload"] else {}
if row["event_type"] == "depends_on":
related_hash = payload.get("upstream_hash", "")
related_name = payload.get("upstream", "")
edges.append(
(
"upstream",
payload.get("upstream_hash", ""),
payload.get("upstream", ""),
payload.get("relationship", "depends_on"),
)
)
else:
related_hash = payload.get("downstream_hash", "")
related_name = payload.get("downstream", "")
relationship = payload.get("relationship", "depends_on")

related = self.get_model(related_hash) if related_hash else None
if related is None and related_name:
related = self.get_model_by_name(related_name)
if related is None:
continue

entry = {
"model_hash": related.model_hash,
"model_name": related.name,
"relationship": relationship,
}
if row["event_type"] == "depends_on":
upstream.append(entry)
else:
downstream.append(entry)
edges.append(
(
"downstream",
payload.get("downstream_hash", ""),
payload.get("downstream", ""),
payload.get("relationship", "depends_on"),
)
)

by_hash = self.get_models([h for _, h, _, _ in edges if h])

upstream: list[dict] = []
downstream: list[dict] = []
for direction, related_hash, related_name, relationship in edges:
related = by_hash.get(related_hash) if related_hash else None
if related is None and related_name:
related = self.get_model_by_name(related_name)
if related is None:
continue
entry = {
"model_hash": related.model_hash,
"model_name": related.name,
"relationship": relationship,
}
if direction == "upstream":
upstream.append(entry)
else:
downstream.append(entry)

return {"upstream": upstream, "downstream": downstream}

Expand Down
Loading
Loading