Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/model_ledger/sdk/ledger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from model_ledger.backends.ledger_memory import InMemoryLedgerBackend
from model_ledger.backends.ledger_protocol import LedgerBackend
from model_ledger.core.enums import ModelStatus
from model_ledger.core.exceptions import ModelNotFoundError
from model_ledger.core.ledger_models import ModelRef, Snapshot, Tag

Expand All @@ -29,6 +30,21 @@ class ConnectResult(TypedDict):
links_skipped: int


def _normalize_status(raw: object) -> str | None:
"""Coerce a connector-discovered status to its canonical ModelStatus value.

Returns None for absent or unrecognized values — "no opinion" — so callers
leave the stored status untouched. A connector that stops reporting status
must never regress an explicitly set status back to the default.
"""
if not isinstance(raw, str):
return None
try:
return ModelStatus(raw).value
except ValueError:
return None


# Events that are internal ledger bookkeeping or governance actions on the
# composite itself. These are NOT propagated as member_changed to parent
# composites — only real domain events on member models should surface there.
Expand Down Expand Up @@ -347,6 +363,15 @@ def add(self, nodes: DataNode | builtins.list[DataNode]) -> AddResult:

Skips writing if the discovered payload is identical to the last snapshot
(content-hash dedup). Preloads existing models in bulk to avoid per-node queries.

Recognized ``node.metadata`` keys map onto the model row: ``owner``,
``model_type``/``node_type``/``type``, ``tier``, ``purpose``,
``model_origin``, and ``status``. A discovered ``status`` (any
``ModelStatus`` value, case-insensitive) is propagated to the model —
including already-registered models — so lifecycle changes detected at
the source (e.g. ``deprecated`` for an entity deleted upstream) reach
the model row on the next sync. An absent or unrecognized status leaves
the stored status unchanged.
"""
import hashlib
import json
Expand Down Expand Up @@ -377,6 +402,7 @@ def add(self, nodes: DataNode | builtins.list[DataNode]) -> AddResult:
added = 0
skipped = 0
for node in nodes:
node_status = _normalize_status(node.metadata.get("status"))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Accept status metadata case-insensitively

When status comes from sql_connector without an explicit metadata_columns mapping and the DB driver returns uppercase column names (for example Snowflake-style rows with STATUS), the connector preserves the original key in node.metadata (src/model_ledger/connectors/sql.py lines 237-240). This lookup only checks lowercase status, so the new propagation path is skipped and the model row remains active even though the discovered snapshot carries the source status. Consider normalizing metadata keys or looking up status case-insensitively here.

Useful? React with 👍 / 👎.

ref = self.register(
name=node.name,
owner=node.metadata.get("owner") or "unknown",
Expand All @@ -387,6 +413,7 @@ def add(self, nodes: DataNode | builtins.list[DataNode]) -> AddResult:
tier=node.metadata.get("tier") or "unclassified",
purpose=node.metadata.get("purpose") or "",
model_origin=node.metadata.get("model_origin") or "internal",
status=node_status or "active",
actor=f"connector:{node.platform}" if node.platform else "system",
)
payload = {
Expand Down Expand Up @@ -425,6 +452,13 @@ def add(self, nodes: DataNode | builtins.list[DataNode]) -> AddResult:
).hexdigest()
# Update last_seen on every run, even if unchanged
ref.last_seen = datetime.now(timezone.utc)
# Propagate a discovered status onto the model row. register()
# returns existing refs unchanged, so a connector-derived status
# must be applied here for update_model() to persist it. This runs
# before the dedup check so the row self-corrects even when the
# snapshot is skipped as unchanged.
if node_status is not None and node_status != ref.status:
ref.status = node_status
self._backend.update_model(ref)

if existing_hashes.get(ref.model_hash) == content_hash:
Expand Down
105 changes: 101 additions & 4 deletions tests/test_backends/test_snowflake_ledger.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ def sql(self, query: str, params: Any = None) -> MockCollectResult:
return MockCollectResult([])

if "MERGE INTO" in upper and ".MODELS " in upper:
# Handle batched MERGE with UNION ALL
# Handle batched MERGE with UNION ALL. The MATCHED branch of the
# real MERGE rewrites STATUS, so the mock applies last-write-wins
# per model_hash — same observable behavior.
for m in re.finditer(
r"SELECT\s+'([^']+)'\s+AS\s+model_hash,\s+'([^']+)'\s+AS\s+name,\s+'([^']+)'\s+AS\s+owner,\s+'([^']+)'\s+AS\s+model_type,\s+'([^']+)'\s+AS\s+model_origin,\s+'([^']+)'\s+AS\s+tier",
r"SELECT\s+'([^']+)'\s+AS\s+model_hash,\s+'([^']+)'\s+AS\s+name,\s+'([^']+)'\s+AS\s+owner,\s+'([^']+)'\s+AS\s+model_type,\s+'([^']+)'\s+AS\s+model_origin,\s+'([^']+)'\s+AS\s+tier,\s+'([^']*)'\s+AS\s+purpose,\s+'([^']+)'\s+AS\s+status",
query,
re.DOTALL,
):
Expand All @@ -45,8 +47,8 @@ def sql(self, query: str, params: Any = None) -> MockCollectResult:
"MODEL_TYPE": m.group(4),
"MODEL_ORIGIN": m.group(5),
"TIER": m.group(6),
"PURPOSE": "",
"STATUS": "active",
"PURPOSE": m.group(7),
"STATUS": m.group(8),
"CREATED_AT": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
return MockCollectResult([])
Expand Down Expand Up @@ -324,6 +326,101 @@ def sql(self, query: str, params: Any = None) -> MockCollectResult:
)


class TestStatusPropagationSQL:
"""Connector-discovered status must land in the MODELS table via the MERGE.

Both flush MERGE paths SET STATUS on match, so once Ledger.add() assigns
ref.status, existing rows self-correct on the next sync. These tests drive
Ledger.add() end-to-end through the SQL MERGE path and assert the stored
row — not just the snapshot payload — carries the discovered status.
"""

def _ledger(self, session):
from model_ledger.backends.snowflake import SnowflakeLedgerBackend
from model_ledger.sdk.ledger import Ledger

backend = SnowflakeLedgerBackend(schema="TEST_SCHEMA", connection=session)
return Ledger(backend), backend

def test_new_model_status_reaches_models_table(self):
from model_ledger.graph.models import DataNode

session = MockLedgerSession()
ledger, backend = self._ledger(session)
ledger.add(
DataNode(
"fraud_scorer",
platform="ml_platform",
outputs=["scores"],
metadata={"status": "deprecated"},
)
)
backend.flush()
ref = backend.get_model_by_name("fraud_scorer")
assert ref is not None
assert ref.status == "deprecated"

def test_existing_row_status_flip_rewrites_via_merge(self):
from model_ledger.graph.models import DataNode

session = MockLedgerSession()
ledger, backend = self._ledger(session)
ledger.add(DataNode("fraud_scorer", platform="ml_platform", outputs=["scores"]))
backend.flush()
assert backend.get_model_by_name("fraud_scorer").status == "active"

# A later sync (fresh SDK cache, rows re-read from the table) discovers
# the entity was deleted at the source and derives status=deprecated.
ledger2, backend2 = self._ledger(session)
ledger2.add(
DataNode(
"fraud_scorer",
platform="ml_platform",
outputs=["scores"],
metadata={"status": "deprecated"},
)
)
backend2.flush()
ref = backend2.get_model_by_name("fraud_scorer")
assert ref is not None
assert ref.status == "deprecated"

def test_status_parity_with_in_memory_backend(self):
"""The same discovery sequence yields the same final status whether it
runs through the in-memory backend or the Snowflake SQL MERGE path."""
from model_ledger.backends.ledger_memory import InMemoryLedgerBackend
from model_ledger.backends.snowflake import SnowflakeLedgerBackend
from model_ledger.graph.models import DataNode
from model_ledger.sdk.ledger import Ledger

# absent -> deprecated -> unknown (ignored) -> absent (kept)
sequence = [None, "deprecated", "not-a-status", None]

def final_status(backend):
for status in sequence:
metadata = {"status": status} if status is not None else {}
ledger = Ledger(backend) # fresh SDK cache per sync
ledger.add(
DataNode(
"fraud_scorer",
platform="ml_platform",
outputs=["scores"],
metadata=metadata,
)
)
if hasattr(backend, "flush"):
backend.flush()
ref = backend.get_model_by_name("fraud_scorer")
assert ref is not None
return ref.status

in_memory = final_status(InMemoryLedgerBackend())
snowflake = final_status(
SnowflakeLedgerBackend(schema="TEST_SCHEMA", connection=MockLedgerSession())
)
assert in_memory == snowflake == "deprecated"


class FakeCompositeSummarySession:
"""Captures every SQL statement; returns canned rows for the summary query."""

Expand Down
101 changes: 101 additions & 0 deletions tests/test_graph/test_ledger_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,107 @@ def test_add_omits_change_occurred_when_absent(self, ledger):
snap = [s for s in ledger.history("scorer") if s.event_type == "discovered"][0]
assert "change_occurred" not in snap.payload

def test_add_new_model_with_status(self, ledger):
ledger.add(
DataNode(
"scorer",
platform="ml_platform",
outputs=["scores"],
metadata={"status": "deprecated"},
)
)
assert ledger.get("scorer").status == "deprecated"

def test_add_defaults_status_to_active(self, ledger):
ledger.add(DataNode("scorer", platform="ml_platform", outputs=["scores"]))
assert ledger.get("scorer").status == "active"

def test_add_flips_status_on_existing_model(self, ledger):
ledger.add(DataNode("scorer", platform="ml_platform", outputs=["scores"]))
assert ledger.get("scorer").status == "active"
ledger.add(
DataNode(
"scorer",
platform="ml_platform",
outputs=["scores"],
metadata={"status": "deprecated"},
)
)
assert ledger.get("scorer").status == "deprecated"

def test_add_status_flip_records_new_discovered_snapshot(self, ledger):
ledger.add(DataNode("scorer", platform="ml_platform", outputs=["scores"]))
result = ledger.add(
DataNode(
"scorer",
platform="ml_platform",
outputs=["scores"],
metadata={"status": "deprecated"},
)
)
assert result["added"] == 1
assert result["skipped"] == 0

def test_add_status_absent_keeps_existing_status(self, ledger):
ledger.add(
DataNode(
"scorer",
platform="ml_platform",
outputs=["scores"],
metadata={"status": "deprecated"},
)
)
ledger.add(DataNode("scorer", platform="ml_platform", outputs=["scores"]))
assert ledger.get("scorer").status == "deprecated"

def test_add_unknown_status_ignored(self, ledger):
ledger.add(
DataNode(
"scorer",
platform="ml_platform",
outputs=["scores"],
metadata={"status": "deprecated"},
)
)
ledger.add(
DataNode(
"scorer",
platform="ml_platform",
outputs=["scores"],
metadata={"status": "not-a-status"},
)
)
assert ledger.get("scorer").status == "deprecated"

def test_add_status_case_insensitive_normalized(self, ledger):
ledger.add(
DataNode(
"scorer",
platform="ml_platform",
outputs=["scores"],
metadata={"status": "DEPRECATED"},
)
)
assert ledger.get("scorer").status == "deprecated"

def test_add_status_applied_even_when_snapshot_deduped(self, ledger):
"""A status flip self-corrects the model row even if the discovered
payload matches an earlier snapshot (content-hash dedup skip)."""
node_meta = {"status": "deprecated"}
ledger.add(
DataNode("scorer", platform="ml_platform", outputs=["scores"], metadata=node_meta)
)
# Manually regress the stored status, simulating drift in the model row.
ref = ledger.get("scorer")
ref.status = "active"
ledger._backend.update_model(ref)
# Same payload again: snapshot dedup skips, but the status still lands.
result = ledger.add(
DataNode("scorer", platform="ml_platform", outputs=["scores"], metadata=node_meta)
)
assert result["skipped"] == 1
assert ledger.get("scorer").status == "deprecated"

def test_add_source_updated_at_does_not_affect_dedup(self, ledger):
node1 = DataNode(
"scorer",
Expand Down
31 changes: 31 additions & 0 deletions tests/test_invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
differing content yields a differing hash (tamper-evident).
3. Ordered history — history() returns snapshots newest-first by timestamp.
4. Point-in-time — inventory_at(t) reflects only models that existed at t.
5. Status propagation — a model's status equals the last valid status discovered
by add(); absent or unknown statuses never regress it.
"""

from __future__ import annotations
Expand All @@ -24,6 +26,7 @@
from model_ledger import Ledger
from model_ledger.backends.ledger_memory import InMemoryLedgerBackend
from model_ledger.core.ledger_models import Snapshot
from model_ledger.graph.models import DataNode

# Safe alphabets keep the focus on the invariants, not unicode-encoding edge cases.
_TOKEN = st.text(alphabet="abcdefghijklmnopqrstuvwxyz0123456789_-", min_size=1, max_size=12)
Expand Down Expand Up @@ -90,6 +93,34 @@ def test_history_is_ordered_newest_first(seq: list[tuple[str, dict]]) -> None:
assert timestamps == sorted(timestamps, reverse=True)


# Independent oracle for status propagation: the canonical form of every valid
# discovered status. Anything else (unknown strings, empty, absent) is "no opinion".
_CANONICAL_STATUS = {s: s for s in ("development", "review", "active", "deprecated", "retired")} | {
"ACTIVE": "active",
"Deprecated": "deprecated",
"RETIRED": "retired",
}
_DISCOVERED_STATUSES = st.lists(
st.one_of(st.none(), st.sampled_from([*_CANONICAL_STATUS, "", "not-a-status", "unknown"])),
max_size=8,
)


@settings(deadline=None, max_examples=40)
@given(statuses=_DISCOVERED_STATUSES)
def test_status_equals_last_valid_discovered_status(statuses: list[str | None]) -> None:
"""add() propagates each valid discovered status to the model; absent or
unrecognized statuses leave it untouched (never regressing to the default)."""
ledger = _ledger()
expected = "active"
for raw in statuses:
metadata = {"status": raw} if raw is not None else {}
ledger.add(DataNode("m", platform="p", outputs=["t"], metadata=metadata))
expected = _CANONICAL_STATUS.get(raw, expected) if raw is not None else expected
if statuses: # add() was called at least once, so the model exists
assert ledger.get("m").status == expected


@settings(deadline=None, max_examples=30)
@given(names=st.lists(_TOKEN, unique=True, max_size=5))
def test_point_in_time_reflects_only_models_that_existed(names: list[str]) -> None:
Expand Down
Loading