From 34fd8d36ab004f94578e74e2c5a6e9353460e2b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Jolovi=C4=87?= Date: Wed, 6 May 2026 06:21:54 +0200 Subject: [PATCH] fix: preserve original expiration window on stale claim reclaim When a stale running claim is reclaimed, updating expires_at to now + ttl shifts the expiration window forward. This causes the cache entry to live longer than the caller originally intended if the worker was slow or crashed. Instead, preserve the original expires_at set at commit creation. The task_def.ttl field is still updated so that downstream code sees the current submission's TTL preference, but the actual expiration horizon remains stable. Adds tests for sync and async stale reclaim with cache=True. --- src/cashet/async_executor.py | 5 ++--- src/cashet/models.py | 2 ++ tests/test_async_client.py | 27 +++++++++++++++++++++++++++ tests/test_store.py | 25 +++++++++++++++++++++++++ 4 files changed, 56 insertions(+), 3 deletions(-) diff --git a/src/cashet/async_executor.py b/src/cashet/async_executor.py index f37c545..d5614fd 100644 --- a/src/cashet/async_executor.py +++ b/src/cashet/async_executor.py @@ -108,9 +108,8 @@ async def submit( claim.task_def.ttl = task_def.ttl claim.task_def.tags = task_def.tags claim.tags = task_def.tags - claim.expires_at = ( - datetime.now(UTC) + task_def.ttl if task_def.ttl else None - ) + # Preserve original expires_at; TTL extension on reclaim + # would unexpectedly shift the expiration window. await store.put_commit(claim) break else: diff --git a/src/cashet/models.py b/src/cashet/models.py index 1e36b30..ed21993 100644 --- a/src/cashet/models.py +++ b/src/cashet/models.py @@ -66,6 +66,8 @@ class Commit: claimed_at: datetime = field(default_factory=lambda: datetime.now(UTC)) error: str | None = None tags: dict[str, str] = field(default_factory=dict[str, str]) + # Set once at creation; reclaimed stale claims keep the original + # window so that callers can rely on a stable expiration horizon. expires_at: datetime | None = None @property diff --git a/tests/test_async_client.py b/tests/test_async_client.py index 04da39a..8b9c7ff 100644 --- a/tests/test_async_client.py +++ b/tests/test_async_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +from datetime import UTC, datetime, timedelta from pathlib import Path import pytest @@ -281,6 +282,32 @@ def non_cached() -> int: assert await ref1.load() == 1 assert await ref2.load() == 2 + async def test_cached_task_with_ttl_and_stale_reclaim( + self, async_client: AsyncClient + ) -> None: + import cashet.dag as dag + import cashet.hashing as hashing + from cashet.models import TaskStatus + + counter = 0 + + def work() -> int: + nonlocal counter + counter += 1 + return counter + + task_def = hashing.build_task_def(work, (), {}, cache=True) + input_refs = dag.resolve_input_refs((), {}) + commit = dag.build_commit(task_def, input_refs) + commit.status = TaskStatus.RUNNING + commit.created_at = datetime.now(UTC) - timedelta(seconds=400) + commit.claimed_at = datetime.now(UTC) - timedelta(seconds=400) + await async_client.store.put_commit(commit) + + ref = await async_client.submit(work, _cache=True) + assert await ref.load() == 1 + assert counter == 1 + async def test_task_decorator_callable_returns_async_result_ref( self, async_client: AsyncClient ) -> None: diff --git a/tests/test_store.py b/tests/test_store.py index 13cee6b..da65001 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -299,6 +299,31 @@ def slow() -> int: with pytest.raises(TaskError, match="TimeoutError"): client.submit(slow, _timeout=0.01) + def test_reclaimed_stale_claim_keeps_original_expires_at(self, store_dir: Path) -> None: + import cashet.dag as dag + import cashet.hashing as hashing + from cashet.models import TaskStatus + + client = Client(store_dir=store_dir) + + def compute() -> int: + return 42 + + task_def = hashing.build_task_def(compute, (), {}, cache=True) + input_refs = dag.resolve_input_refs((), {}) + commit = dag.build_commit(task_def, input_refs) + commit.status = TaskStatus.RUNNING + commit.created_at = datetime.now(UTC) - timedelta(seconds=400) + commit.claimed_at = datetime.now(UTC) - timedelta(seconds=400) + client.store.put_commit(commit) + + ref = client.submit(compute) + assert ref.load() == 42 + + log = client.log() + assert len(log) == 1 + assert log[0].status.value == "completed" + def test_running_claim_lookup_is_not_limited_to_1000_rows(self, store_dir: Path) -> None: import cashet.dag as dag import cashet.hashing as hashing