Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions pyrit/backend/routes/scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
/api/scenarios/runs — scenario execution lifecycle
"""

from typing import Any

from fastapi import APIRouter, HTTPException, Query, status

from pyrit.backend.models.common import ProblemDetail
Expand Down Expand Up @@ -199,7 +201,7 @@ async def cancel_scenario_run(scenario_result_id: str) -> ScenarioRunSummary: #
409: {"model": ProblemDetail, "description": "Run not yet completed"},
},
)
async def get_scenario_run_results(scenario_result_id: str) -> dict: # pyrit-async-suffix-exempt
async def get_scenario_run_results(scenario_result_id: str) -> dict[str, Any]: # pyrit-async-suffix-exempt
"""
Get detailed results for a completed scenario run.

Expand All @@ -209,7 +211,7 @@ async def get_scenario_run_results(scenario_result_id: str) -> dict: # pyrit-as
scenario_result_id: The scenario_result_id.

Returns:
dict: ScenarioResult.to_dict() payload.
dict: ``ScenarioResult.model_dump(mode="json", by_alias=True)`` payload.
"""
service = get_scenario_run_service()
try:
Expand All @@ -222,4 +224,4 @@ async def get_scenario_run_results(scenario_result_id: str) -> dict: # pyrit-as
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Scenario run '{scenario_result_id}' not found",
)
return result.to_dict()
return result.model_dump(mode="json", by_alias=True)
4 changes: 2 additions & 2 deletions pyrit/cli/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,12 @@ async def print_scenario_result_async(*, result_dict: dict[str, Any]) -> None:
Print detailed scenario results using the output module.

Args:
result_dict: ``ScenarioResult.to_dict()`` payload from the REST API.
result_dict: ``ScenarioResult.model_dump(mode="json", by_alias=True)`` payload from the REST API.
"""
from pyrit.models.scenario_result import ScenarioResult
from pyrit.output.scenario_result.pretty import PrettyScenarioResultMemoryPrinter

scenario_result = ScenarioResult.from_dict(result_dict)
scenario_result = ScenarioResult.model_validate(result_dict)
printer = PrettyScenarioResultMemoryPrinter()
await printer.write_async(scenario_result)

Expand Down
2 changes: 1 addition & 1 deletion pyrit/cli/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ async def get_scenario_run_results_async(self, *, scenario_result_id: str) -> di
Get detailed results for a completed scenario run.

Returns:
dict: ``ScenarioResult.to_dict()`` payload.
dict: ``ScenarioResult.model_dump(mode="json", by_alias=True)`` payload.
"""
return await self._get_json_async(path=f"/api/scenarios/runs/{scenario_result_id}/results")

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/backend/test_scenario_run_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ class TestGetScenarioRunResultsRoute:
def test_get_results_returns_200(self, client: TestClient) -> None:
"""Test that getting results of a completed run returns 200."""
mock_scenario_result = MagicMock()
mock_scenario_result.to_dict.return_value = {
mock_scenario_result.model_dump.return_value = {
"id": "result-uuid",
"scenario_identifier": {"name": "foundry.red_team_agent", "version": 1},
"scenario_run_state": "COMPLETED",
Expand Down
19 changes: 11 additions & 8 deletions tests/unit/cli/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,23 +319,26 @@ async def test_print_scenario_result_async_uses_pretty_printer():
fake_printer.write_async = AsyncMock()

with (
patch("pyrit.models.scenario_result.ScenarioResult.from_dict", return_value=fake_scenario) as from_dict_mock,
patch(
"pyrit.models.scenario_result.ScenarioResult.model_validate", return_value=fake_scenario
) as model_validate_mock,
patch(
"pyrit.output.scenario_result.pretty.PrettyScenarioResultMemoryPrinter", return_value=fake_printer
) as printer_cls,
):
await _output.print_scenario_result_async(result_dict=result_dict)

from_dict_mock.assert_called_once_with(result_dict)
model_validate_mock.assert_called_once_with(result_dict)
printer_cls.assert_called_once_with()
fake_printer.write_async.assert_awaited_once_with(fake_scenario)


async def test_print_scenario_result_async_roundtrip_with_real_payload():
"""
Integration smoke test: a real ScenarioResult.to_dict() payload must flow
through ScenarioResult.from_dict() inside print_scenario_result_async
without raising. Locks the REST contract used by the CLI thin client.
Integration smoke test: a real ``ScenarioResult.model_dump(mode="json", by_alias=True)``
payload must flow through ``ScenarioResult.model_validate(...)`` inside
``print_scenario_result_async`` without raising. Locks the REST contract used by the CLI
thin client.
"""
from datetime import datetime, timezone

Expand All @@ -361,10 +364,10 @@ async def test_print_scenario_result_async_roundtrip_with_real_payload():
attack_results={"strat_a": [attack]},
scenario_run_state="COMPLETED",
)
payload = original.to_dict()
payload = original.model_dump(mode="json", by_alias=True)

# Drive print_scenario_result_async through the real from_dict path; only
# stub the printer to keep the test fast.
# Drive print_scenario_result_async through the real model_validate path;
# only stub the printer to keep the test fast.
fake_printer = MagicMock()
fake_printer.write_async = AsyncMock()
with patch(
Expand Down
Loading