diff --git a/src/model_ledger/sdk/ledger.py b/src/model_ledger/sdk/ledger.py index 42a03f3..6208cfa 100644 --- a/src/model_ledger/sdk/ledger.py +++ b/src/model_ledger/sdk/ledger.py @@ -8,6 +8,7 @@ from model_ledger.backends.ledger_memory import InMemoryLedgerBackend from model_ledger.backends.ledger_protocol import LedgerBackend +from model_ledger.core.enums import ModelStatus from model_ledger.core.exceptions import ModelNotFoundError from model_ledger.core.ledger_models import ModelRef, Snapshot, Tag @@ -29,6 +30,21 @@ class ConnectResult(TypedDict): links_skipped: int +def _normalize_status(raw: object) -> str | None: + """Coerce a connector-discovered status to its canonical ModelStatus value. + + Returns None for absent or unrecognized values — "no opinion" — so callers + leave the stored status untouched. A connector that stops reporting status + must never regress an explicitly set status back to the default. + """ + if not isinstance(raw, str): + return None + try: + return ModelStatus(raw).value + except ValueError: + return None + + # Events that are internal ledger bookkeeping or governance actions on the # composite itself. These are NOT propagated as member_changed to parent # composites — only real domain events on member models should surface there. @@ -347,6 +363,15 @@ def add(self, nodes: DataNode | builtins.list[DataNode]) -> AddResult: Skips writing if the discovered payload is identical to the last snapshot (content-hash dedup). Preloads existing models in bulk to avoid per-node queries. + + Recognized ``node.metadata`` keys map onto the model row: ``owner``, + ``model_type``/``node_type``/``type``, ``tier``, ``purpose``, + ``model_origin``, and ``status``. A discovered ``status`` (any + ``ModelStatus`` value, case-insensitive) is propagated to the model — + including already-registered models — so lifecycle changes detected at + the source (e.g. ``deprecated`` for an entity deleted upstream) reach + the model row on the next sync. An absent or unrecognized status leaves + the stored status unchanged. """ import hashlib import json @@ -377,6 +402,7 @@ def add(self, nodes: DataNode | builtins.list[DataNode]) -> AddResult: added = 0 skipped = 0 for node in nodes: + node_status = _normalize_status(node.metadata.get("status")) ref = self.register( name=node.name, owner=node.metadata.get("owner") or "unknown", @@ -387,6 +413,7 @@ def add(self, nodes: DataNode | builtins.list[DataNode]) -> AddResult: tier=node.metadata.get("tier") or "unclassified", purpose=node.metadata.get("purpose") or "", model_origin=node.metadata.get("model_origin") or "internal", + status=node_status or "active", actor=f"connector:{node.platform}" if node.platform else "system", ) payload = { @@ -425,6 +452,13 @@ def add(self, nodes: DataNode | builtins.list[DataNode]) -> AddResult: ).hexdigest() # Update last_seen on every run, even if unchanged ref.last_seen = datetime.now(timezone.utc) + # Propagate a discovered status onto the model row. register() + # returns existing refs unchanged, so a connector-derived status + # must be applied here for update_model() to persist it. This runs + # before the dedup check so the row self-corrects even when the + # snapshot is skipped as unchanged. + if node_status is not None and node_status != ref.status: + ref.status = node_status self._backend.update_model(ref) if existing_hashes.get(ref.model_hash) == content_hash: diff --git a/tests/test_backends/test_snowflake_ledger.py b/tests/test_backends/test_snowflake_ledger.py index 5de29f9..e76fcde 100644 --- a/tests/test_backends/test_snowflake_ledger.py +++ b/tests/test_backends/test_snowflake_ledger.py @@ -32,9 +32,11 @@ def sql(self, query: str, params: Any = None) -> MockCollectResult: return MockCollectResult([]) if "MERGE INTO" in upper and ".MODELS " in upper: - # Handle batched MERGE with UNION ALL + # Handle batched MERGE with UNION ALL. The MATCHED branch of the + # real MERGE rewrites STATUS, so the mock applies last-write-wins + # per model_hash — same observable behavior. for m in re.finditer( - r"SELECT\s+'([^']+)'\s+AS\s+model_hash,\s+'([^']+)'\s+AS\s+name,\s+'([^']+)'\s+AS\s+owner,\s+'([^']+)'\s+AS\s+model_type,\s+'([^']+)'\s+AS\s+model_origin,\s+'([^']+)'\s+AS\s+tier", + r"SELECT\s+'([^']+)'\s+AS\s+model_hash,\s+'([^']+)'\s+AS\s+name,\s+'([^']+)'\s+AS\s+owner,\s+'([^']+)'\s+AS\s+model_type,\s+'([^']+)'\s+AS\s+model_origin,\s+'([^']+)'\s+AS\s+tier,\s+'([^']*)'\s+AS\s+purpose,\s+'([^']+)'\s+AS\s+status", query, re.DOTALL, ): @@ -45,8 +47,8 @@ def sql(self, query: str, params: Any = None) -> MockCollectResult: "MODEL_TYPE": m.group(4), "MODEL_ORIGIN": m.group(5), "TIER": m.group(6), - "PURPOSE": "", - "STATUS": "active", + "PURPOSE": m.group(7), + "STATUS": m.group(8), "CREATED_AT": datetime(2025, 1, 1, tzinfo=timezone.utc), } return MockCollectResult([]) @@ -324,6 +326,101 @@ def sql(self, query: str, params: Any = None) -> MockCollectResult: ) +class TestStatusPropagationSQL: + """Connector-discovered status must land in the MODELS table via the MERGE. + + Both flush MERGE paths SET STATUS on match, so once Ledger.add() assigns + ref.status, existing rows self-correct on the next sync. These tests drive + Ledger.add() end-to-end through the SQL MERGE path and assert the stored + row — not just the snapshot payload — carries the discovered status. + """ + + def _ledger(self, session): + from model_ledger.backends.snowflake import SnowflakeLedgerBackend + from model_ledger.sdk.ledger import Ledger + + backend = SnowflakeLedgerBackend(schema="TEST_SCHEMA", connection=session) + return Ledger(backend), backend + + def test_new_model_status_reaches_models_table(self): + from model_ledger.graph.models import DataNode + + session = MockLedgerSession() + ledger, backend = self._ledger(session) + ledger.add( + DataNode( + "fraud_scorer", + platform="ml_platform", + outputs=["scores"], + metadata={"status": "deprecated"}, + ) + ) + backend.flush() + ref = backend.get_model_by_name("fraud_scorer") + assert ref is not None + assert ref.status == "deprecated" + + def test_existing_row_status_flip_rewrites_via_merge(self): + from model_ledger.graph.models import DataNode + + session = MockLedgerSession() + ledger, backend = self._ledger(session) + ledger.add(DataNode("fraud_scorer", platform="ml_platform", outputs=["scores"])) + backend.flush() + assert backend.get_model_by_name("fraud_scorer").status == "active" + + # A later sync (fresh SDK cache, rows re-read from the table) discovers + # the entity was deleted at the source and derives status=deprecated. + ledger2, backend2 = self._ledger(session) + ledger2.add( + DataNode( + "fraud_scorer", + platform="ml_platform", + outputs=["scores"], + metadata={"status": "deprecated"}, + ) + ) + backend2.flush() + ref = backend2.get_model_by_name("fraud_scorer") + assert ref is not None + assert ref.status == "deprecated" + + def test_status_parity_with_in_memory_backend(self): + """The same discovery sequence yields the same final status whether it + runs through the in-memory backend or the Snowflake SQL MERGE path.""" + from model_ledger.backends.ledger_memory import InMemoryLedgerBackend + from model_ledger.backends.snowflake import SnowflakeLedgerBackend + from model_ledger.graph.models import DataNode + from model_ledger.sdk.ledger import Ledger + + # absent -> deprecated -> unknown (ignored) -> absent (kept) + sequence = [None, "deprecated", "not-a-status", None] + + def final_status(backend): + for status in sequence: + metadata = {"status": status} if status is not None else {} + ledger = Ledger(backend) # fresh SDK cache per sync + ledger.add( + DataNode( + "fraud_scorer", + platform="ml_platform", + outputs=["scores"], + metadata=metadata, + ) + ) + if hasattr(backend, "flush"): + backend.flush() + ref = backend.get_model_by_name("fraud_scorer") + assert ref is not None + return ref.status + + in_memory = final_status(InMemoryLedgerBackend()) + snowflake = final_status( + SnowflakeLedgerBackend(schema="TEST_SCHEMA", connection=MockLedgerSession()) + ) + assert in_memory == snowflake == "deprecated" + + class FakeCompositeSummarySession: """Captures every SQL statement; returns canned rows for the summary query.""" diff --git a/tests/test_graph/test_ledger_graph.py b/tests/test_graph/test_ledger_graph.py index 4782c1e..5f7613b 100644 --- a/tests/test_graph/test_ledger_graph.py +++ b/tests/test_graph/test_ledger_graph.py @@ -83,6 +83,107 @@ def test_add_omits_change_occurred_when_absent(self, ledger): snap = [s for s in ledger.history("scorer") if s.event_type == "discovered"][0] assert "change_occurred" not in snap.payload + def test_add_new_model_with_status(self, ledger): + ledger.add( + DataNode( + "scorer", + platform="ml_platform", + outputs=["scores"], + metadata={"status": "deprecated"}, + ) + ) + assert ledger.get("scorer").status == "deprecated" + + def test_add_defaults_status_to_active(self, ledger): + ledger.add(DataNode("scorer", platform="ml_platform", outputs=["scores"])) + assert ledger.get("scorer").status == "active" + + def test_add_flips_status_on_existing_model(self, ledger): + ledger.add(DataNode("scorer", platform="ml_platform", outputs=["scores"])) + assert ledger.get("scorer").status == "active" + ledger.add( + DataNode( + "scorer", + platform="ml_platform", + outputs=["scores"], + metadata={"status": "deprecated"}, + ) + ) + assert ledger.get("scorer").status == "deprecated" + + def test_add_status_flip_records_new_discovered_snapshot(self, ledger): + ledger.add(DataNode("scorer", platform="ml_platform", outputs=["scores"])) + result = ledger.add( + DataNode( + "scorer", + platform="ml_platform", + outputs=["scores"], + metadata={"status": "deprecated"}, + ) + ) + assert result["added"] == 1 + assert result["skipped"] == 0 + + def test_add_status_absent_keeps_existing_status(self, ledger): + ledger.add( + DataNode( + "scorer", + platform="ml_platform", + outputs=["scores"], + metadata={"status": "deprecated"}, + ) + ) + ledger.add(DataNode("scorer", platform="ml_platform", outputs=["scores"])) + assert ledger.get("scorer").status == "deprecated" + + def test_add_unknown_status_ignored(self, ledger): + ledger.add( + DataNode( + "scorer", + platform="ml_platform", + outputs=["scores"], + metadata={"status": "deprecated"}, + ) + ) + ledger.add( + DataNode( + "scorer", + platform="ml_platform", + outputs=["scores"], + metadata={"status": "not-a-status"}, + ) + ) + assert ledger.get("scorer").status == "deprecated" + + def test_add_status_case_insensitive_normalized(self, ledger): + ledger.add( + DataNode( + "scorer", + platform="ml_platform", + outputs=["scores"], + metadata={"status": "DEPRECATED"}, + ) + ) + assert ledger.get("scorer").status == "deprecated" + + def test_add_status_applied_even_when_snapshot_deduped(self, ledger): + """A status flip self-corrects the model row even if the discovered + payload matches an earlier snapshot (content-hash dedup skip).""" + node_meta = {"status": "deprecated"} + ledger.add( + DataNode("scorer", platform="ml_platform", outputs=["scores"], metadata=node_meta) + ) + # Manually regress the stored status, simulating drift in the model row. + ref = ledger.get("scorer") + ref.status = "active" + ledger._backend.update_model(ref) + # Same payload again: snapshot dedup skips, but the status still lands. + result = ledger.add( + DataNode("scorer", platform="ml_platform", outputs=["scores"], metadata=node_meta) + ) + assert result["skipped"] == 1 + assert ledger.get("scorer").status == "deprecated" + def test_add_source_updated_at_does_not_affect_dedup(self, ledger): node1 = DataNode( "scorer", diff --git a/tests/test_invariants.py b/tests/test_invariants.py index a865bcc..93de1d8 100644 --- a/tests/test_invariants.py +++ b/tests/test_invariants.py @@ -11,6 +11,8 @@ differing content yields a differing hash (tamper-evident). 3. Ordered history — history() returns snapshots newest-first by timestamp. 4. Point-in-time — inventory_at(t) reflects only models that existed at t. + 5. Status propagation — a model's status equals the last valid status discovered + by add(); absent or unknown statuses never regress it. """ from __future__ import annotations @@ -24,6 +26,7 @@ from model_ledger import Ledger from model_ledger.backends.ledger_memory import InMemoryLedgerBackend from model_ledger.core.ledger_models import Snapshot +from model_ledger.graph.models import DataNode # Safe alphabets keep the focus on the invariants, not unicode-encoding edge cases. _TOKEN = st.text(alphabet="abcdefghijklmnopqrstuvwxyz0123456789_-", min_size=1, max_size=12) @@ -90,6 +93,34 @@ def test_history_is_ordered_newest_first(seq: list[tuple[str, dict]]) -> None: assert timestamps == sorted(timestamps, reverse=True) +# Independent oracle for status propagation: the canonical form of every valid +# discovered status. Anything else (unknown strings, empty, absent) is "no opinion". +_CANONICAL_STATUS = {s: s for s in ("development", "review", "active", "deprecated", "retired")} | { + "ACTIVE": "active", + "Deprecated": "deprecated", + "RETIRED": "retired", +} +_DISCOVERED_STATUSES = st.lists( + st.one_of(st.none(), st.sampled_from([*_CANONICAL_STATUS, "", "not-a-status", "unknown"])), + max_size=8, +) + + +@settings(deadline=None, max_examples=40) +@given(statuses=_DISCOVERED_STATUSES) +def test_status_equals_last_valid_discovered_status(statuses: list[str | None]) -> None: + """add() propagates each valid discovered status to the model; absent or + unrecognized statuses leave it untouched (never regressing to the default).""" + ledger = _ledger() + expected = "active" + for raw in statuses: + metadata = {"status": raw} if raw is not None else {} + ledger.add(DataNode("m", platform="p", outputs=["t"], metadata=metadata)) + expected = _CANONICAL_STATUS.get(raw, expected) if raw is not None else expected + if statuses: # add() was called at least once, so the model exists + assert ledger.get("m").status == expected + + @settings(deadline=None, max_examples=30) @given(names=st.lists(_TOKEN, unique=True, max_size=5)) def test_point_in_time_reflects_only_models_that_existed(names: list[str]) -> None: