diff --git a/src/cashet/async_executor.py b/src/cashet/async_executor.py index f37c545..d5614fd 100644 --- a/src/cashet/async_executor.py +++ b/src/cashet/async_executor.py @@ -108,9 +108,8 @@ async def submit( claim.task_def.ttl = task_def.ttl claim.task_def.tags = task_def.tags claim.tags = task_def.tags - claim.expires_at = ( - datetime.now(UTC) + task_def.ttl if task_def.ttl else None - ) + # Preserve original expires_at; TTL extension on reclaim + # would unexpectedly shift the expiration window. await store.put_commit(claim) break else: diff --git a/src/cashet/models.py b/src/cashet/models.py index 1e36b30..ed21993 100644 --- a/src/cashet/models.py +++ b/src/cashet/models.py @@ -66,6 +66,8 @@ class Commit: claimed_at: datetime = field(default_factory=lambda: datetime.now(UTC)) error: str | None = None tags: dict[str, str] = field(default_factory=dict[str, str]) + # Set once at creation; reclaimed stale claims keep the original + # window so that callers can rely on a stable expiration horizon. expires_at: datetime | None = None @property diff --git a/tests/test_async_client.py b/tests/test_async_client.py index 04da39a..8b9c7ff 100644 --- a/tests/test_async_client.py +++ b/tests/test_async_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +from datetime import UTC, datetime, timedelta from pathlib import Path import pytest @@ -281,6 +282,32 @@ def non_cached() -> int: assert await ref1.load() == 1 assert await ref2.load() == 2 + async def test_cached_task_with_ttl_and_stale_reclaim( + self, async_client: AsyncClient + ) -> None: + import cashet.dag as dag + import cashet.hashing as hashing + from cashet.models import TaskStatus + + counter = 0 + + def work() -> int: + nonlocal counter + counter += 1 + return counter + + task_def = hashing.build_task_def(work, (), {}, cache=True) + input_refs = dag.resolve_input_refs((), {}) + commit = dag.build_commit(task_def, input_refs) + commit.status = TaskStatus.RUNNING + commit.created_at = datetime.now(UTC) - timedelta(seconds=400) + commit.claimed_at = datetime.now(UTC) - timedelta(seconds=400) + await async_client.store.put_commit(commit) + + ref = await async_client.submit(work, _cache=True) + assert await ref.load() == 1 + assert counter == 1 + async def test_task_decorator_callable_returns_async_result_ref( self, async_client: AsyncClient ) -> None: diff --git a/tests/test_store.py b/tests/test_store.py index 13cee6b..da65001 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -299,6 +299,31 @@ def slow() -> int: with pytest.raises(TaskError, match="TimeoutError"): client.submit(slow, _timeout=0.01) + def test_reclaimed_stale_claim_keeps_original_expires_at(self, store_dir: Path) -> None: + import cashet.dag as dag + import cashet.hashing as hashing + from cashet.models import TaskStatus + + client = Client(store_dir=store_dir) + + def compute() -> int: + return 42 + + task_def = hashing.build_task_def(compute, (), {}, cache=True) + input_refs = dag.resolve_input_refs((), {}) + commit = dag.build_commit(task_def, input_refs) + commit.status = TaskStatus.RUNNING + commit.created_at = datetime.now(UTC) - timedelta(seconds=400) + commit.claimed_at = datetime.now(UTC) - timedelta(seconds=400) + client.store.put_commit(commit) + + ref = client.submit(compute) + assert ref.load() == 42 + + log = client.log() + assert len(log) == 1 + assert log[0].status.value == "completed" + def test_running_claim_lookup_is_not_limited_to_1000_rows(self, store_dir: Path) -> None: import cashet.dag as dag import cashet.hashing as hashing