diff --git a/src/model_ledger/backends/snowflake.py b/src/model_ledger/backends/snowflake.py index fb1079f..77530e7 100644 --- a/src/model_ledger/backends/snowflake.py +++ b/src/model_ledger/backends/snowflake.py @@ -7,6 +7,8 @@ from __future__ import annotations import json +import threading +from collections.abc import Callable from datetime import datetime, timezone from typing import Any @@ -14,6 +16,28 @@ BATCH_SIZE = 50 +# Snowflake error code raised when a session's auth token has idle-expired +# ("Authentication token has expired"). This is the precise signal we react to. +_AUTH_EXPIRED_ERRNO = 390114 +_AUTH_EXPIRED_MESSAGE = "authentication token has expired" + + +def _is_auth_expiry_error(exc: BaseException) -> bool: + """True only for the Snowflake auth-token-expired error. + + Matches on errno ``390114`` first (the authoritative signal) and falls back + to the canonical message text for drivers that surface the code only in the + string. Deliberately narrow: an unrelated ``ProgrammingError`` (bad SQL, + missing table, permission denied) must NOT look like an auth expiry, so we + never match on the exception type alone. + """ + if getattr(exc, "errno", None) == _AUTH_EXPIRED_ERRNO: + return True + # Some driver/version combinations leave errno unset but embed the code and + # the canonical phrase in the message. Require BOTH to match conservatively. + text = (str(getattr(exc, "msg", "") or "") + " " + str(exc)).lower() + return str(_AUTH_EXPIRED_ERRNO) in text and _AUTH_EXPIRED_MESSAGE in text + def _exec(session: Any, sql: str) -> list[dict[str, Any]]: if hasattr(session, "execute"): @@ -103,15 +127,42 @@ class SnowflakeLedgerBackend: is persisted, or use as a context manager. Tables: MODELS, SNAPSHOTS, TAGS in the given schema. + + Reconnect-on-auth-expiry + ------------------------ + A long-lived backend holds one Snowflake session. When that session's auth + token idle-expires, every subsequent statement fails with + ``ProgrammingError`` errno ``390114`` ("Authentication token has expired") + until the process restarts. Passing ``connection_factory`` lets the backend + self-heal: on a detected auth-expiry error it calls the factory to obtain a + fresh connection, swaps it in, and retries the *same* statement exactly + once. A second consecutive auth-expiry (or any other error) propagates. + + Factory contract: ``connection_factory()`` must return a *ready-to-use* + connection — same account/user/auth and, where relevant, warehouse, role, + and current database as the original. The backend issues no session-setup + (``USE``) statements; it addresses every object with a fully-qualified + ``{schema}`` name, so the factory owns all session configuration. This + composes with — and does not replace — the driver's + ``client_session_keep_alive`` heartbeat: heartbeats reduce idle expiry but + cannot eliminate it (network blips, very long idle, a stalled heartbeat + thread), and this path is the backstop for the residual cases. + + If only ``connection`` is given (no factory), behavior is unchanged: an + auth-expiry error propagates exactly as before, with no reconnect. """ def __init__( self, - connection: Any, + connection: Any = None, schema: str = "MODEL_LEDGER", read_only: bool = False, + connection_factory: Callable[[], Any] | None = None, ) -> None: - self._session = connection + if connection is None and connection_factory is None: + raise ValueError("provide connection, connection_factory, or both") + self._connection_factory = connection_factory + self._session = connection if connection is not None else connection_factory() # type: ignore[misc] self._schema = schema self._read_only = read_only parts = schema.split(".") @@ -119,9 +170,52 @@ def __init__( self._schema_name = parts[1] if len(parts) > 1 else parts[0] self._model_buffer: list[ModelRef] = [] self._snapshot_buffer: list[Snapshot] = [] + # Serializes the connection swap so two threads can't both reconnect. + self._reconnect_lock = threading.Lock() if not read_only: self._ensure_tables() + def _reconnect(self, stale: Any) -> None: + """Swap in a fresh connection, guarding against concurrent reconnects. + + ``stale`` is the connection the caller observed failing. After taking + the lock we re-check it against the current ``self._session``: if another + thread already reconnected, ours is a no-op and the winning session is + reused. + """ + if self._connection_factory is None: + return + with self._reconnect_lock: + if self._session is not stale: + # A concurrent caller already reconnected; reuse that session. + return + self._session = self._connection_factory() + + def _exec(self, sql: str) -> list[dict[str, Any]]: + """Run a result-returning statement, self-healing on auth expiry.""" + session = self._session + try: + return _exec(session, sql) + except Exception as exc: + if self._connection_factory is None or not _is_auth_expiry_error(exc): + raise + self._reconnect(session) + # Retry exactly once on the fresh session. A second auth expiry + # (or any other error) propagates. + return _exec(self._session, sql) + + def _exec_no_result(self, sql: str) -> None: + """Run a non-result statement, self-healing on auth expiry.""" + session = self._session + try: + _exec_no_result(session, sql) + return + except Exception as exc: + if self._connection_factory is None or not _is_auth_expiry_error(exc): + raise + self._reconnect(session) + _exec_no_result(self._session, sql) + def __enter__(self): return self @@ -183,15 +277,14 @@ def _flush_models_pandas(self) -> bool: ) staging = f"{self._schema}.MODELS_STAGING" - _exec_no_result( - self._session, f"CREATE OR REPLACE TEMPORARY TABLE {staging} LIKE {self._schema}.MODELS" + self._exec_no_result( + f"CREATE OR REPLACE TEMPORARY TABLE {staging} LIKE {self._schema}.MODELS" ) wp_kwargs: dict[str, str] = {"schema": self._schema_name} if self._database: wp_kwargs["database"] = self._database write_pandas(conn, df, "MODELS_STAGING", **wp_kwargs) # type: ignore[arg-type] - _exec_no_result( - self._session, + self._exec_no_result( f""" MERGE INTO {self._schema}.MODELS t USING {staging} s ON t.MODEL_HASH = s.MODEL_HASH WHEN MATCHED THEN UPDATE SET @@ -203,7 +296,7 @@ def _flush_models_pandas(self) -> bool: VALUES (s.MODEL_HASH, s.NAME, s.OWNER, s.MODEL_TYPE, s.MODEL_ORIGIN, s.TIER, s.PURPOSE, s.STATUS, s.CREATED_AT, s.LAST_SEEN, PARSE_JSON(s.METADATA))""", ) - _exec_no_result(self._session, f"DROP TABLE IF EXISTS {staging}") + self._exec_no_result(f"DROP TABLE IF EXISTS {staging}") return True def _flush_models_sql(self) -> None: @@ -222,8 +315,7 @@ def _flush_models_sql(self) -> None: f"{_esc(json.dumps(m.metadata, default=str)) if m.metadata else 'NULL'} AS metadata" for m in batch ) - _exec_no_result( - self._session, + self._exec_no_result( f""" MERGE INTO {self._schema}.MODELS t USING ({unions}) s ON t.MODEL_HASH = s.model_hash WHEN MATCHED THEN UPDATE SET @@ -275,8 +367,7 @@ def _flush_snapshots_pandas(self) -> bool: ) staging = f"{self._schema}.SNAPSHOTS_STAGING" - _exec_no_result( - self._session, + self._exec_no_result( f""" CREATE OR REPLACE TEMPORARY TABLE {staging} ( SNAPSHOT_HASH VARCHAR, MODEL_HASH VARCHAR, PARENT_HASH VARCHAR, @@ -287,8 +378,7 @@ def _flush_snapshots_pandas(self) -> bool: if self._database: wp_kwargs["database"] = self._database write_pandas(conn, df, "SNAPSHOTS_STAGING", **wp_kwargs) # type: ignore[arg-type] - _exec_no_result( - self._session, + self._exec_no_result( f""" INSERT INTO {self._schema}.SNAPSHOTS (SNAPSHOT_HASH, MODEL_HASH, PARENT_HASH, TIMESTAMP, ACTOR, EVENT_TYPE, SOURCE, PAYLOAD, TAGS) @@ -297,7 +387,7 @@ def _flush_snapshots_pandas(self) -> bool: FROM {staging} s WHERE NOT EXISTS (SELECT 1 FROM {self._schema}.SNAPSHOTS t WHERE t.SNAPSHOT_HASH = s.SNAPSHOT_HASH)""", ) - _exec_no_result(self._session, f"DROP TABLE IF EXISTS {staging}") + self._exec_no_result(f"DROP TABLE IF EXISTS {staging}") return True def _flush_snapshots_sql(self) -> None: @@ -312,8 +402,7 @@ def _flush_snapshots_sql(self) -> None: f"{_esc(s.actor)}, {_esc(s.event_type)}, {_esc(s.source)}" for s in batch ) - _exec_no_result( - self._session, + self._exec_no_result( f""" INSERT INTO {self._schema}.SNAPSHOTS (SNAPSHOT_HASH, MODEL_HASH, PARENT_HASH, TIMESTAMP, ACTOR, EVENT_TYPE, SOURCE) @@ -322,9 +411,8 @@ def _flush_snapshots_sql(self) -> None: ) def _ensure_tables(self) -> None: - _exec_no_result(self._session, f"CREATE SCHEMA IF NOT EXISTS {self._schema}") - _exec_no_result( - self._session, + self._exec_no_result(f"CREATE SCHEMA IF NOT EXISTS {self._schema}") + self._exec_no_result( f""" CREATE TABLE IF NOT EXISTS {self._schema}.MODELS ( MODEL_HASH VARCHAR PRIMARY KEY, NAME VARCHAR UNIQUE NOT NULL, @@ -340,15 +428,13 @@ def _ensure_tables(self) -> None: # the "already exists" case — other DDL errors (missing permission, transient # failure) must surface so startup doesn't silently leave MERGEs broken. try: - _exec_no_result( - self._session, + self._exec_no_result( f"ALTER TABLE {self._schema}.MODELS ADD COLUMN METADATA VARIANT", ) except Exception as e: if "already exists" not in str(e).lower(): raise - _exec_no_result( - self._session, + self._exec_no_result( f""" CREATE TABLE IF NOT EXISTS {self._schema}.SNAPSHOTS ( SNAPSHOT_HASH VARCHAR PRIMARY KEY, MODEL_HASH VARCHAR NOT NULL, @@ -356,8 +442,7 @@ def _ensure_tables(self) -> None: ACTOR VARCHAR NOT NULL, EVENT_TYPE VARCHAR NOT NULL, SOURCE VARCHAR, PAYLOAD VARIANT, TAGS VARIANT)""", ) - _exec_no_result( - self._session, + self._exec_no_result( f""" CREATE TABLE IF NOT EXISTS {self._schema}.TAGS ( MODEL_HASH VARCHAR NOT NULL, NAME VARCHAR NOT NULL, @@ -370,8 +455,7 @@ def save_model(self, model: ModelRef) -> None: def get_model(self, model_hash: str) -> ModelRef | None: self._flush_models() - rows = _exec( - self._session, + rows = self._exec( f"SELECT * FROM {self._schema}.MODELS WHERE MODEL_HASH = {_esc(model_hash)}", ) return _row_to_model_ref(rows[0]) if rows else None @@ -380,9 +464,7 @@ def get_model_by_name(self, name: str) -> ModelRef | None: for m in self._model_buffer: if m.name == name: return m - rows = _exec( - self._session, f"SELECT * FROM {self._schema}.MODELS WHERE NAME = {_esc(name)}" - ) + rows = self._exec(f"SELECT * FROM {self._schema}.MODELS WHERE NAME = {_esc(name)}") return _row_to_model_ref(rows[0]) if rows else None def list_models(self, **filters: str) -> list[ModelRef]: @@ -405,7 +487,7 @@ def list_models(self, **filters: str) -> list[ModelRef]: sql += f" LIMIT {int(limit)}" if offset is not None: sql += f" OFFSET {int(offset)}" - return [_row_to_model_ref(r) for r in _exec(self._session, sql)] + return [_row_to_model_ref(r) for r in self._exec(sql)] def count_models(self, **filters: str) -> int: """Count models matching filters without fetching all rows.""" @@ -423,7 +505,7 @@ def count_models(self, **filters: str) -> int: ) if conditions: sql += " WHERE " + " AND ".join(conditions) - rows = _exec(self._session, sql) + rows = self._exec(sql) return rows[0]["CNT"] if rows else 0 def update_model(self, model: ModelRef) -> None: @@ -434,8 +516,7 @@ def append_snapshot(self, snapshot: Snapshot) -> None: def get_snapshot(self, snapshot_hash: str) -> Snapshot | None: self._flush_snapshots() - rows = _exec( - self._session, + rows = self._exec( f"SELECT * FROM {self._schema}.SNAPSHOTS WHERE SNAPSHOT_HASH = {_esc(snapshot_hash)}", ) return _row_to_snapshot(rows[0]) if rows else None @@ -446,7 +527,7 @@ def list_snapshots(self, model_hash: str, **filters: str) -> list[Snapshot]: for k, v in filters.items(): sql += f" AND {k.upper()} = {_esc(v)}" sql += " ORDER BY TIMESTAMP" - return [_row_to_snapshot(r) for r in _exec(self._session, sql)] + return [_row_to_snapshot(r) for r in self._exec(sql)] def list_all_snapshots(self, event_type: str | None = None) -> list[Snapshot]: """Bulk load all snapshots — 1 query instead of N per-model queries.""" @@ -454,7 +535,7 @@ def list_all_snapshots(self, event_type: str | None = None) -> list[Snapshot]: sql = f"SELECT * FROM {self._schema}.SNAPSHOTS" if event_type: sql += f" WHERE EVENT_TYPE = {_esc(event_type)}" - return [_row_to_snapshot(r) for r in _exec(self._session, sql)] + return [_row_to_snapshot(r) for r in self._exec(sql)] def list_snapshot_content_hashes(self, event_type: str | None = None) -> dict[str, str]: """Read _content_hash from payloads — returns {model_hash: content_hash}. @@ -470,9 +551,7 @@ def list_snapshot_content_hashes(self, event_type: str | None = None) -> dict[st QUALIFY ROW_NUMBER() OVER (PARTITION BY MODEL_HASH ORDER BY TIMESTAMP DESC) = 1 """ return { - r["MODEL_HASH"]: r["CONTENT_HASH"] - for r in _exec(self._session, sql) - if r.get("CONTENT_HASH") + r["MODEL_HASH"]: r["CONTENT_HASH"] for r in self._exec(sql) if r.get("CONTENT_HASH") } def composite_summary( @@ -606,7 +685,7 @@ def composite_summary( LEFT JOIN validations v ON v.COMPOSITE_HASH = c.MODEL_HASH LEFT JOIN open_obs oo ON oo.COMPOSITE_HASH = c.MODEL_HASH ORDER BY c.NAME""" - rows = _exec(self._session, sql) + rows = self._exec(sql) results = [] for r in rows: raw = r.get("METADATA") or {} @@ -637,8 +716,7 @@ def latest_snapshot(self, model_hash: str, tag: str | None = None) -> Snapshot | if t: return self.get_snapshot(t.snapshot_hash) return None - rows = _exec( - self._session, + rows = self._exec( f"SELECT * FROM {self._schema}.SNAPSHOTS WHERE MODEL_HASH = {_esc(model_hash)} ORDER BY TIMESTAMP DESC LIMIT 1", ) return _row_to_snapshot(rows[0]) if rows else None @@ -657,11 +735,10 @@ def list_snapshots_before( if event_type: sql += f" AND EVENT_TYPE = {_esc(event_type)}" sql += " ORDER BY TIMESTAMP" - return [_row_to_snapshot(r) for r in _exec(self._session, sql)] + return [_row_to_snapshot(r) for r in self._exec(sql)] def set_tag(self, tag: Tag) -> None: - _exec_no_result( - self._session, + self._exec_no_result( f""" MERGE INTO {self._schema}.TAGS t USING (SELECT {_esc(tag.model_hash)} AS model_hash, {_esc(tag.name)} AS name, @@ -674,8 +751,7 @@ def set_tag(self, tag: Tag) -> None: ) def get_tag(self, model_hash: str, name: str) -> Tag | None: - rows = _exec( - self._session, + rows = self._exec( f"SELECT * FROM {self._schema}.TAGS WHERE MODEL_HASH = {_esc(model_hash)} AND NAME = {_esc(name)}", ) return _row_to_tag(rows[0]) if rows else None @@ -683,8 +759,7 @@ def get_tag(self, model_hash: str, name: str) -> Tag | None: def list_tags(self, model_hash: str) -> list[Tag]: return [ _row_to_tag(r) - for r in _exec( - self._session, + for r in self._exec( f"SELECT * FROM {self._schema}.TAGS WHERE MODEL_HASH = {_esc(model_hash)} ORDER BY NAME", ) ] @@ -692,8 +767,7 @@ def list_tags(self, model_hash: str) -> list[Tag]: def count_all_snapshots(self) -> int: """Count all snapshots in a single query.""" self._flush_snapshots() - rows = _exec( - self._session, + rows = self._exec( f"SELECT COUNT(*) AS CNT FROM {self._schema}.SNAPSHOTS", ) return rows[0]["CNT"] if rows else 0 @@ -708,8 +782,7 @@ def model_summaries( self._flush_snapshots() in_clause = ", ".join(_esc(h) for h in model_hashes) - agg_rows = _exec( - self._session, + agg_rows = self._exec( f""" SELECT MODEL_HASH, MAX(TIMESTAMP) AS LAST_EVENT, @@ -719,8 +792,7 @@ def model_summaries( GROUP BY MODEL_HASH""", ) - plat_rows = _exec( - self._session, + plat_rows = self._exec( f""" SELECT MODEL_HASH, COALESCE(PAYLOAD:platform::VARCHAR, SOURCE) AS PLATFORM @@ -790,14 +862,12 @@ def changelog_page( where = (" WHERE " + " AND ".join(conditions)) if conditions else "" - count_rows = _exec( - self._session, + count_rows = self._exec( f"SELECT COUNT(*) AS CNT FROM {self._schema}.SNAPSHOTS s{where}", ) total = count_rows[0]["CNT"] if count_rows else 0 - data_rows = _exec( - self._session, + data_rows = self._exec( f""" SELECT s.SNAPSHOT_HASH, s.MODEL_HASH, s.PARENT_HASH, s.TIMESTAMP, s.ACTOR, s.EVENT_TYPE, s.SOURCE, @@ -839,8 +909,7 @@ def batch_dependencies( self._flush_snapshots() self._flush_models() - dep_rows = _exec( - self._session, + dep_rows = self._exec( f""" SELECT EVENT_TYPE, @@ -884,8 +953,7 @@ def batch_dependencies( f"NAME IN ({', '.join(_esc(n) for n in lookup_names)})" if lookup_names else None ) or_clause = " OR ".join(filter(None, [hash_cond, name_cond])) - model_rows = _exec( - self._session, + model_rows = self._exec( f"SELECT MODEL_HASH, NAME FROM {self._schema}.MODELS WHERE {or_clause}", ) model_by_hash = {r["MODEL_HASH"]: r["NAME"] for r in model_rows} @@ -934,8 +1002,7 @@ def batch_platforms( self._flush_snapshots() in_clause = ", ".join(_esc(h) for h in model_hashes) - plat_rows = _exec( - self._session, + plat_rows = self._exec( f""" SELECT MODEL_HASH, COALESCE(PAYLOAD:platform::VARCHAR, SOURCE) AS PLATFORM diff --git a/tests/test_backends/test_snowflake_ledger.py b/tests/test_backends/test_snowflake_ledger.py index e76fcde..eb7e333 100644 --- a/tests/test_backends/test_snowflake_ledger.py +++ b/tests/test_backends/test_snowflake_ledger.py @@ -623,3 +623,206 @@ def test_parity_with_in_memory_fallback(self): ] backend, _ = self._backend(rows) assert backend.composite_summary() == expected + + +class AuthExpiredError(Exception): + """Mimics snowflake.connector.errors.ProgrammingError for the expired-token + case without importing the optional driver. ``_is_auth_expiry_error`` only + duck-types ``.errno`` / ``.msg``, so this is a faithful stand-in.""" + + def __init__(self, errno=390114, msg="390114: Authentication token has expired"): + super().__init__(msg) + self.errno = errno + self.msg = msg + + +class _Cursor: + description = None + + def fetchall(self): + return [] + + +class FakeConnection: + """A cursor-style connection (uses the ``execute`` path in ``_exec``). + + Records the SQL it runs. Configurable to raise a chosen exception on its + first N execute calls, then behave normally (or to fail on every call). + """ + + def __init__(self, raise_exc=None, raise_times=0, fail_all=False): + self.executed: list[str] = [] + self._raise_exc = raise_exc + self._raise_times = raise_times + self._fail_all = fail_all + self._calls = 0 + + def execute(self, sql): + self._calls += 1 + if self._raise_exc is not None and (self._fail_all or self._calls <= self._raise_times): + raise self._raise_exc() if isinstance(self._raise_exc, type) else self._raise_exc + self.executed.append(sql) + return _Cursor() + + +def _backend_with_factory(factory, **kwargs): + from model_ledger.backends.snowflake import SnowflakeLedgerBackend + + # read_only=True skips _ensure_tables DDL so each test drives exactly the + # statement it cares about. + return SnowflakeLedgerBackend( + schema="TEST_SCHEMA", read_only=True, connection_factory=factory, **kwargs + ) + + +class TestReconnectOnAuthExpiry: + def test_no_factory_means_no_reconnect_error_propagates(self): + """With only a connection (no factory), an auth-expiry error propagates + unchanged — the v0 behavior, no regression.""" + from model_ledger.backends.snowflake import SnowflakeLedgerBackend + + conn = FakeConnection(raise_exc=AuthExpiredError, fail_all=True) + backend = SnowflakeLedgerBackend(schema="TEST_SCHEMA", read_only=True, connection=conn) + with pytest.raises(AuthExpiredError): + backend.count_all_snapshots() + + def test_auth_expiry_once_then_success_on_retry(self): + """On a single auth-expiry, the backend reconnects via the factory and + retries the same statement once, which succeeds.""" + stale = FakeConnection(raise_exc=AuthExpiredError, raise_times=1) + fresh = FakeConnection() + conns = iter([stale, fresh]) + + backend = _backend_with_factory(lambda: next(conns)) + # The factory produced `stale` at construction. The query hits the stale + # session, gets the expiry, reconnects to `fresh`, and retries. + result = backend.count_all_snapshots() + assert result == 0 # _Cursor.description is None -> _exec returns [] + assert backend._session is fresh + # The retried statement actually ran on the fresh connection. + assert any("COUNT(*)" in s for s in fresh.executed) + + def test_non_auth_programming_error_is_not_retried(self): + """A non-auth error (e.g. bad SQL / missing table) must propagate + without any reconnect attempt.""" + + class BadSqlError(Exception): + errno = 1003 # not 390114 + msg = "SQL compilation error: object does not exist" + + stale = FakeConnection(raise_exc=BadSqlError, fail_all=True) + reconnects = {"n": 0} + + def factory(): + reconnects["n"] += 1 + return stale + + backend = _backend_with_factory(factory) + before = reconnects["n"] # the one construction-time call + with pytest.raises(BadSqlError): + backend.count_all_snapshots() + # No reconnect happened beyond the one-time construction call. + assert reconnects["n"] == before + + def test_second_consecutive_auth_expiry_propagates(self): + """If the fresh connection ALSO raises auth-expiry, the second failure + propagates — we retry exactly once, never in a loop.""" + first = FakeConnection(raise_exc=AuthExpiredError, fail_all=True) + second = FakeConnection(raise_exc=AuthExpiredError, fail_all=True) + conns = iter([first, second]) + + backend = _backend_with_factory(lambda: next(conns)) + with pytest.raises(AuthExpiredError): + backend.count_all_snapshots() + # Reconnected exactly once (swapped to `second`), then gave up. + assert backend._session is second + + def test_message_only_auth_expiry_is_detected(self): + """Drivers that leave errno unset but embed the code + phrase in the + message are still recognized as auth expiry.""" + + class MessageOnlyError(Exception): + errno = None + msg = "390114 (08001): Authentication token has expired. Reauthenticate." + + stale = FakeConnection(raise_exc=MessageOnlyError, raise_times=1) + fresh = FakeConnection() + conns = iter([stale, fresh]) + + backend = _backend_with_factory(lambda: next(conns)) + backend.count_all_snapshots() + assert backend._session is fresh + + def test_concurrency_guard_single_reconnect(self): + """Two threads hit the expired session concurrently; the lock + re-check + ensure the factory reconnects exactly once and both threads end on the + same fresh session.""" + import threading + + # The first connection fails for both threads' first attempt. After the + # single reconnect, the fresh connection succeeds for everyone. + stale = FakeConnection(raise_exc=AuthExpiredError, fail_all=True) + fresh = FakeConnection() + + factory_calls = {"n": 0} + factory_lock = threading.Lock() + first_yielded = {"done": False} + + def factory(): + with factory_lock: + factory_calls["n"] += 1 + # Construction yields the stale connection first. + if not first_yielded["done"]: + first_yielded["done"] = True + return stale + return fresh + + backend = _backend_with_factory(factory) + assert backend._session is stale + construction_calls = factory_calls["n"] + + barrier = threading.Barrier(2) + errors: list[BaseException] = [] + + def worker(): + try: + barrier.wait() + backend.count_all_snapshots() + except BaseException as exc: # noqa: BLE001 + errors.append(exc) + + t1 = threading.Thread(target=worker) + t2 = threading.Thread(target=worker) + t1.start() + t2.start() + t1.join() + t2.join() + + assert errors == [] + assert backend._session is fresh + # Exactly one reconnect beyond the construction-time factory call. + assert factory_calls["n"] == construction_calls + 1 + + def test_no_result_statement_also_self_heals(self): + """The reconnect-retry-once path covers non-result statements too + (writes/DDL), not just queries.""" + stale = FakeConnection(raise_exc=AuthExpiredError, raise_times=1) + fresh = FakeConnection() + conns = iter([stale, fresh]) + + backend = _backend_with_factory(lambda: next(conns)) + tag = Tag( + model_hash="m1", + name="latest", + snapshot_hash="snap1", + updated_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + backend.set_tag(tag) # routes through _exec_no_result + assert backend._session is fresh + assert any("MERGE INTO" in s for s in fresh.executed) + + def test_requires_connection_or_factory(self): + from model_ledger.backends.snowflake import SnowflakeLedgerBackend + + with pytest.raises(ValueError, match="connection"): + SnowflakeLedgerBackend(schema="TEST_SCHEMA")