From 1b74feea9db15efd3f3e68212cb01f012212fc50 Mon Sep 17 00:00:00 2001 From: Denzyl Holder Layne Date: Tue, 28 Apr 2026 12:05:40 -0400 Subject: [PATCH 1/5] fix: Classify non-retriable customer errors for Durable API calls MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Non-retriable customer errors from Lambda (e.g., KMSAccessDeniedException, KMSDisabledException) arrive as HTTP 502 during CheckpointDurableExecution and GetDurableExecutionState API calls. Without this change, these 502s are classified as retriable invocation errors, causing the SDK to retry invocations that will never self-resolve. Extract shared classification logic into BotoClientError with a _NON_RETRIABLE_CUSTOMER_ERROR_CODES set and _classify_error_category method. Both CheckpointError and GetExecutionStateError now inherit from_exception() with unified classification so that non-retriable errors (KMS 502s, 4xx client errors) return Status: FAILED immediately, while retriable errors (5xx, 429, network errors) continue to raise for Lambda retry. Add is_retriable() to InvocationError hierarchy so execution.py handlers use a single interface instead of isinstance checks. Testing — parameterized classification tests across both error types for all four KMS codes, 4xx, 429, 5xx, and 502 scenarios. Integration tests for non-retriable/retriable errors across all three execution.py code paths: initial pagination, background thread, and user thread. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --- .../exceptions.py | 123 ++++++++++---- .../execution.py | 55 ++++-- tests/exceptions_test.py | 160 ++++++++++++++---- tests/execution_test.py | 156 ++++++++++++++++- 4 files changed, 409 insertions(+), 85 deletions(-) diff --git a/src/aws_durable_execution_sdk_python/exceptions.py b/src/aws_durable_execution_sdk_python/exceptions.py index 336996c..e88df53 100644 --- a/src/aws_durable_execution_sdk_python/exceptions.py +++ b/src/aws_durable_execution_sdk_python/exceptions.py @@ -13,6 +13,22 @@ BAD_REQUEST_ERROR: int = 400 TOO_MANY_REQUESTS_ERROR: int = 429 SERVICE_ERROR: int = 500 +INVALID_PARAMETER_VALUE_EXCEPTION: str = "InvalidParameterValueException" +INVALID_CHECKPOINT_TOKEN_PREFIX: str = "Invalid Checkpoint Token" + +# Non-retriable customer error codes that arrive as non-4xx (e.g. HTTP 502) from Lambda. +# Unlike typical 5xx errors, these require customer intervention (e.g., fixing +# a KMS key configuration) and will never succeed on retry. +# Add new non-retriable error codes here — they are automatically classified +# as EXECUTION (non-retriable) by _classify_error_category(). +_NON_RETRIABLE_CUSTOMER_ERROR_CODES: frozenset[str] = frozenset( + { + "KMSAccessDeniedException", + "KMSDisabledException", + "KMSInvalidStateException", + "KMSNotFoundException", + } +) if TYPE_CHECKING: import datetime @@ -77,6 +93,14 @@ def __init__( ): super().__init__(message, termination_reason) + def is_retriable(self) -> bool: + """Whether this error is retriable. Returns True by default. + + Subclasses override to implement classification logic based on + error codes and HTTP status codes. + """ + return True + class CallbackError(ExecutionError): """Error in callback handling.""" @@ -86,10 +110,28 @@ def __init__(self, message: str, callback_id: str | None = None): self.callback_id = callback_id +class DurableApiErrorCategory(Enum): + INVOCATION = "INVOCATION" + EXECUTION = "EXECUTION" + + +# Backward-compatible alias +CheckpointErrorCategory = DurableApiErrorCategory + + class BotoClientError(InvocationError): + """Error from a Lambda API call (e.g., CheckpointDurableExecution, GetDurableExecutionState). + + Extends InvocationError because the default behavior for API failures is to retry + the Lambda invocation. However, some errors are non-retriable (e.g., 4xx client errors, + KMS key misconfiguration) and should fail the execution instead. The error_category field + and is_retriable() method distinguish these cases at runtime. + """ + def __init__( self, message: str, + error_category: DurableApiErrorCategory = DurableApiErrorCategory.INVOCATION, error: AwsErrorObj | None = None, response_metadata: AwsErrorMetadata | None = None, termination_reason=TerminationReason.INVOCATION_ERROR, @@ -97,16 +139,57 @@ def __init__( super().__init__(message=message, termination_reason=termination_reason) self.error: AwsErrorObj | None = error self.response_metadata: AwsErrorMetadata | None = response_metadata + self.error_category: DurableApiErrorCategory = error_category @classmethod def from_exception(cls, exception: Exception) -> Self: response = getattr(exception, "response", {}) response_metadata = response.get("ResponseMetadata") error = response.get("Error") + error_category = BotoClientError._classify_error_category(error, response_metadata) return cls( - message=str(exception), error=error, response_metadata=response_metadata + message=str(exception), error_category=error_category, error=error, response_metadata=response_metadata ) + @staticmethod + def _classify_error_category( + error: AwsErrorObj | None, + response_metadata: AwsErrorMetadata | None, + ) -> DurableApiErrorCategory: + """Classify a Durable API error as retriable (INVOCATION) or non-retriable (EXECUTION). + + Classification rules: + - Non-retriable customer error codes (e.g., KMS key issues) → EXECUTION + These arrive as HTTP 502 but require customer intervention to fix. + - 4xx errors → EXECUTION, except: + - 429 (TooManyRequests) → INVOCATION (throttling is transient) + - InvalidParameterValueException with "Invalid Checkpoint Token" → INVOCATION + (stale token from a concurrent checkpoint; next invocation gets a fresh token) + - 5xx, network errors → INVOCATION + """ + error_code: str | None = (error and error.get("Code")) or None + if error_code and error_code in _NON_RETRIABLE_CUSTOMER_ERROR_CODES: + return DurableApiErrorCategory.EXECUTION + + status_code: int | None = (response_metadata and response_metadata.get("HTTPStatusCode")) or None + if ( + status_code + and BAD_REQUEST_ERROR <= status_code < SERVICE_ERROR + and status_code != TOO_MANY_REQUESTS_ERROR + and error + and not ( + (error.get("Code") or "") == INVALID_PARAMETER_VALUE_EXCEPTION + and (error.get("Message") or "").startswith(INVALID_CHECKPOINT_TOKEN_PREFIX) + ) + ): + return DurableApiErrorCategory.EXECUTION + + return DurableApiErrorCategory.INVOCATION + + def is_retriable(self) -> bool: + """Whether this error is retriable based on error_category.""" + return self.error_category == DurableApiErrorCategory.INVOCATION + def build_logger_extras(self) -> dict: extras: dict = {} # preserve PascalCase to be consistent with other langauges @@ -125,55 +208,23 @@ def __init__(self, message: str, step_id: str | None = None): self.step_id = step_id -class CheckpointErrorCategory(Enum): - INVOCATION = "INVOCATION" - EXECUTION = "EXECUTION" - - class CheckpointError(BotoClientError): """Failure to checkpoint. Will terminate the lambda.""" def __init__( self, message: str, - error_category: CheckpointErrorCategory, + error_category: DurableApiErrorCategory = DurableApiErrorCategory.INVOCATION, error: AwsErrorObj | None = None, response_metadata: AwsErrorMetadata | None = None, ): super().__init__( message, + error_category, error, response_metadata, termination_reason=TerminationReason.CHECKPOINT_FAILED, ) - self.error_category: CheckpointErrorCategory = error_category - - @classmethod - def from_exception(cls, exception: Exception) -> CheckpointError: - base = BotoClientError.from_exception(exception) - metadata: AwsErrorMetadata | None = base.response_metadata - error: AwsErrorObj | None = base.error - error_category: CheckpointErrorCategory = CheckpointErrorCategory.INVOCATION - - # 4xx errors (except 429) are permanent failures (EXECUTION), unless it's an - # InvalidParameterValueException with "Invalid Checkpoint Token" which is retriable (INVOCATION). - # 5xx, 429, and network errors are retriable (INVOCATION). - status_code: int | None = (metadata and metadata.get("HTTPStatusCode")) or None - if ( - status_code - and BAD_REQUEST_ERROR <= status_code < SERVICE_ERROR - and status_code != TOO_MANY_REQUESTS_ERROR - and error - and not ( - (error.get("Code") or "") == "InvalidParameterValueException" - and (error.get("Message") or "").startswith("Invalid Checkpoint Token") - ) - ): - error_category = CheckpointErrorCategory.EXECUTION - return CheckpointError(str(exception), error_category, error, metadata) - - def is_retriable(self): - return self.error_category == CheckpointErrorCategory.INVOCATION class ValidationError(DurableExecutionsError): @@ -186,11 +237,13 @@ class GetExecutionStateError(BotoClientError): def __init__( self, message: str, + error_category: DurableApiErrorCategory = DurableApiErrorCategory.INVOCATION, error: AwsErrorObj | None = None, response_metadata: AwsErrorMetadata | None = None, ): super().__init__( message, + error_category, error, response_metadata, termination_reason=TerminationReason.INVOCATION_ERROR, diff --git a/src/aws_durable_execution_sdk_python/execution.py b/src/aws_durable_execution_sdk_python/execution.py index 403111d..e45ade7 100644 --- a/src/aws_durable_execution_sdk_python/execution.py +++ b/src/aws_durable_execution_sdk_python/execution.py @@ -260,11 +260,26 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: else ReplayStatus.NEW, ) - execution_state.fetch_paginated_operations( - invocation_input.initial_execution_state.operations, - invocation_input.checkpoint_token, - invocation_input.initial_execution_state.next_marker, - ) + try: + execution_state.fetch_paginated_operations( + invocation_input.initial_execution_state.operations, + invocation_input.checkpoint_token, + invocation_input.initial_execution_state.next_marker, + ) + except BotoClientError as e: + # Non-retriable Durable API errors (e.g., customer configuration issues, + # 4xx client errors) will never succeed on retry — fail the execution immediately. + if not e.is_retriable(): + logger.exception( + "Non-retriable Durable API error during initial state fetch. Must fail execution " + "without retry.", + extra=e.build_logger_extras(), + ) + return DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, + error=ErrorObject.from_exception(e), + ).to_dict() + raise raw_input_payload: str | None = execution_state.get_input_payload() @@ -356,11 +371,20 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: "Checkpoint processing failed", extra=bg_error.source_exception.build_logger_extras(), ) + # Non-retriable Durable API errors (e.g., customer configuration issues, + # 4xx client errors) will never succeed on retry — fail the execution immediately. + if not bg_error.source_exception.is_retriable(): + logger.exception( + "Non-retriable Durable API error from background thread. Must fail execution " + "without retry.", + extra=bg_error.source_exception.build_logger_extras(), + ) + return DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, + error=ErrorObject.from_exception(bg_error.source_exception), + ).to_dict() else: logger.exception("Checkpoint processing failed") - # handle the original exception - if isinstance(bg_error.source_exception, CheckpointError): - return handle_checkpoint_error(bg_error.source_exception).to_dict() raise bg_error.source_exception from bg_error except SuspendExecution: @@ -377,12 +401,23 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: extra=e.build_logger_extras(), ) return handle_checkpoint_error(e).to_dict() - except InvocationError: + except InvocationError as e: + # Non-retriable Durable API errors (e.g., customer configuration issues, + # 4xx client errors) will never succeed on retry — fail the execution immediately. + if not e.is_retriable(): + logger.exception( + "Non-retriable Durable API error. Must fail execution without retry.", + extra=e.build_logger_extras(), # type: ignore[attr-defined] + ) + return DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, + error=ErrorObject.from_exception(e), + ).to_dict() logger.exception("Invocation error. Must terminate.") # Throw the error to trigger Lambda retry raise except ExecutionError as e: - logger.exception("Execution error. Must terminate without retry.") + logger.exception("Execution error. Must fail execution without retry.") return DurableExecutionInvocationOutput( status=InvocationStatus.FAILED, error=ErrorObject.from_exception(e), diff --git a/tests/exceptions_test.py b/tests/exceptions_test.py index f350b42..f0c0b70 100644 --- a/tests/exceptions_test.py +++ b/tests/exceptions_test.py @@ -7,12 +7,15 @@ from botocore.exceptions import ClientError # type: ignore[import-untyped] from aws_durable_execution_sdk_python.exceptions import ( + BotoClientError, CallableRuntimeError, CallableRuntimeErrorSerializableDetails, CheckpointError, CheckpointErrorCategory, + DurableApiErrorCategory, DurableExecutionsError, ExecutionError, + GetExecutionStateError, InvocationError, OrderedLockError, OrphanedChildException, @@ -87,10 +90,13 @@ def test_checkpoint_error_classification_payload_size_exceeded_execution(): assert not result.is_retriable() -def test_checkpoint_error_classification_other_4xx_execution(): - """Test other 4xx errors are execution errors.""" +def test_checkpoint_error_classification_invalid_param_without_token_execution(): + """Test 4xx InvalidParameterValueException without Invalid Checkpoint Token is execution error.""" error_response = { - "Error": {"Code": "ValidationException", "Message": "Invalid parameter value"}, + "Error": { + "Code": "InvalidParameterValueException", + "Message": "Some other invalid parameter", + }, "ResponseMetadata": {"HTTPStatusCode": 400}, } client_error = ClientError(error_response, "Checkpoint") @@ -101,58 +107,94 @@ def test_checkpoint_error_classification_other_4xx_execution(): assert not result.is_retriable() -def test_checkpoint_error_classification_429_invocation(): - """Test 429 errors are invocation errors (retryable).""" - error_response = { - "Error": {"Code": "TooManyRequestsException", "Message": "Rate limit exceeded"}, - "ResponseMetadata": {"HTTPStatusCode": 429}, - } - client_error = ClientError(error_response, "Checkpoint") +# ============================================================================= +# Shared Durable API error classification tests (BotoClientError._classify_error_category) +# These test the shared classification logic through each BotoClientError subclass. +# ============================================================================= - result = CheckpointError.from_exception(client_error) - assert result.error_category == CheckpointErrorCategory.INVOCATION - assert result.is_retriable() +@pytest.mark.parametrize("error_cls", [CheckpointError, GetExecutionStateError]) +@pytest.mark.parametrize( + "error_code", + [ + "KMSAccessDeniedException", + "KMSDisabledException", + "KMSInvalidStateException", + "KMSNotFoundException", + ], +) +def test_durable_api_error_non_retriable_customer_error_codes(error_cls, error_code: str): + """Test that non-retriable customer error codes (HTTP 502) are classified as EXECUTION.""" + error_response = { + "Error": {"Code": error_code, "Message": f"{error_code} error"}, + "ResponseMetadata": {"HTTPStatusCode": 502}, + } + client_error = ClientError(error_response, "Invoke") + result = error_cls.from_exception(client_error) + assert result.error_category == DurableApiErrorCategory.EXECUTION + assert not result.is_retriable() -def test_checkpoint_error_classification_invalid_param_without_token_execution(): - """Test 4xx InvalidParameterValueException without Invalid Checkpoint Token is execution error.""" +@pytest.mark.parametrize("error_cls", [CheckpointError, GetExecutionStateError]) +def test_durable_api_error_4xx_non_retriable(error_cls): + """Test 4xx errors are classified as EXECUTION (non-retriable).""" error_response = { - "Error": { - "Code": "InvalidParameterValueException", - "Message": "Some other invalid parameter", - }, + "Error": {"Code": "ValidationException", "Message": "Invalid parameter"}, "ResponseMetadata": {"HTTPStatusCode": 400}, } - client_error = ClientError(error_response, "Checkpoint") + client_error = ClientError(error_response, "Invoke") + result = error_cls.from_exception(client_error) + assert result.error_category == DurableApiErrorCategory.EXECUTION + assert not result.is_retriable() - result = CheckpointError.from_exception(client_error) - assert result.error_category == CheckpointErrorCategory.EXECUTION - assert not result.is_retriable() +@pytest.mark.parametrize("error_cls", [CheckpointError, GetExecutionStateError]) +def test_durable_api_error_429_retriable(error_cls): + """Test 429 errors are classified as INVOCATION (retriable).""" + error_response = { + "Error": {"Code": "TooManyRequestsException", "Message": "Rate limit exceeded"}, + "ResponseMetadata": {"HTTPStatusCode": 429}, + } + client_error = ClientError(error_response, "Invoke") + result = error_cls.from_exception(client_error) + assert result.error_category == DurableApiErrorCategory.INVOCATION + assert result.is_retriable() -def test_checkpoint_error_classification_5xx_invocation(): - """Test 5xx errors are invocation errors.""" +@pytest.mark.parametrize("error_cls", [CheckpointError, GetExecutionStateError]) +def test_durable_api_error_5xx_retriable(error_cls): + """Test 5xx errors are classified as INVOCATION (retriable).""" error_response = { "Error": {"Code": "InternalServerError", "Message": "Service unavailable"}, "ResponseMetadata": {"HTTPStatusCode": 500}, } - client_error = ClientError(error_response, "Checkpoint") - - result = CheckpointError.from_exception(client_error) - - assert result.error_category == CheckpointErrorCategory.INVOCATION + client_error = ClientError(error_response, "Invoke") + result = error_cls.from_exception(client_error) + assert result.error_category == DurableApiErrorCategory.INVOCATION assert result.is_retriable() -def test_checkpoint_error_classification_unknown_invocation(): - """Test unknown errors are invocation errors.""" - unknown_error = Exception("Network timeout") +@pytest.mark.parametrize("error_cls", [CheckpointError, GetExecutionStateError]) +def test_durable_api_error_retriable_502(error_cls): + """Test that 502 errors with unrecognized error codes are retriable.""" + error_response = { + "Error": { + "Code": "ServiceException", + "Message": "Service encountered an internal error.", + }, + "ResponseMetadata": {"HTTPStatusCode": 502}, + } + client_error = ClientError(error_response, "Invoke") + result = error_cls.from_exception(client_error) + assert result.error_category == DurableApiErrorCategory.INVOCATION + assert result.is_retriable() - result = CheckpointError.from_exception(unknown_error) - assert result.error_category == CheckpointErrorCategory.INVOCATION +@pytest.mark.parametrize("error_cls", [CheckpointError, GetExecutionStateError]) +def test_durable_api_error_unknown_retriable(error_cls): + """Test unknown errors (no HTTP response) are classified as INVOCATION (retriable).""" + result = error_cls.from_exception(Exception("Network timeout")) + assert result.error_category == DurableApiErrorCategory.INVOCATION assert result.is_retriable() @@ -391,3 +433,51 @@ def test_orphaned_child_exception_with_operation_id(): exception = OrphanedChildException("parent completed", operation_id="child_op_456") assert exception.operation_id == "child_op_456" assert str(exception) == "parent completed" + + +@pytest.mark.parametrize( + ("error_code", "status_code", "expected_retriable"), + [ + ("KMSAccessDeniedException", 502, False), + ("ServiceException", 500, True), + ("ServiceException", 502, True), + ], +) +def test_boto_client_error_is_retriable(error_code: str, status_code: int, expected_retriable: bool): + """Test BotoClientError.is_retriable() classification.""" + error_response = { + "Error": {"Code": error_code, "Message": "test error"}, + "ResponseMetadata": {"HTTPStatusCode": status_code}, + } + client_error = ClientError(error_response, "Invoke") + result = BotoClientError.from_exception(client_error) + assert result.is_retriable() == expected_retriable + + +def test_boto_client_error_is_retriable_no_error(): + """Test BotoClientError.is_retriable() returns True with no error info.""" + result = BotoClientError.from_exception(Exception("network error")) + assert result.is_retriable() + + +# ============================================================================= +# DurableApiErrorCategory backward compatibility +# ============================================================================= + + +def test_durable_api_error_category_backward_compatible_alias(): + """Test CheckpointErrorCategory is a backward-compatible alias for DurableApiErrorCategory.""" + assert CheckpointErrorCategory is DurableApiErrorCategory + assert CheckpointErrorCategory.INVOCATION is DurableApiErrorCategory.INVOCATION + assert CheckpointErrorCategory.EXECUTION is DurableApiErrorCategory.EXECUTION + + +# ============================================================================= +# is_retriable() tests +# ============================================================================= + + +def test_invocation_error_is_retriable_default(): + """Test InvocationError.is_retriable() returns True by default.""" + error = InvocationError("some error") + assert error.is_retriable() \ No newline at end of file diff --git a/tests/execution_test.py b/tests/execution_test.py index eeaf949..605b467 100644 --- a/tests/execution_test.py +++ b/tests/execution_test.py @@ -14,7 +14,9 @@ BotoClientError, CheckpointError, CheckpointErrorCategory, + DurableApiErrorCategory, ExecutionError, + GetExecutionStateError, InvocationError, SuspendExecution, ) @@ -1763,11 +1765,12 @@ def test_handler(event: Any, context: DurableContext) -> dict: assert response["Status"] == InvocationStatus.FAILED.value assert response["Error"]["ErrorType"] == "CheckpointError" - mock_logger.exception.assert_called_once() - call_args = mock_logger.exception.call_args - assert "Checkpoint processing failed" in call_args[0][0] - assert call_args[1]["extra"]["Error"] == error_obj - assert call_args[1]["extra"]["ResponseMetadata"] == metadata_obj + mock_logger.exception.assert_called() + # First call: "Checkpoint processing failed" with error extras + first_call = mock_logger.exception.call_args_list[0] + assert "Checkpoint processing failed" in first_call[0][0] + assert first_call[1]["extra"]["Error"] == error_obj + assert first_call[1]["extra"]["ResponseMetadata"] == metadata_obj def test_durable_execution_logs_boto_client_error_extras_from_background_thread(): @@ -2646,3 +2649,146 @@ def test_from_dict_leaves_timestamps_as_integers(): _ = operations[1].wait_details.scheduled_end_timestamp < datetime.datetime.now( tz=datetime.UTC ) + + +# ============================================================================= +# Non-retriable Durable API error handling tests +# ============================================================================= + + +def _make_invocation_input(mock_client, next_marker=""): + """Helper to create a standard test invocation input.""" + operation = Operation( + operation_id="exec1", + operation_type=OperationType.EXECUTION, + status=OperationStatus.STARTED, + execution_details=ExecutionDetails(input_payload="{}"), + ) + return DurableExecutionInvocationInputWithClient( + durable_execution_arn="arn:test:execution", + checkpoint_token="token123", # noqa: S106 + initial_execution_state=InitialExecutionState(operations=[operation], next_marker=next_marker), + service_client=mock_client, + ) + + +def _make_lambda_context(): + """Helper to create a standard mock Lambda context.""" + ctx = Mock() + ctx.aws_request_id = "test-request" + ctx.client_context = None + ctx.identity = None + ctx._epoch_deadline_time_in_ms = 1000000 # noqa: SLF001 + ctx.invoked_function_arn = None + ctx.tenant_id = None + return ctx + + +def test_durable_execution_non_retriable_invocation_error_returns_failed(): + """Test that non-retriable InvocationError returns FAILED instead of retrying.""" + mock_client = Mock(spec=DurableServiceClient) + non_retriable_error = GetExecutionStateError( + message="KMS access denied", + error_category=DurableApiErrorCategory.EXECUTION, + error={"Code": "KMSAccessDeniedException", "Message": "KMS access denied"}, + response_metadata={"HTTPStatusCode": 502}, + ) + + @durable_execution + def test_handler(event: Any, context: DurableContext) -> dict: + raise non_retriable_error + + result = test_handler(_make_invocation_input(mock_client), _make_lambda_context()) + assert result["Status"] == InvocationStatus.FAILED.value + assert result["Error"]["ErrorType"] == "GetExecutionStateError" + + +def test_durable_execution_retriable_invocation_error_raises(): + """Test that retriable InvocationError raises to trigger Lambda retry.""" + mock_client = Mock(spec=DurableServiceClient) + retriable_error = GetExecutionStateError( + message="Service error", + error={"Code": "ServiceException", "Message": "Internal error"}, + response_metadata={"HTTPStatusCode": 500}, + ) + + @durable_execution + def test_handler(event: Any, context: DurableContext) -> dict: + raise retriable_error + + with pytest.raises(GetExecutionStateError, match="Service error"): + test_handler(_make_invocation_input(mock_client), _make_lambda_context()) + + +def test_durable_execution_non_retriable_background_thread_error_returns_failed(): + """Test that non-retriable error from background thread returns FAILED.""" + mock_client = Mock(spec=DurableServiceClient) + non_retriable_error = GetExecutionStateError( + message="KMS key disabled", + error_category=DurableApiErrorCategory.EXECUTION, + error={"Code": "KMSDisabledException", "Message": "KMS key disabled"}, + response_metadata={"HTTPStatusCode": 502}, + ) + mock_client.checkpoint.side_effect = lambda *a, **kw: (_ for _ in ()).throw(non_retriable_error) + + @durable_execution + def test_handler(event: Any, context: DurableContext) -> dict: + context.step(lambda ctx: "step_result") + return {"result": "success"} + + result = test_handler(_make_invocation_input(mock_client), _make_lambda_context()) + assert result["Status"] == InvocationStatus.FAILED.value + assert result["Error"]["ErrorType"] == "GetExecutionStateError" + + +@pytest.mark.parametrize( + ("error_code", "status_code", "error_category"), + [ + ("KMSAccessDeniedException", 502, DurableApiErrorCategory.EXECUTION), + ("ValidationException", 400, DurableApiErrorCategory.EXECUTION), + ], +) +def test_durable_execution_non_retriable_initial_pagination_error_returns_failed( + error_code: str, status_code: int, error_category: DurableApiErrorCategory +): + """Test that non-retriable errors during initial pagination return FAILED.""" + mock_client = Mock(spec=DurableServiceClient) + non_retriable_error = GetExecutionStateError( + message=f"{error_code} error", + error_category=error_category, + error={"Code": error_code, "Message": f"{error_code} error"}, + response_metadata={"HTTPStatusCode": status_code}, + ) + mock_client.get_execution_state.side_effect = non_retriable_error + + @durable_execution + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client, next_marker="next-page-marker"), + _make_lambda_context(), + ) + assert result["Status"] == InvocationStatus.FAILED.value + assert result["Error"]["ErrorType"] == "GetExecutionStateError" + + +def test_durable_execution_retriable_initial_pagination_error_raises(): + """Test that retriable error during initial pagination raises to trigger Lambda retry.""" + mock_client = Mock(spec=DurableServiceClient) + retriable_error = GetExecutionStateError( + message="Service error", + error={"Code": "ServiceException", "Message": "Internal error"}, + response_metadata={"HTTPStatusCode": 500}, + ) + mock_client.get_execution_state.side_effect = retriable_error + + @durable_execution + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + with pytest.raises(GetExecutionStateError, match="Service error"): + test_handler( + _make_invocation_input(mock_client, next_marker="next-page-marker"), + _make_lambda_context(), + ) \ No newline at end of file From e66458bd8a29404cdb0a57d362009a0c29637334 Mon Sep 17 00:00:00 2001 From: Denzyl Holder Layne Date: Tue, 28 Apr 2026 12:11:22 -0400 Subject: [PATCH 2/5] Use retryable instead of retriable. --- .../exceptions.py | 26 ++++---- .../execution.py | 20 +++--- tests/exceptions_test.py | 64 +++++++++---------- tests/execution_test.py | 42 ++++++------ 4 files changed, 76 insertions(+), 76 deletions(-) diff --git a/src/aws_durable_execution_sdk_python/exceptions.py b/src/aws_durable_execution_sdk_python/exceptions.py index e88df53..ec5799d 100644 --- a/src/aws_durable_execution_sdk_python/exceptions.py +++ b/src/aws_durable_execution_sdk_python/exceptions.py @@ -16,12 +16,12 @@ INVALID_PARAMETER_VALUE_EXCEPTION: str = "InvalidParameterValueException" INVALID_CHECKPOINT_TOKEN_PREFIX: str = "Invalid Checkpoint Token" -# Non-retriable customer error codes that arrive as non-4xx (e.g. HTTP 502) from Lambda. +# Non-retryable customer error codes that arrive as non-4xx (e.g. HTTP 502) from Lambda. # Unlike typical 5xx errors, these require customer intervention (e.g., fixing # a KMS key configuration) and will never succeed on retry. -# Add new non-retriable error codes here — they are automatically classified -# as EXECUTION (non-retriable) by _classify_error_category(). -_NON_RETRIABLE_CUSTOMER_ERROR_CODES: frozenset[str] = frozenset( +# Add new non-retryable error codes here — they are automatically classified +# as EXECUTION (non-retryable) by _classify_error_category(). +_NON_RETRYABLE_CUSTOMER_ERROR_CODES: frozenset[str] = frozenset( { "KMSAccessDeniedException", "KMSDisabledException", @@ -93,8 +93,8 @@ def __init__( ): super().__init__(message, termination_reason) - def is_retriable(self) -> bool: - """Whether this error is retriable. Returns True by default. + def is_retryable(self) -> bool: + """Whether this error is retryable. Returns True by default. Subclasses override to implement classification logic based on error codes and HTTP status codes. @@ -123,9 +123,9 @@ class BotoClientError(InvocationError): """Error from a Lambda API call (e.g., CheckpointDurableExecution, GetDurableExecutionState). Extends InvocationError because the default behavior for API failures is to retry - the Lambda invocation. However, some errors are non-retriable (e.g., 4xx client errors, + the Lambda invocation. However, some errors are non-retryable (e.g., 4xx client errors, KMS key misconfiguration) and should fail the execution instead. The error_category field - and is_retriable() method distinguish these cases at runtime. + and is_retryable() method distinguish these cases at runtime. """ def __init__( @@ -156,10 +156,10 @@ def _classify_error_category( error: AwsErrorObj | None, response_metadata: AwsErrorMetadata | None, ) -> DurableApiErrorCategory: - """Classify a Durable API error as retriable (INVOCATION) or non-retriable (EXECUTION). + """Classify a Durable API error as retryable (INVOCATION) or non-retryable (EXECUTION). Classification rules: - - Non-retriable customer error codes (e.g., KMS key issues) → EXECUTION + - Non-retryable customer error codes (e.g., KMS key issues) → EXECUTION These arrive as HTTP 502 but require customer intervention to fix. - 4xx errors → EXECUTION, except: - 429 (TooManyRequests) → INVOCATION (throttling is transient) @@ -168,7 +168,7 @@ def _classify_error_category( - 5xx, network errors → INVOCATION """ error_code: str | None = (error and error.get("Code")) or None - if error_code and error_code in _NON_RETRIABLE_CUSTOMER_ERROR_CODES: + if error_code and error_code in _NON_RETRYABLE_CUSTOMER_ERROR_CODES: return DurableApiErrorCategory.EXECUTION status_code: int | None = (response_metadata and response_metadata.get("HTTPStatusCode")) or None @@ -186,8 +186,8 @@ def _classify_error_category( return DurableApiErrorCategory.INVOCATION - def is_retriable(self) -> bool: - """Whether this error is retriable based on error_category.""" + def is_retryable(self) -> bool: + """Whether this error is retryable based on error_category.""" return self.error_category == DurableApiErrorCategory.INVOCATION def build_logger_extras(self) -> dict: diff --git a/src/aws_durable_execution_sdk_python/execution.py b/src/aws_durable_execution_sdk_python/execution.py index e45ade7..8172263 100644 --- a/src/aws_durable_execution_sdk_python/execution.py +++ b/src/aws_durable_execution_sdk_python/execution.py @@ -267,11 +267,11 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: invocation_input.initial_execution_state.next_marker, ) except BotoClientError as e: - # Non-retriable Durable API errors (e.g., customer configuration issues, + # Non-retryable Durable API errors (e.g., customer configuration issues, # 4xx client errors) will never succeed on retry — fail the execution immediately. - if not e.is_retriable(): + if not e.is_retryable(): logger.exception( - "Non-retriable Durable API error during initial state fetch. Must fail execution " + "Non-retryable Durable API error during initial state fetch. Must fail execution " "without retry.", extra=e.build_logger_extras(), ) @@ -371,11 +371,11 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: "Checkpoint processing failed", extra=bg_error.source_exception.build_logger_extras(), ) - # Non-retriable Durable API errors (e.g., customer configuration issues, + # Non-retryable Durable API errors (e.g., customer configuration issues, # 4xx client errors) will never succeed on retry — fail the execution immediately. - if not bg_error.source_exception.is_retriable(): + if not bg_error.source_exception.is_retryable(): logger.exception( - "Non-retriable Durable API error from background thread. Must fail execution " + "Non-retryable Durable API error from background thread. Must fail execution " "without retry.", extra=bg_error.source_exception.build_logger_extras(), ) @@ -402,11 +402,11 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: ) return handle_checkpoint_error(e).to_dict() except InvocationError as e: - # Non-retriable Durable API errors (e.g., customer configuration issues, + # Non-retryable Durable API errors (e.g., customer configuration issues, # 4xx client errors) will never succeed on retry — fail the execution immediately. - if not e.is_retriable(): + if not e.is_retryable(): logger.exception( - "Non-retriable Durable API error. Must fail execution without retry.", + "Non-retryable Durable API error. Must fail execution without retry.", extra=e.build_logger_extras(), # type: ignore[attr-defined] ) return DurableExecutionInvocationOutput( @@ -463,7 +463,7 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: def handle_checkpoint_error(error: CheckpointError) -> DurableExecutionInvocationOutput: - if error.is_retriable(): + if error.is_retryable(): raise error from None # Terminate Lambda immediately and have it be retried return DurableExecutionInvocationOutput( status=InvocationStatus.FAILED, error=ErrorObject.from_exception(error) diff --git a/tests/exceptions_test.py b/tests/exceptions_test.py index f0c0b70..5e158d2 100644 --- a/tests/exceptions_test.py +++ b/tests/exceptions_test.py @@ -70,7 +70,7 @@ def test_checkpoint_error_classification_invalid_token_invocation(): result = CheckpointError.from_exception(client_error) assert result.error_category == CheckpointErrorCategory.INVOCATION - assert result.is_retriable() + assert result.is_retryable() def test_checkpoint_error_classification_payload_size_exceeded_execution(): @@ -87,7 +87,7 @@ def test_checkpoint_error_classification_payload_size_exceeded_execution(): result = CheckpointError.from_exception(client_error) assert result.error_category == CheckpointErrorCategory.EXECUTION - assert not result.is_retriable() + assert not result.is_retryable() def test_checkpoint_error_classification_invalid_param_without_token_execution(): @@ -104,7 +104,7 @@ def test_checkpoint_error_classification_invalid_param_without_token_execution() result = CheckpointError.from_exception(client_error) assert result.error_category == CheckpointErrorCategory.EXECUTION - assert not result.is_retriable() + assert not result.is_retryable() # ============================================================================= @@ -123,8 +123,8 @@ def test_checkpoint_error_classification_invalid_param_without_token_execution() "KMSNotFoundException", ], ) -def test_durable_api_error_non_retriable_customer_error_codes(error_cls, error_code: str): - """Test that non-retriable customer error codes (HTTP 502) are classified as EXECUTION.""" +def test_durable_api_error_non_retryable_customer_error_codes(error_cls, error_code: str): + """Test that non-retryable customer error codes (HTTP 502) are classified as EXECUTION.""" error_response = { "Error": {"Code": error_code, "Message": f"{error_code} error"}, "ResponseMetadata": {"HTTPStatusCode": 502}, @@ -132,12 +132,12 @@ def test_durable_api_error_non_retriable_customer_error_codes(error_cls, error_c client_error = ClientError(error_response, "Invoke") result = error_cls.from_exception(client_error) assert result.error_category == DurableApiErrorCategory.EXECUTION - assert not result.is_retriable() + assert not result.is_retryable() @pytest.mark.parametrize("error_cls", [CheckpointError, GetExecutionStateError]) -def test_durable_api_error_4xx_non_retriable(error_cls): - """Test 4xx errors are classified as EXECUTION (non-retriable).""" +def test_durable_api_error_4xx_non_retryable(error_cls): + """Test 4xx errors are classified as EXECUTION (non-retryable).""" error_response = { "Error": {"Code": "ValidationException", "Message": "Invalid parameter"}, "ResponseMetadata": {"HTTPStatusCode": 400}, @@ -145,12 +145,12 @@ def test_durable_api_error_4xx_non_retriable(error_cls): client_error = ClientError(error_response, "Invoke") result = error_cls.from_exception(client_error) assert result.error_category == DurableApiErrorCategory.EXECUTION - assert not result.is_retriable() + assert not result.is_retryable() @pytest.mark.parametrize("error_cls", [CheckpointError, GetExecutionStateError]) -def test_durable_api_error_429_retriable(error_cls): - """Test 429 errors are classified as INVOCATION (retriable).""" +def test_durable_api_error_429_retryable(error_cls): + """Test 429 errors are classified as INVOCATION (retryable).""" error_response = { "Error": {"Code": "TooManyRequestsException", "Message": "Rate limit exceeded"}, "ResponseMetadata": {"HTTPStatusCode": 429}, @@ -158,12 +158,12 @@ def test_durable_api_error_429_retriable(error_cls): client_error = ClientError(error_response, "Invoke") result = error_cls.from_exception(client_error) assert result.error_category == DurableApiErrorCategory.INVOCATION - assert result.is_retriable() + assert result.is_retryable() @pytest.mark.parametrize("error_cls", [CheckpointError, GetExecutionStateError]) -def test_durable_api_error_5xx_retriable(error_cls): - """Test 5xx errors are classified as INVOCATION (retriable).""" +def test_durable_api_error_5xx_retryable(error_cls): + """Test 5xx errors are classified as INVOCATION (retryable).""" error_response = { "Error": {"Code": "InternalServerError", "Message": "Service unavailable"}, "ResponseMetadata": {"HTTPStatusCode": 500}, @@ -171,12 +171,12 @@ def test_durable_api_error_5xx_retriable(error_cls): client_error = ClientError(error_response, "Invoke") result = error_cls.from_exception(client_error) assert result.error_category == DurableApiErrorCategory.INVOCATION - assert result.is_retriable() + assert result.is_retryable() @pytest.mark.parametrize("error_cls", [CheckpointError, GetExecutionStateError]) -def test_durable_api_error_retriable_502(error_cls): - """Test that 502 errors with unrecognized error codes are retriable.""" +def test_durable_api_error_retryable_502(error_cls): + """Test that 502 errors with unrecognized error codes are retryable.""" error_response = { "Error": { "Code": "ServiceException", @@ -187,15 +187,15 @@ def test_durable_api_error_retriable_502(error_cls): client_error = ClientError(error_response, "Invoke") result = error_cls.from_exception(client_error) assert result.error_category == DurableApiErrorCategory.INVOCATION - assert result.is_retriable() + assert result.is_retryable() @pytest.mark.parametrize("error_cls", [CheckpointError, GetExecutionStateError]) -def test_durable_api_error_unknown_retriable(error_cls): - """Test unknown errors (no HTTP response) are classified as INVOCATION (retriable).""" +def test_durable_api_error_unknown_retryable(error_cls): + """Test unknown errors (no HTTP response) are classified as INVOCATION (retryable).""" result = error_cls.from_exception(Exception("Network timeout")) assert result.error_category == DurableApiErrorCategory.INVOCATION - assert result.is_retriable() + assert result.is_retryable() def test_validation_error(): @@ -436,28 +436,28 @@ def test_orphaned_child_exception_with_operation_id(): @pytest.mark.parametrize( - ("error_code", "status_code", "expected_retriable"), + ("error_code", "status_code", "expected_retryable"), [ ("KMSAccessDeniedException", 502, False), ("ServiceException", 500, True), ("ServiceException", 502, True), ], ) -def test_boto_client_error_is_retriable(error_code: str, status_code: int, expected_retriable: bool): - """Test BotoClientError.is_retriable() classification.""" +def test_boto_client_error_is_retryable(error_code: str, status_code: int, expected_retryable: bool): + """Test BotoClientError.is_retryable() classification.""" error_response = { "Error": {"Code": error_code, "Message": "test error"}, "ResponseMetadata": {"HTTPStatusCode": status_code}, } client_error = ClientError(error_response, "Invoke") result = BotoClientError.from_exception(client_error) - assert result.is_retriable() == expected_retriable + assert result.is_retryable() == expected_retryable -def test_boto_client_error_is_retriable_no_error(): - """Test BotoClientError.is_retriable() returns True with no error info.""" +def test_boto_client_error_is_retryable_no_error(): + """Test BotoClientError.is_retryable() returns True with no error info.""" result = BotoClientError.from_exception(Exception("network error")) - assert result.is_retriable() + assert result.is_retryable() # ============================================================================= @@ -473,11 +473,11 @@ def test_durable_api_error_category_backward_compatible_alias(): # ============================================================================= -# is_retriable() tests +# is_retryable() tests # ============================================================================= -def test_invocation_error_is_retriable_default(): - """Test InvocationError.is_retriable() returns True by default.""" +def test_invocation_error_is_retryable_default(): + """Test InvocationError.is_retryable() returns True by default.""" error = InvocationError("some error") - assert error.is_retriable() \ No newline at end of file + assert error.is_retryable() \ No newline at end of file diff --git a/tests/execution_test.py b/tests/execution_test.py index 605b467..ecdac4c 100644 --- a/tests/execution_test.py +++ b/tests/execution_test.py @@ -2652,7 +2652,7 @@ def test_from_dict_leaves_timestamps_as_integers(): # ============================================================================= -# Non-retriable Durable API error handling tests +# Non-retryable Durable API error handling tests # ============================================================================= @@ -2684,10 +2684,10 @@ def _make_lambda_context(): return ctx -def test_durable_execution_non_retriable_invocation_error_returns_failed(): - """Test that non-retriable InvocationError returns FAILED instead of retrying.""" +def test_durable_execution_non_retryable_invocation_error_returns_failed(): + """Test that non-retryable InvocationError returns FAILED instead of retrying.""" mock_client = Mock(spec=DurableServiceClient) - non_retriable_error = GetExecutionStateError( + non_retryable_error = GetExecutionStateError( message="KMS access denied", error_category=DurableApiErrorCategory.EXECUTION, error={"Code": "KMSAccessDeniedException", "Message": "KMS access denied"}, @@ -2696,17 +2696,17 @@ def test_durable_execution_non_retriable_invocation_error_returns_failed(): @durable_execution def test_handler(event: Any, context: DurableContext) -> dict: - raise non_retriable_error + raise non_retryable_error result = test_handler(_make_invocation_input(mock_client), _make_lambda_context()) assert result["Status"] == InvocationStatus.FAILED.value assert result["Error"]["ErrorType"] == "GetExecutionStateError" -def test_durable_execution_retriable_invocation_error_raises(): - """Test that retriable InvocationError raises to trigger Lambda retry.""" +def test_durable_execution_retryable_invocation_error_raises(): + """Test that retryable InvocationError raises to trigger Lambda retry.""" mock_client = Mock(spec=DurableServiceClient) - retriable_error = GetExecutionStateError( + retryable_error = GetExecutionStateError( message="Service error", error={"Code": "ServiceException", "Message": "Internal error"}, response_metadata={"HTTPStatusCode": 500}, @@ -2714,22 +2714,22 @@ def test_durable_execution_retriable_invocation_error_raises(): @durable_execution def test_handler(event: Any, context: DurableContext) -> dict: - raise retriable_error + raise retryable_error with pytest.raises(GetExecutionStateError, match="Service error"): test_handler(_make_invocation_input(mock_client), _make_lambda_context()) -def test_durable_execution_non_retriable_background_thread_error_returns_failed(): - """Test that non-retriable error from background thread returns FAILED.""" +def test_durable_execution_non_retryable_background_thread_error_returns_failed(): + """Test that non-retryable error from background thread returns FAILED.""" mock_client = Mock(spec=DurableServiceClient) - non_retriable_error = GetExecutionStateError( + non_retryable_error = GetExecutionStateError( message="KMS key disabled", error_category=DurableApiErrorCategory.EXECUTION, error={"Code": "KMSDisabledException", "Message": "KMS key disabled"}, response_metadata={"HTTPStatusCode": 502}, ) - mock_client.checkpoint.side_effect = lambda *a, **kw: (_ for _ in ()).throw(non_retriable_error) + mock_client.checkpoint.side_effect = lambda *a, **kw: (_ for _ in ()).throw(non_retryable_error) @durable_execution def test_handler(event: Any, context: DurableContext) -> dict: @@ -2748,18 +2748,18 @@ def test_handler(event: Any, context: DurableContext) -> dict: ("ValidationException", 400, DurableApiErrorCategory.EXECUTION), ], ) -def test_durable_execution_non_retriable_initial_pagination_error_returns_failed( +def test_durable_execution_non_retryable_initial_pagination_error_returns_failed( error_code: str, status_code: int, error_category: DurableApiErrorCategory ): - """Test that non-retriable errors during initial pagination return FAILED.""" + """Test that non-retryable errors during initial pagination return FAILED.""" mock_client = Mock(spec=DurableServiceClient) - non_retriable_error = GetExecutionStateError( + non_retryable_error = GetExecutionStateError( message=f"{error_code} error", error_category=error_category, error={"Code": error_code, "Message": f"{error_code} error"}, response_metadata={"HTTPStatusCode": status_code}, ) - mock_client.get_execution_state.side_effect = non_retriable_error + mock_client.get_execution_state.side_effect = non_retryable_error @durable_execution def test_handler(event: Any, context: DurableContext) -> dict: @@ -2773,15 +2773,15 @@ def test_handler(event: Any, context: DurableContext) -> dict: assert result["Error"]["ErrorType"] == "GetExecutionStateError" -def test_durable_execution_retriable_initial_pagination_error_raises(): - """Test that retriable error during initial pagination raises to trigger Lambda retry.""" +def test_durable_execution_retryable_initial_pagination_error_raises(): + """Test that retryable error during initial pagination raises to trigger Lambda retry.""" mock_client = Mock(spec=DurableServiceClient) - retriable_error = GetExecutionStateError( + retryable_error = GetExecutionStateError( message="Service error", error={"Code": "ServiceException", "Message": "Internal error"}, response_metadata={"HTTPStatusCode": 500}, ) - mock_client.get_execution_state.side_effect = retriable_error + mock_client.get_execution_state.side_effect = retryable_error @durable_execution def test_handler(event: Any, context: DurableContext) -> dict: From e88faa9fda63d1fda010ebdf93e053a4e6151f2e Mon Sep 17 00:00:00 2001 From: Denzyl Holder Layne Date: Tue, 28 Apr 2026 12:17:21 -0400 Subject: [PATCH 3/5] Add backward-compatible is_retriable alias. --- src/aws_durable_execution_sdk_python/exceptions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/aws_durable_execution_sdk_python/exceptions.py b/src/aws_durable_execution_sdk_python/exceptions.py index ec5799d..cb340e7 100644 --- a/src/aws_durable_execution_sdk_python/exceptions.py +++ b/src/aws_durable_execution_sdk_python/exceptions.py @@ -190,6 +190,9 @@ def is_retryable(self) -> bool: """Whether this error is retryable based on error_category.""" return self.error_category == DurableApiErrorCategory.INVOCATION + # Backward-compatible alias + is_retriable = is_retryable + def build_logger_extras(self) -> dict: extras: dict = {} # preserve PascalCase to be consistent with other langauges From b76d2139eb0eed203a5f097ef1566d47a1a08f37 Mon Sep 17 00:00:00 2001 From: Denzyl Holder Layne Date: Tue, 28 Apr 2026 12:27:47 -0400 Subject: [PATCH 4/5] Fix formatting. --- .../exceptions.py | 17 +++++++++++++---- tests/exceptions_test.py | 14 +++++++++----- tests/execution_test.py | 12 ++++++++---- 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/src/aws_durable_execution_sdk_python/exceptions.py b/src/aws_durable_execution_sdk_python/exceptions.py index cb340e7..a3fe319 100644 --- a/src/aws_durable_execution_sdk_python/exceptions.py +++ b/src/aws_durable_execution_sdk_python/exceptions.py @@ -146,9 +146,14 @@ def from_exception(cls, exception: Exception) -> Self: response = getattr(exception, "response", {}) response_metadata = response.get("ResponseMetadata") error = response.get("Error") - error_category = BotoClientError._classify_error_category(error, response_metadata) + error_category = BotoClientError._classify_error_category( + error, response_metadata + ) return cls( - message=str(exception), error_category=error_category, error=error, response_metadata=response_metadata + message=str(exception), + error_category=error_category, + error=error, + response_metadata=response_metadata, ) @staticmethod @@ -171,7 +176,9 @@ def _classify_error_category( if error_code and error_code in _NON_RETRYABLE_CUSTOMER_ERROR_CODES: return DurableApiErrorCategory.EXECUTION - status_code: int | None = (response_metadata and response_metadata.get("HTTPStatusCode")) or None + status_code: int | None = ( + response_metadata and response_metadata.get("HTTPStatusCode") + ) or None if ( status_code and BAD_REQUEST_ERROR <= status_code < SERVICE_ERROR @@ -179,7 +186,9 @@ def _classify_error_category( and error and not ( (error.get("Code") or "") == INVALID_PARAMETER_VALUE_EXCEPTION - and (error.get("Message") or "").startswith(INVALID_CHECKPOINT_TOKEN_PREFIX) + and (error.get("Message") or "").startswith( + INVALID_CHECKPOINT_TOKEN_PREFIX + ) ) ): return DurableApiErrorCategory.EXECUTION diff --git a/tests/exceptions_test.py b/tests/exceptions_test.py index 5e158d2..a18d534 100644 --- a/tests/exceptions_test.py +++ b/tests/exceptions_test.py @@ -123,13 +123,15 @@ def test_checkpoint_error_classification_invalid_param_without_token_execution() "KMSNotFoundException", ], ) -def test_durable_api_error_non_retryable_customer_error_codes(error_cls, error_code: str): +def test_durable_api_error_non_retryable_customer_error_codes( + error_cls, error_code: str +): """Test that non-retryable customer error codes (HTTP 502) are classified as EXECUTION.""" error_response = { "Error": {"Code": error_code, "Message": f"{error_code} error"}, "ResponseMetadata": {"HTTPStatusCode": 502}, } - client_error = ClientError(error_response, "Invoke") + client_error = ClientError(error_response, "Invoke") # type: ignore[arg-type] result = error_cls.from_exception(client_error) assert result.error_category == DurableApiErrorCategory.EXECUTION assert not result.is_retryable() @@ -443,13 +445,15 @@ def test_orphaned_child_exception_with_operation_id(): ("ServiceException", 502, True), ], ) -def test_boto_client_error_is_retryable(error_code: str, status_code: int, expected_retryable: bool): +def test_boto_client_error_is_retryable( + error_code: str, status_code: int, expected_retryable: bool +): """Test BotoClientError.is_retryable() classification.""" error_response = { "Error": {"Code": error_code, "Message": "test error"}, "ResponseMetadata": {"HTTPStatusCode": status_code}, } - client_error = ClientError(error_response, "Invoke") + client_error = ClientError(error_response, "Invoke") # type: ignore[arg-type] result = BotoClientError.from_exception(client_error) assert result.is_retryable() == expected_retryable @@ -480,4 +484,4 @@ def test_durable_api_error_category_backward_compatible_alias(): def test_invocation_error_is_retryable_default(): """Test InvocationError.is_retryable() returns True by default.""" error = InvocationError("some error") - assert error.is_retryable() \ No newline at end of file + assert error.is_retryable() diff --git a/tests/execution_test.py b/tests/execution_test.py index ecdac4c..e247d01 100644 --- a/tests/execution_test.py +++ b/tests/execution_test.py @@ -2667,7 +2667,9 @@ def _make_invocation_input(mock_client, next_marker=""): return DurableExecutionInvocationInputWithClient( durable_execution_arn="arn:test:execution", checkpoint_token="token123", # noqa: S106 - initial_execution_state=InitialExecutionState(operations=[operation], next_marker=next_marker), + initial_execution_state=InitialExecutionState( + operations=[operation], next_marker=next_marker + ), service_client=mock_client, ) @@ -2729,7 +2731,9 @@ def test_durable_execution_non_retryable_background_thread_error_returns_failed( error={"Code": "KMSDisabledException", "Message": "KMS key disabled"}, response_metadata={"HTTPStatusCode": 502}, ) - mock_client.checkpoint.side_effect = lambda *a, **kw: (_ for _ in ()).throw(non_retryable_error) + mock_client.checkpoint.side_effect = lambda *a, **kw: (_ for _ in ()).throw( + non_retryable_error + ) @durable_execution def test_handler(event: Any, context: DurableContext) -> dict: @@ -2757,7 +2761,7 @@ def test_durable_execution_non_retryable_initial_pagination_error_returns_failed message=f"{error_code} error", error_category=error_category, error={"Code": error_code, "Message": f"{error_code} error"}, - response_metadata={"HTTPStatusCode": status_code}, + response_metadata={"HTTPStatusCode": status_code}, # type: ignore[arg-type] ) mock_client.get_execution_state.side_effect = non_retryable_error @@ -2791,4 +2795,4 @@ def test_handler(event: Any, context: DurableContext) -> dict: test_handler( _make_invocation_input(mock_client, next_marker="next-page-marker"), _make_lambda_context(), - ) \ No newline at end of file + ) From ad94b0880e5e7d561df0215d69300dae14fb0082 Mon Sep 17 00:00:00 2001 From: Denzyl Holder Layne Date: Wed, 29 Apr 2026 11:46:02 -0400 Subject: [PATCH 5/5] Log Durable API errors during state fetch, and save partially fetched state. --- src/aws_durable_execution_sdk_python/state.py | 37 ++++++-- tests/state_test.py | 84 +++++++++++++++++++ 2 files changed, 112 insertions(+), 9 deletions(-) diff --git a/src/aws_durable_execution_sdk_python/state.py b/src/aws_durable_execution_sdk_python/state.py index 48076c9..0d9cb0f 100644 --- a/src/aws_durable_execution_sdk_python/state.py +++ b/src/aws_durable_execution_sdk_python/state.py @@ -16,6 +16,7 @@ BackgroundThreadError, CallableRuntimeError, DurableExecutionsError, + GetExecutionStateError, OrphanedChildException, ) from aws_durable_execution_sdk_python.lambda_service import ( @@ -275,20 +276,38 @@ def fetch_paginated_operations( initial_operations: initial operations to be added to ExecutionState checkpoint_token: checkpoint token used to call Durable Functions API. next_marker: a marker indicates that there are paginated operations. + + Raises: + GetExecutionStateError: If the API call fails. The error is logged + with structured extras before re-raising. Callers are responsible + for deciding whether to fail the execution or allow Lambda retry + based on is_retryable(). """ all_operations: list[Operation] = ( initial_operations.copy() if initial_operations else [] ) - while next_marker: - output: StateOutput = self._service_client.get_execution_state( - durable_execution_arn=self.durable_execution_arn, - checkpoint_token=checkpoint_token, - next_marker=next_marker, + try: + while next_marker: + output: StateOutput = self._service_client.get_execution_state( + durable_execution_arn=self.durable_execution_arn, + checkpoint_token=checkpoint_token, + next_marker=next_marker, + ) + all_operations.extend(output.operations) + next_marker = output.next_marker + except GetExecutionStateError as e: + logger.exception( + "Durable API error during state fetch.", + extra=e.build_logger_extras(), ) - all_operations.extend(output.operations) - next_marker = output.next_marker - with self._operations_lock: - self.operations.update({op.operation_id: op for op in all_operations}) + raise + finally: + # Always store whatever operations we successfully fetched + if all_operations: + with self._operations_lock: + self.operations.update( + {op.operation_id: op for op in all_operations} + ) def get_input_payload(self) -> str | None: # It is possible that backend will not provide an execution operation diff --git a/tests/state_test.py b/tests/state_test.py index 19d641e..2ea0ab3 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -16,6 +16,8 @@ from aws_durable_execution_sdk_python.exceptions import ( BackgroundThreadError, CallableRuntimeError, + DurableApiErrorCategory, + GetExecutionStateError, OrphanedChildException, ) from aws_durable_execution_sdk_python.identifier import OperationIdentifier @@ -724,6 +726,88 @@ def mock_get_execution_state(durable_execution_arn, checkpoint_token, next_marke assert operation.operation_id == expected_op.operation_id +def test_fetch_paginated_operations_stores_partial_results_on_error(): + """Test that operations from successful pages are stored even when a later page fails.""" + mock_lambda_client = Mock(spec=LambdaClient) + + non_retryable_error = GetExecutionStateError( + message="KMS access denied", + error_category=DurableApiErrorCategory.EXECUTION, + error={"Code": "KMSAccessDeniedException", "Message": "KMS access denied"}, + response_metadata={"HTTPStatusCode": 502}, + ) + + def mock_get_execution_state(durable_execution_arn, checkpoint_token, next_marker): + if next_marker == "marker1": + return StateOutput( + operations=[ + Operation( + operation_id="1", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ) + ], + next_marker="marker2", + ) + raise non_retryable_error + + mock_lambda_client.get_execution_state.side_effect = mock_get_execution_state + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_lambda_client, + ) + + with pytest.raises(GetExecutionStateError): + state.fetch_paginated_operations( + initial_operations=[ + Operation( + operation_id="0", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ) + ], + checkpoint_token="test_token", # noqa: S106 + next_marker="marker1", + ) + + # Initial operation + page 1 should be stored despite page 2 failing + assert "0" in state.operations + assert "1" in state.operations + assert len(state.operations) == 2 + + +def test_fetch_paginated_operations_logs_error(caplog): + """Test that GetExecutionStateError is logged with structured extras.""" + mock_lambda_client = Mock(spec=LambdaClient) + + error = GetExecutionStateError( + message="Service error", + error_category=DurableApiErrorCategory.INVOCATION, + error={"Code": "ServiceException", "Message": "Service error"}, + response_metadata={"HTTPStatusCode": 500}, + ) + mock_lambda_client.get_execution_state.side_effect = error + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_lambda_client, + ) + + with pytest.raises(GetExecutionStateError): + state.fetch_paginated_operations( + initial_operations=[], + checkpoint_token="test_token", # noqa: S106 + next_marker="marker1", + ) + + assert "Durable API error during state fetch." in caplog.text + + # ============================================================================ # Checkpoint Batching Tests # ============================================================================