diff --git a/src/aws_durable_execution_sdk_python/exceptions.py b/src/aws_durable_execution_sdk_python/exceptions.py index 336996c..a3fe319 100644 --- a/src/aws_durable_execution_sdk_python/exceptions.py +++ b/src/aws_durable_execution_sdk_python/exceptions.py @@ -13,6 +13,22 @@ BAD_REQUEST_ERROR: int = 400 TOO_MANY_REQUESTS_ERROR: int = 429 SERVICE_ERROR: int = 500 +INVALID_PARAMETER_VALUE_EXCEPTION: str = "InvalidParameterValueException" +INVALID_CHECKPOINT_TOKEN_PREFIX: str = "Invalid Checkpoint Token" + +# Non-retryable customer error codes that arrive as non-4xx (e.g. HTTP 502) from Lambda. +# Unlike typical 5xx errors, these require customer intervention (e.g., fixing +# a KMS key configuration) and will never succeed on retry. +# Add new non-retryable error codes here — they are automatically classified +# as EXECUTION (non-retryable) by _classify_error_category(). +_NON_RETRYABLE_CUSTOMER_ERROR_CODES: frozenset[str] = frozenset( + { + "KMSAccessDeniedException", + "KMSDisabledException", + "KMSInvalidStateException", + "KMSNotFoundException", + } +) if TYPE_CHECKING: import datetime @@ -77,6 +93,14 @@ def __init__( ): super().__init__(message, termination_reason) + def is_retryable(self) -> bool: + """Whether this error is retryable. Returns True by default. + + Subclasses override to implement classification logic based on + error codes and HTTP status codes. + """ + return True + class CallbackError(ExecutionError): """Error in callback handling.""" @@ -86,10 +110,28 @@ def __init__(self, message: str, callback_id: str | None = None): self.callback_id = callback_id +class DurableApiErrorCategory(Enum): + INVOCATION = "INVOCATION" + EXECUTION = "EXECUTION" + + +# Backward-compatible alias +CheckpointErrorCategory = DurableApiErrorCategory + + class BotoClientError(InvocationError): + """Error from a Lambda API call (e.g., CheckpointDurableExecution, GetDurableExecutionState). + + Extends InvocationError because the default behavior for API failures is to retry + the Lambda invocation. However, some errors are non-retryable (e.g., 4xx client errors, + KMS key misconfiguration) and should fail the execution instead. The error_category field + and is_retryable() method distinguish these cases at runtime. + """ + def __init__( self, message: str, + error_category: DurableApiErrorCategory = DurableApiErrorCategory.INVOCATION, error: AwsErrorObj | None = None, response_metadata: AwsErrorMetadata | None = None, termination_reason=TerminationReason.INVOCATION_ERROR, @@ -97,16 +139,69 @@ def __init__( super().__init__(message=message, termination_reason=termination_reason) self.error: AwsErrorObj | None = error self.response_metadata: AwsErrorMetadata | None = response_metadata + self.error_category: DurableApiErrorCategory = error_category @classmethod def from_exception(cls, exception: Exception) -> Self: response = getattr(exception, "response", {}) response_metadata = response.get("ResponseMetadata") error = response.get("Error") + error_category = BotoClientError._classify_error_category( + error, response_metadata + ) return cls( - message=str(exception), error=error, response_metadata=response_metadata + message=str(exception), + error_category=error_category, + error=error, + response_metadata=response_metadata, ) + @staticmethod + def _classify_error_category( + error: AwsErrorObj | None, + response_metadata: AwsErrorMetadata | None, + ) -> DurableApiErrorCategory: + """Classify a Durable API error as retryable (INVOCATION) or non-retryable (EXECUTION). + + Classification rules: + - Non-retryable customer error codes (e.g., KMS key issues) → EXECUTION + These arrive as HTTP 502 but require customer intervention to fix. + - 4xx errors → EXECUTION, except: + - 429 (TooManyRequests) → INVOCATION (throttling is transient) + - InvalidParameterValueException with "Invalid Checkpoint Token" → INVOCATION + (stale token from a concurrent checkpoint; next invocation gets a fresh token) + - 5xx, network errors → INVOCATION + """ + error_code: str | None = (error and error.get("Code")) or None + if error_code and error_code in _NON_RETRYABLE_CUSTOMER_ERROR_CODES: + return DurableApiErrorCategory.EXECUTION + + status_code: int | None = ( + response_metadata and response_metadata.get("HTTPStatusCode") + ) or None + if ( + status_code + and BAD_REQUEST_ERROR <= status_code < SERVICE_ERROR + and status_code != TOO_MANY_REQUESTS_ERROR + and error + and not ( + (error.get("Code") or "") == INVALID_PARAMETER_VALUE_EXCEPTION + and (error.get("Message") or "").startswith( + INVALID_CHECKPOINT_TOKEN_PREFIX + ) + ) + ): + return DurableApiErrorCategory.EXECUTION + + return DurableApiErrorCategory.INVOCATION + + def is_retryable(self) -> bool: + """Whether this error is retryable based on error_category.""" + return self.error_category == DurableApiErrorCategory.INVOCATION + + # Backward-compatible alias + is_retriable = is_retryable + def build_logger_extras(self) -> dict: extras: dict = {} # preserve PascalCase to be consistent with other langauges @@ -125,55 +220,23 @@ def __init__(self, message: str, step_id: str | None = None): self.step_id = step_id -class CheckpointErrorCategory(Enum): - INVOCATION = "INVOCATION" - EXECUTION = "EXECUTION" - - class CheckpointError(BotoClientError): """Failure to checkpoint. Will terminate the lambda.""" def __init__( self, message: str, - error_category: CheckpointErrorCategory, + error_category: DurableApiErrorCategory = DurableApiErrorCategory.INVOCATION, error: AwsErrorObj | None = None, response_metadata: AwsErrorMetadata | None = None, ): super().__init__( message, + error_category, error, response_metadata, termination_reason=TerminationReason.CHECKPOINT_FAILED, ) - self.error_category: CheckpointErrorCategory = error_category - - @classmethod - def from_exception(cls, exception: Exception) -> CheckpointError: - base = BotoClientError.from_exception(exception) - metadata: AwsErrorMetadata | None = base.response_metadata - error: AwsErrorObj | None = base.error - error_category: CheckpointErrorCategory = CheckpointErrorCategory.INVOCATION - - # 4xx errors (except 429) are permanent failures (EXECUTION), unless it's an - # InvalidParameterValueException with "Invalid Checkpoint Token" which is retriable (INVOCATION). - # 5xx, 429, and network errors are retriable (INVOCATION). - status_code: int | None = (metadata and metadata.get("HTTPStatusCode")) or None - if ( - status_code - and BAD_REQUEST_ERROR <= status_code < SERVICE_ERROR - and status_code != TOO_MANY_REQUESTS_ERROR - and error - and not ( - (error.get("Code") or "") == "InvalidParameterValueException" - and (error.get("Message") or "").startswith("Invalid Checkpoint Token") - ) - ): - error_category = CheckpointErrorCategory.EXECUTION - return CheckpointError(str(exception), error_category, error, metadata) - - def is_retriable(self): - return self.error_category == CheckpointErrorCategory.INVOCATION class ValidationError(DurableExecutionsError): @@ -186,11 +249,13 @@ class GetExecutionStateError(BotoClientError): def __init__( self, message: str, + error_category: DurableApiErrorCategory = DurableApiErrorCategory.INVOCATION, error: AwsErrorObj | None = None, response_metadata: AwsErrorMetadata | None = None, ): super().__init__( message, + error_category, error, response_metadata, termination_reason=TerminationReason.INVOCATION_ERROR, diff --git a/src/aws_durable_execution_sdk_python/execution.py b/src/aws_durable_execution_sdk_python/execution.py index 403111d..8172263 100644 --- a/src/aws_durable_execution_sdk_python/execution.py +++ b/src/aws_durable_execution_sdk_python/execution.py @@ -260,11 +260,26 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: else ReplayStatus.NEW, ) - execution_state.fetch_paginated_operations( - invocation_input.initial_execution_state.operations, - invocation_input.checkpoint_token, - invocation_input.initial_execution_state.next_marker, - ) + try: + execution_state.fetch_paginated_operations( + invocation_input.initial_execution_state.operations, + invocation_input.checkpoint_token, + invocation_input.initial_execution_state.next_marker, + ) + except BotoClientError as e: + # Non-retryable Durable API errors (e.g., customer configuration issues, + # 4xx client errors) will never succeed on retry — fail the execution immediately. + if not e.is_retryable(): + logger.exception( + "Non-retryable Durable API error during initial state fetch. Must fail execution " + "without retry.", + extra=e.build_logger_extras(), + ) + return DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, + error=ErrorObject.from_exception(e), + ).to_dict() + raise raw_input_payload: str | None = execution_state.get_input_payload() @@ -356,11 +371,20 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: "Checkpoint processing failed", extra=bg_error.source_exception.build_logger_extras(), ) + # Non-retryable Durable API errors (e.g., customer configuration issues, + # 4xx client errors) will never succeed on retry — fail the execution immediately. + if not bg_error.source_exception.is_retryable(): + logger.exception( + "Non-retryable Durable API error from background thread. Must fail execution " + "without retry.", + extra=bg_error.source_exception.build_logger_extras(), + ) + return DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, + error=ErrorObject.from_exception(bg_error.source_exception), + ).to_dict() else: logger.exception("Checkpoint processing failed") - # handle the original exception - if isinstance(bg_error.source_exception, CheckpointError): - return handle_checkpoint_error(bg_error.source_exception).to_dict() raise bg_error.source_exception from bg_error except SuspendExecution: @@ -377,12 +401,23 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: extra=e.build_logger_extras(), ) return handle_checkpoint_error(e).to_dict() - except InvocationError: + except InvocationError as e: + # Non-retryable Durable API errors (e.g., customer configuration issues, + # 4xx client errors) will never succeed on retry — fail the execution immediately. + if not e.is_retryable(): + logger.exception( + "Non-retryable Durable API error. Must fail execution without retry.", + extra=e.build_logger_extras(), # type: ignore[attr-defined] + ) + return DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, + error=ErrorObject.from_exception(e), + ).to_dict() logger.exception("Invocation error. Must terminate.") # Throw the error to trigger Lambda retry raise except ExecutionError as e: - logger.exception("Execution error. Must terminate without retry.") + logger.exception("Execution error. Must fail execution without retry.") return DurableExecutionInvocationOutput( status=InvocationStatus.FAILED, error=ErrorObject.from_exception(e), @@ -428,7 +463,7 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: def handle_checkpoint_error(error: CheckpointError) -> DurableExecutionInvocationOutput: - if error.is_retriable(): + if error.is_retryable(): raise error from None # Terminate Lambda immediately and have it be retried return DurableExecutionInvocationOutput( status=InvocationStatus.FAILED, error=ErrorObject.from_exception(error) diff --git a/src/aws_durable_execution_sdk_python/state.py b/src/aws_durable_execution_sdk_python/state.py index 48076c9..0d9cb0f 100644 --- a/src/aws_durable_execution_sdk_python/state.py +++ b/src/aws_durable_execution_sdk_python/state.py @@ -16,6 +16,7 @@ BackgroundThreadError, CallableRuntimeError, DurableExecutionsError, + GetExecutionStateError, OrphanedChildException, ) from aws_durable_execution_sdk_python.lambda_service import ( @@ -275,20 +276,38 @@ def fetch_paginated_operations( initial_operations: initial operations to be added to ExecutionState checkpoint_token: checkpoint token used to call Durable Functions API. next_marker: a marker indicates that there are paginated operations. + + Raises: + GetExecutionStateError: If the API call fails. The error is logged + with structured extras before re-raising. Callers are responsible + for deciding whether to fail the execution or allow Lambda retry + based on is_retryable(). """ all_operations: list[Operation] = ( initial_operations.copy() if initial_operations else [] ) - while next_marker: - output: StateOutput = self._service_client.get_execution_state( - durable_execution_arn=self.durable_execution_arn, - checkpoint_token=checkpoint_token, - next_marker=next_marker, + try: + while next_marker: + output: StateOutput = self._service_client.get_execution_state( + durable_execution_arn=self.durable_execution_arn, + checkpoint_token=checkpoint_token, + next_marker=next_marker, + ) + all_operations.extend(output.operations) + next_marker = output.next_marker + except GetExecutionStateError as e: + logger.exception( + "Durable API error during state fetch.", + extra=e.build_logger_extras(), ) - all_operations.extend(output.operations) - next_marker = output.next_marker - with self._operations_lock: - self.operations.update({op.operation_id: op for op in all_operations}) + raise + finally: + # Always store whatever operations we successfully fetched + if all_operations: + with self._operations_lock: + self.operations.update( + {op.operation_id: op for op in all_operations} + ) def get_input_payload(self) -> str | None: # It is possible that backend will not provide an execution operation diff --git a/tests/exceptions_test.py b/tests/exceptions_test.py index f350b42..a18d534 100644 --- a/tests/exceptions_test.py +++ b/tests/exceptions_test.py @@ -7,12 +7,15 @@ from botocore.exceptions import ClientError # type: ignore[import-untyped] from aws_durable_execution_sdk_python.exceptions import ( + BotoClientError, CallableRuntimeError, CallableRuntimeErrorSerializableDetails, CheckpointError, CheckpointErrorCategory, + DurableApiErrorCategory, DurableExecutionsError, ExecutionError, + GetExecutionStateError, InvocationError, OrderedLockError, OrphanedChildException, @@ -67,7 +70,7 @@ def test_checkpoint_error_classification_invalid_token_invocation(): result = CheckpointError.from_exception(client_error) assert result.error_category == CheckpointErrorCategory.INVOCATION - assert result.is_retriable() + assert result.is_retryable() def test_checkpoint_error_classification_payload_size_exceeded_execution(): @@ -84,13 +87,16 @@ def test_checkpoint_error_classification_payload_size_exceeded_execution(): result = CheckpointError.from_exception(client_error) assert result.error_category == CheckpointErrorCategory.EXECUTION - assert not result.is_retriable() + assert not result.is_retryable() -def test_checkpoint_error_classification_other_4xx_execution(): - """Test other 4xx errors are execution errors.""" +def test_checkpoint_error_classification_invalid_param_without_token_execution(): + """Test 4xx InvalidParameterValueException without Invalid Checkpoint Token is execution error.""" error_response = { - "Error": {"Code": "ValidationException", "Message": "Invalid parameter value"}, + "Error": { + "Code": "InvalidParameterValueException", + "Message": "Some other invalid parameter", + }, "ResponseMetadata": {"HTTPStatusCode": 400}, } client_error = ClientError(error_response, "Checkpoint") @@ -98,62 +104,100 @@ def test_checkpoint_error_classification_other_4xx_execution(): result = CheckpointError.from_exception(client_error) assert result.error_category == CheckpointErrorCategory.EXECUTION - assert not result.is_retriable() + assert not result.is_retryable() -def test_checkpoint_error_classification_429_invocation(): - """Test 429 errors are invocation errors (retryable).""" - error_response = { - "Error": {"Code": "TooManyRequestsException", "Message": "Rate limit exceeded"}, - "ResponseMetadata": {"HTTPStatusCode": 429}, - } - client_error = ClientError(error_response, "Checkpoint") +# ============================================================================= +# Shared Durable API error classification tests (BotoClientError._classify_error_category) +# These test the shared classification logic through each BotoClientError subclass. +# ============================================================================= - result = CheckpointError.from_exception(client_error) - assert result.error_category == CheckpointErrorCategory.INVOCATION - assert result.is_retriable() +@pytest.mark.parametrize("error_cls", [CheckpointError, GetExecutionStateError]) +@pytest.mark.parametrize( + "error_code", + [ + "KMSAccessDeniedException", + "KMSDisabledException", + "KMSInvalidStateException", + "KMSNotFoundException", + ], +) +def test_durable_api_error_non_retryable_customer_error_codes( + error_cls, error_code: str +): + """Test that non-retryable customer error codes (HTTP 502) are classified as EXECUTION.""" + error_response = { + "Error": {"Code": error_code, "Message": f"{error_code} error"}, + "ResponseMetadata": {"HTTPStatusCode": 502}, + } + client_error = ClientError(error_response, "Invoke") # type: ignore[arg-type] + result = error_cls.from_exception(client_error) + assert result.error_category == DurableApiErrorCategory.EXECUTION + assert not result.is_retryable() -def test_checkpoint_error_classification_invalid_param_without_token_execution(): - """Test 4xx InvalidParameterValueException without Invalid Checkpoint Token is execution error.""" +@pytest.mark.parametrize("error_cls", [CheckpointError, GetExecutionStateError]) +def test_durable_api_error_4xx_non_retryable(error_cls): + """Test 4xx errors are classified as EXECUTION (non-retryable).""" error_response = { - "Error": { - "Code": "InvalidParameterValueException", - "Message": "Some other invalid parameter", - }, + "Error": {"Code": "ValidationException", "Message": "Invalid parameter"}, "ResponseMetadata": {"HTTPStatusCode": 400}, } - client_error = ClientError(error_response, "Checkpoint") + client_error = ClientError(error_response, "Invoke") + result = error_cls.from_exception(client_error) + assert result.error_category == DurableApiErrorCategory.EXECUTION + assert not result.is_retryable() - result = CheckpointError.from_exception(client_error) - assert result.error_category == CheckpointErrorCategory.EXECUTION - assert not result.is_retriable() +@pytest.mark.parametrize("error_cls", [CheckpointError, GetExecutionStateError]) +def test_durable_api_error_429_retryable(error_cls): + """Test 429 errors are classified as INVOCATION (retryable).""" + error_response = { + "Error": {"Code": "TooManyRequestsException", "Message": "Rate limit exceeded"}, + "ResponseMetadata": {"HTTPStatusCode": 429}, + } + client_error = ClientError(error_response, "Invoke") + result = error_cls.from_exception(client_error) + assert result.error_category == DurableApiErrorCategory.INVOCATION + assert result.is_retryable() -def test_checkpoint_error_classification_5xx_invocation(): - """Test 5xx errors are invocation errors.""" +@pytest.mark.parametrize("error_cls", [CheckpointError, GetExecutionStateError]) +def test_durable_api_error_5xx_retryable(error_cls): + """Test 5xx errors are classified as INVOCATION (retryable).""" error_response = { "Error": {"Code": "InternalServerError", "Message": "Service unavailable"}, "ResponseMetadata": {"HTTPStatusCode": 500}, } - client_error = ClientError(error_response, "Checkpoint") - - result = CheckpointError.from_exception(client_error) - - assert result.error_category == CheckpointErrorCategory.INVOCATION - assert result.is_retriable() + client_error = ClientError(error_response, "Invoke") + result = error_cls.from_exception(client_error) + assert result.error_category == DurableApiErrorCategory.INVOCATION + assert result.is_retryable() -def test_checkpoint_error_classification_unknown_invocation(): - """Test unknown errors are invocation errors.""" - unknown_error = Exception("Network timeout") +@pytest.mark.parametrize("error_cls", [CheckpointError, GetExecutionStateError]) +def test_durable_api_error_retryable_502(error_cls): + """Test that 502 errors with unrecognized error codes are retryable.""" + error_response = { + "Error": { + "Code": "ServiceException", + "Message": "Service encountered an internal error.", + }, + "ResponseMetadata": {"HTTPStatusCode": 502}, + } + client_error = ClientError(error_response, "Invoke") + result = error_cls.from_exception(client_error) + assert result.error_category == DurableApiErrorCategory.INVOCATION + assert result.is_retryable() - result = CheckpointError.from_exception(unknown_error) - assert result.error_category == CheckpointErrorCategory.INVOCATION - assert result.is_retriable() +@pytest.mark.parametrize("error_cls", [CheckpointError, GetExecutionStateError]) +def test_durable_api_error_unknown_retryable(error_cls): + """Test unknown errors (no HTTP response) are classified as INVOCATION (retryable).""" + result = error_cls.from_exception(Exception("Network timeout")) + assert result.error_category == DurableApiErrorCategory.INVOCATION + assert result.is_retryable() def test_validation_error(): @@ -391,3 +435,53 @@ def test_orphaned_child_exception_with_operation_id(): exception = OrphanedChildException("parent completed", operation_id="child_op_456") assert exception.operation_id == "child_op_456" assert str(exception) == "parent completed" + + +@pytest.mark.parametrize( + ("error_code", "status_code", "expected_retryable"), + [ + ("KMSAccessDeniedException", 502, False), + ("ServiceException", 500, True), + ("ServiceException", 502, True), + ], +) +def test_boto_client_error_is_retryable( + error_code: str, status_code: int, expected_retryable: bool +): + """Test BotoClientError.is_retryable() classification.""" + error_response = { + "Error": {"Code": error_code, "Message": "test error"}, + "ResponseMetadata": {"HTTPStatusCode": status_code}, + } + client_error = ClientError(error_response, "Invoke") # type: ignore[arg-type] + result = BotoClientError.from_exception(client_error) + assert result.is_retryable() == expected_retryable + + +def test_boto_client_error_is_retryable_no_error(): + """Test BotoClientError.is_retryable() returns True with no error info.""" + result = BotoClientError.from_exception(Exception("network error")) + assert result.is_retryable() + + +# ============================================================================= +# DurableApiErrorCategory backward compatibility +# ============================================================================= + + +def test_durable_api_error_category_backward_compatible_alias(): + """Test CheckpointErrorCategory is a backward-compatible alias for DurableApiErrorCategory.""" + assert CheckpointErrorCategory is DurableApiErrorCategory + assert CheckpointErrorCategory.INVOCATION is DurableApiErrorCategory.INVOCATION + assert CheckpointErrorCategory.EXECUTION is DurableApiErrorCategory.EXECUTION + + +# ============================================================================= +# is_retryable() tests +# ============================================================================= + + +def test_invocation_error_is_retryable_default(): + """Test InvocationError.is_retryable() returns True by default.""" + error = InvocationError("some error") + assert error.is_retryable() diff --git a/tests/execution_test.py b/tests/execution_test.py index eeaf949..e247d01 100644 --- a/tests/execution_test.py +++ b/tests/execution_test.py @@ -14,7 +14,9 @@ BotoClientError, CheckpointError, CheckpointErrorCategory, + DurableApiErrorCategory, ExecutionError, + GetExecutionStateError, InvocationError, SuspendExecution, ) @@ -1763,11 +1765,12 @@ def test_handler(event: Any, context: DurableContext) -> dict: assert response["Status"] == InvocationStatus.FAILED.value assert response["Error"]["ErrorType"] == "CheckpointError" - mock_logger.exception.assert_called_once() - call_args = mock_logger.exception.call_args - assert "Checkpoint processing failed" in call_args[0][0] - assert call_args[1]["extra"]["Error"] == error_obj - assert call_args[1]["extra"]["ResponseMetadata"] == metadata_obj + mock_logger.exception.assert_called() + # First call: "Checkpoint processing failed" with error extras + first_call = mock_logger.exception.call_args_list[0] + assert "Checkpoint processing failed" in first_call[0][0] + assert first_call[1]["extra"]["Error"] == error_obj + assert first_call[1]["extra"]["ResponseMetadata"] == metadata_obj def test_durable_execution_logs_boto_client_error_extras_from_background_thread(): @@ -2646,3 +2649,150 @@ def test_from_dict_leaves_timestamps_as_integers(): _ = operations[1].wait_details.scheduled_end_timestamp < datetime.datetime.now( tz=datetime.UTC ) + + +# ============================================================================= +# Non-retryable Durable API error handling tests +# ============================================================================= + + +def _make_invocation_input(mock_client, next_marker=""): + """Helper to create a standard test invocation input.""" + operation = Operation( + operation_id="exec1", + operation_type=OperationType.EXECUTION, + status=OperationStatus.STARTED, + execution_details=ExecutionDetails(input_payload="{}"), + ) + return DurableExecutionInvocationInputWithClient( + durable_execution_arn="arn:test:execution", + checkpoint_token="token123", # noqa: S106 + initial_execution_state=InitialExecutionState( + operations=[operation], next_marker=next_marker + ), + service_client=mock_client, + ) + + +def _make_lambda_context(): + """Helper to create a standard mock Lambda context.""" + ctx = Mock() + ctx.aws_request_id = "test-request" + ctx.client_context = None + ctx.identity = None + ctx._epoch_deadline_time_in_ms = 1000000 # noqa: SLF001 + ctx.invoked_function_arn = None + ctx.tenant_id = None + return ctx + + +def test_durable_execution_non_retryable_invocation_error_returns_failed(): + """Test that non-retryable InvocationError returns FAILED instead of retrying.""" + mock_client = Mock(spec=DurableServiceClient) + non_retryable_error = GetExecutionStateError( + message="KMS access denied", + error_category=DurableApiErrorCategory.EXECUTION, + error={"Code": "KMSAccessDeniedException", "Message": "KMS access denied"}, + response_metadata={"HTTPStatusCode": 502}, + ) + + @durable_execution + def test_handler(event: Any, context: DurableContext) -> dict: + raise non_retryable_error + + result = test_handler(_make_invocation_input(mock_client), _make_lambda_context()) + assert result["Status"] == InvocationStatus.FAILED.value + assert result["Error"]["ErrorType"] == "GetExecutionStateError" + + +def test_durable_execution_retryable_invocation_error_raises(): + """Test that retryable InvocationError raises to trigger Lambda retry.""" + mock_client = Mock(spec=DurableServiceClient) + retryable_error = GetExecutionStateError( + message="Service error", + error={"Code": "ServiceException", "Message": "Internal error"}, + response_metadata={"HTTPStatusCode": 500}, + ) + + @durable_execution + def test_handler(event: Any, context: DurableContext) -> dict: + raise retryable_error + + with pytest.raises(GetExecutionStateError, match="Service error"): + test_handler(_make_invocation_input(mock_client), _make_lambda_context()) + + +def test_durable_execution_non_retryable_background_thread_error_returns_failed(): + """Test that non-retryable error from background thread returns FAILED.""" + mock_client = Mock(spec=DurableServiceClient) + non_retryable_error = GetExecutionStateError( + message="KMS key disabled", + error_category=DurableApiErrorCategory.EXECUTION, + error={"Code": "KMSDisabledException", "Message": "KMS key disabled"}, + response_metadata={"HTTPStatusCode": 502}, + ) + mock_client.checkpoint.side_effect = lambda *a, **kw: (_ for _ in ()).throw( + non_retryable_error + ) + + @durable_execution + def test_handler(event: Any, context: DurableContext) -> dict: + context.step(lambda ctx: "step_result") + return {"result": "success"} + + result = test_handler(_make_invocation_input(mock_client), _make_lambda_context()) + assert result["Status"] == InvocationStatus.FAILED.value + assert result["Error"]["ErrorType"] == "GetExecutionStateError" + + +@pytest.mark.parametrize( + ("error_code", "status_code", "error_category"), + [ + ("KMSAccessDeniedException", 502, DurableApiErrorCategory.EXECUTION), + ("ValidationException", 400, DurableApiErrorCategory.EXECUTION), + ], +) +def test_durable_execution_non_retryable_initial_pagination_error_returns_failed( + error_code: str, status_code: int, error_category: DurableApiErrorCategory +): + """Test that non-retryable errors during initial pagination return FAILED.""" + mock_client = Mock(spec=DurableServiceClient) + non_retryable_error = GetExecutionStateError( + message=f"{error_code} error", + error_category=error_category, + error={"Code": error_code, "Message": f"{error_code} error"}, + response_metadata={"HTTPStatusCode": status_code}, # type: ignore[arg-type] + ) + mock_client.get_execution_state.side_effect = non_retryable_error + + @durable_execution + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client, next_marker="next-page-marker"), + _make_lambda_context(), + ) + assert result["Status"] == InvocationStatus.FAILED.value + assert result["Error"]["ErrorType"] == "GetExecutionStateError" + + +def test_durable_execution_retryable_initial_pagination_error_raises(): + """Test that retryable error during initial pagination raises to trigger Lambda retry.""" + mock_client = Mock(spec=DurableServiceClient) + retryable_error = GetExecutionStateError( + message="Service error", + error={"Code": "ServiceException", "Message": "Internal error"}, + response_metadata={"HTTPStatusCode": 500}, + ) + mock_client.get_execution_state.side_effect = retryable_error + + @durable_execution + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + with pytest.raises(GetExecutionStateError, match="Service error"): + test_handler( + _make_invocation_input(mock_client, next_marker="next-page-marker"), + _make_lambda_context(), + ) diff --git a/tests/state_test.py b/tests/state_test.py index 19d641e..2ea0ab3 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -16,6 +16,8 @@ from aws_durable_execution_sdk_python.exceptions import ( BackgroundThreadError, CallableRuntimeError, + DurableApiErrorCategory, + GetExecutionStateError, OrphanedChildException, ) from aws_durable_execution_sdk_python.identifier import OperationIdentifier @@ -724,6 +726,88 @@ def mock_get_execution_state(durable_execution_arn, checkpoint_token, next_marke assert operation.operation_id == expected_op.operation_id +def test_fetch_paginated_operations_stores_partial_results_on_error(): + """Test that operations from successful pages are stored even when a later page fails.""" + mock_lambda_client = Mock(spec=LambdaClient) + + non_retryable_error = GetExecutionStateError( + message="KMS access denied", + error_category=DurableApiErrorCategory.EXECUTION, + error={"Code": "KMSAccessDeniedException", "Message": "KMS access denied"}, + response_metadata={"HTTPStatusCode": 502}, + ) + + def mock_get_execution_state(durable_execution_arn, checkpoint_token, next_marker): + if next_marker == "marker1": + return StateOutput( + operations=[ + Operation( + operation_id="1", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ) + ], + next_marker="marker2", + ) + raise non_retryable_error + + mock_lambda_client.get_execution_state.side_effect = mock_get_execution_state + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_lambda_client, + ) + + with pytest.raises(GetExecutionStateError): + state.fetch_paginated_operations( + initial_operations=[ + Operation( + operation_id="0", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ) + ], + checkpoint_token="test_token", # noqa: S106 + next_marker="marker1", + ) + + # Initial operation + page 1 should be stored despite page 2 failing + assert "0" in state.operations + assert "1" in state.operations + assert len(state.operations) == 2 + + +def test_fetch_paginated_operations_logs_error(caplog): + """Test that GetExecutionStateError is logged with structured extras.""" + mock_lambda_client = Mock(spec=LambdaClient) + + error = GetExecutionStateError( + message="Service error", + error_category=DurableApiErrorCategory.INVOCATION, + error={"Code": "ServiceException", "Message": "Service error"}, + response_metadata={"HTTPStatusCode": 500}, + ) + mock_lambda_client.get_execution_state.side_effect = error + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_lambda_client, + ) + + with pytest.raises(GetExecutionStateError): + state.fetch_paginated_operations( + initial_operations=[], + checkpoint_token="test_token", # noqa: S106 + next_marker="marker1", + ) + + assert "Durable API error during state fetch." in caplog.text + + # ============================================================================ # Checkpoint Batching Tests # ============================================================================