diff --git a/docs/running_agents.md b/docs/running_agents.md index d72d47a2ce..22842c97ce 100644 --- a/docs/running_agents.md +++ b/docs/running_agents.md @@ -465,7 +465,7 @@ Set the hook per run via `run_config` to redact sensitive data, trim long histor ### Error handlers -All `Runner` entry points accept `error_handlers`, a dict keyed by error kind. The supported keys are `"max_turns"` and `"model_refusal"`. Use them when you want to return a controlled final output instead of raising `MaxTurnsExceeded` or `ModelRefusalError`. +All `Runner` entry points accept `error_handlers`, a dict keyed by error kind. The supported keys are `"max_turns"`, `"model_refusal"`, and `"invalid_final_output"`. Use them when you want to return a controlled final output instead of raising `MaxTurnsExceeded`, `ModelRefusalError`, or `ModelBehaviorError`. ```python from agents import ( @@ -528,6 +528,38 @@ result = Runner.run_sync( print(result.final_output) ``` +Use `"invalid_final_output"` when the model's final message does not validate against the agent's `output_type` and you want a fallback instead of ending the run with `ModelBehaviorError`. The handler only fires for final output validation failures; other `ModelBehaviorError` cases, such as calls to nonexistent tools, are not affected. + +```python +from pydantic import BaseModel + +from agents import Agent, ModelBehaviorError, RunErrorHandlerInput, Runner + + +class Recipe(BaseModel): + ingredients: list[str] + parse_error: str | None = None + + +def on_invalid_final_output(data: RunErrorHandlerInput[None]) -> Recipe: + assert isinstance(data.error, ModelBehaviorError) + return Recipe(ingredients=[], parse_error=data.error.message) + + +agent = Agent( + name="Recipe assistant", + instructions="Return a structured recipe.", + output_type=Recipe, +) + +result = Runner.run_sync( + agent, + "Plan tonight's dinner.", + error_handlers={"invalid_final_output": on_invalid_final_output}, +) +print(result.final_output) +``` + ## Durable execution integrations and human-in-the-loop For tool approval pause/resume patterns, start with the dedicated [Human-in-the-loop guide](human_in_the_loop.md). diff --git a/src/agents/run_error_handlers.py b/src/agents/run_error_handlers.py index 6f345852eb..cdf83f3149 100644 --- a/src/agents/run_error_handlers.py +++ b/src/agents/run_error_handlers.py @@ -7,7 +7,7 @@ from typing_extensions import TypedDict from .agent import Agent -from .exceptions import MaxTurnsExceeded, ModelRefusalError +from .exceptions import MaxTurnsExceeded, ModelBehaviorError, ModelRefusalError from .items import ModelResponse, RunItem, TResponseInputItem from .run_context import RunContextWrapper, TContext from .util._types import MaybeAwaitable @@ -27,7 +27,7 @@ class RunErrorData: @dataclass class RunErrorHandlerInput(Generic[TContext]): - error: MaxTurnsExceeded | ModelRefusalError + error: MaxTurnsExceeded | ModelRefusalError | ModelBehaviorError context: RunContextWrapper[TContext] run_data: RunErrorData @@ -52,6 +52,7 @@ class RunErrorHandlers(TypedDict, Generic[TContext], total=False): max_turns: RunErrorHandler[TContext] model_refusal: RunErrorHandler[TContext] + invalid_final_output: RunErrorHandler[TContext] __all__ = [ diff --git a/src/agents/run_internal/error_handlers.py b/src/agents/run_internal/error_handlers.py index 81a94a2002..2106e067bf 100644 --- a/src/agents/run_internal/error_handlers.py +++ b/src/agents/run_internal/error_handlers.py @@ -128,7 +128,7 @@ def create_message_output_item(agent: Agent[Any], output_text: str) -> MessageOu async def resolve_run_error_handler_result( *, error_handlers: RunErrorHandlers[TContext] | None, - error: MaxTurnsExceeded | ModelRefusalError, + error: MaxTurnsExceeded | ModelRefusalError | ModelBehaviorError, context_wrapper: RunContextWrapper[TContext], run_data: RunErrorData, ) -> RunErrorHandlerResult | None: @@ -136,6 +136,8 @@ async def resolve_run_error_handler_result( return None if isinstance(error, ModelRefusalError): handler = error_handlers.get("model_refusal") + elif isinstance(error, ModelBehaviorError): + handler = error_handlers.get("invalid_final_output") else: handler = error_handlers.get("max_turns") if handler is None: diff --git a/src/agents/run_internal/turn_resolution.py b/src/agents/run_internal/turn_resolution.py index 2c95cf2e13..1f18de0e2f 100644 --- a/src/agents/run_internal/turn_resolution.py +++ b/src/agents/run_internal/turn_resolution.py @@ -807,7 +807,30 @@ async def execute_tools_and_side_effects( tool_output_guardrail_results=tool_output_guardrail_results, ) if output_schema and not output_schema.is_plain_text() and potential_final_output_text: - final_output = output_schema.validate_json(potential_final_output_text) + try: + final_output = output_schema.validate_json(potential_final_output_text) + except ModelBehaviorError as invalid_output_error: + run_error_data = build_run_error_data( + input=original_input, + new_items=pre_step_items + new_step_items, + raw_responses=[new_response], + last_agent=public_agent, + ) + handler_result = await resolve_run_error_handler_result( + error_handlers=error_handlers, + error=invalid_output_error, + context_wrapper=context_wrapper, + run_data=run_error_data, + ) + if handler_result is None: + raise + + final_output = validate_handler_final_output( + public_agent, handler_result.final_output + ) + if handler_result.include_in_history: + output_text = format_final_output_text(public_agent, final_output) + new_step_items.append(create_message_output_item(public_agent, output_text)) return await execute_final_output_call( public_agent=public_agent, original_input=original_input, diff --git a/tests/test_max_turns.py b/tests/test_max_turns.py index 0a21aaf385..21fbffb45d 100644 --- a/tests/test_max_turns.py +++ b/tests/test_max_turns.py @@ -11,6 +11,7 @@ ItemHelpers, MaxTurnsExceeded, MessageOutputItem, + ModelBehaviorError, ModelRefusalError, RunErrorHandlerResult, Runner, @@ -234,6 +235,84 @@ async def test_streamed_refusal_handler_returns_output(): ) +@pytest.mark.asyncio +async def test_non_streamed_invalid_final_output_raises_without_handler(): + model = FakeModel(initial_output=[get_text_message("not valid json")]) + agent = Agent(name="test_1", model=model, output_type=FooModel) + + with pytest.raises(ModelBehaviorError): + await Runner.run(agent, input="user_message", max_turns=3) + + +@pytest.mark.asyncio +async def test_non_streamed_invalid_final_output_handler_returns_structured_output(): + model = FakeModel(initial_output=[get_text_message("not valid json")]) + agent = Agent(name="test_1", model=model, output_type=FooModel) + + def handler(data): + assert isinstance(data.error, ModelBehaviorError) + assert data.run_data.raw_responses + return FooModel(summary="safe fallback") + + result = await Runner.run( + agent, + input="user_message", + max_turns=3, + error_handlers={"invalid_final_output": handler}, + ) + + assert isinstance(result.final_output, FooModel) + assert result.final_output.summary == "safe fallback" + assert ItemHelpers.text_message_outputs(result.new_items).endswith( + '{"summary":"safe fallback"}' + ) + + +@pytest.mark.asyncio +async def test_non_streamed_invalid_final_output_handler_can_skip_history(): + model = FakeModel(initial_output=[get_text_message("not valid json")]) + agent = Agent(name="test_1", model=model, output_type=FooModel) + + result = await Runner.run( + agent, + input="user_message", + error_handlers={ + "invalid_final_output": lambda data: RunErrorHandlerResult( + final_output=FooModel(summary="safe fallback"), + include_in_history=False, + ), + }, + ) + + assert result.final_output.summary == "safe fallback" + # The model's own invalid message is still part of the turn's items; only the + # handler's fallback is skipped, so no extra synthetic message is appended for it. + assert ItemHelpers.text_message_outputs(result.new_items) == "not valid json" + + +@pytest.mark.asyncio +async def test_streamed_invalid_final_output_handler_returns_output(): + model = FakeModel(initial_output=[get_text_message("not valid json")]) + agent = Agent(name="test_1", model=model, output_type=FooModel) + + result = Runner.run_streamed( + agent, + input="user_message", + error_handlers={ + "invalid_final_output": lambda data: FooModel(summary="safe fallback"), + }, + ) + + events = [event async for event in result.stream_events()] + + assert result.final_output.summary == "safe fallback" + run_item_events = [event for event in events if isinstance(event, RunItemStreamEvent)] + assert any( + event.name == "message_output_created" and isinstance(event.item, MessageOutputItem) + for event in run_item_events + ) + + @pytest.mark.asyncio async def test_structured_output_non_streamed_max_turns(): model = FakeModel()