From bfd5cbf89cd41d4c4120de159e36e08cc04b6089 Mon Sep 17 00:00:00 2001 From: Forge Date: Wed, 24 Jun 2026 06:52:43 +0000 Subject: [PATCH 01/68] [AISOS-1888] Define StatsState mixin and StageStats TypedDict MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Detailed description: - Created src/forge/workflow/stats.py with two TypedDicts: * StageStats: per-stage metrics (stage_name, iteration_count, machine_time_seconds, human_time_seconds, input_tokens, output_tokens, started_at, ended_at) — all nullable timestamps use X | None convention * StatsState: workflow-level stats mixin (stats_stages, stats_pr_urls, stats_ci_cycles, stats_outcome, stats_outcome_reason, stats_comment_posted) - Modified src/forge/workflow/base.py to import and re-export StageStats and StatsState via __all__; added module docstring documenting all state mixins - Modified src/forge/workflow/__init__.py to re-export StageStats and StatsState - Created tests/unit/workflow/test_stats.py with 18 unit tests verifying field presence, type annotations, nullable semantics, construction patterns, and importability from both forge.workflow and forge.workflow.base Closes: AISOS-1888 --- src/forge/workflow/__init__.py | 3 + src/forge/workflow/base.py | 25 +++- src/forge/workflow/stats.py | 77 +++++++++++ tests/unit/workflow/test_stats.py | 217 ++++++++++++++++++++++++++++++ 4 files changed, 321 insertions(+), 1 deletion(-) create mode 100644 src/forge/workflow/stats.py create mode 100644 tests/unit/workflow/test_stats.py diff --git a/src/forge/workflow/__init__.py b/src/forge/workflow/__init__.py index 4d68775c..67e7587d 100644 --- a/src/forge/workflow/__init__.py +++ b/src/forge/workflow/__init__.py @@ -9,6 +9,7 @@ ) from forge.workflow.registry import create_default_router from forge.workflow.router import WorkflowRouter +from forge.workflow.stats import StageStats, StatsState __all__ = [ "BaseState", @@ -16,6 +17,8 @@ "CIIntegrationState", "PRIntegrationState", "ReviewIntegrationState", + "StageStats", + "StatsState", "WorkflowRouter", "create_default_router", ] diff --git a/src/forge/workflow/base.py b/src/forge/workflow/base.py index b3b0d161..ae197cb1 100644 --- a/src/forge/workflow/base.py +++ b/src/forge/workflow/base.py @@ -1,4 +1,16 @@ -"""Base workflow classes and state definitions.""" +"""Base workflow classes and state definitions. + +Mixin TypedDicts +---------------- +Compose workflow states from the following mixins: + +* :class:`PRIntegrationState` — for workflows that open pull requests. +* :class:`CIIntegrationState` — for workflows that run CI checks. +* :class:`ReviewIntegrationState` — for workflows with review stages. +* :class:`~forge.workflow.stats.StatsState` — for workflows that record + execution statistics (iteration counts, token usage, timing, outcome). + Defined in :mod:`forge.workflow.stats`. +""" from abc import ABC, abstractmethod from datetime import datetime @@ -8,6 +20,17 @@ from langgraph.graph.message import add_messages from forge.models.workflow import TicketType +from forge.workflow.stats import StageStats, StatsState + +__all__ = [ + "BaseState", + "BaseWorkflow", + "CIIntegrationState", + "PRIntegrationState", + "ReviewIntegrationState", + "StageStats", + "StatsState", +] class BaseState(TypedDict, total=False): diff --git a/src/forge/workflow/stats.py b/src/forge/workflow/stats.py new file mode 100644 index 00000000..15ac3d15 --- /dev/null +++ b/src/forge/workflow/stats.py @@ -0,0 +1,77 @@ +"""Statistics tracking data structures for workflow execution. + +This module defines the TypedDicts used to capture per-stage metrics and +overall workflow outcome data, as required by SC-001. +""" + +from typing import TypedDict + + +class StageStats(TypedDict, total=False): + """Per-stage execution metrics captured during workflow execution. + + Each stage in a workflow gets one StageStats entry, keyed by stage name + in the StatsState.stats_stages mapping. Fields are updated incrementally + as the stage progresses and finalised when the stage ends. + + Fields: + stage_name: Canonical name of the workflow stage (e.g. "implement"). + iteration_count: Number of times this stage has been (re-)entered, + including retries and revision loops. + machine_time_seconds: Wall-clock seconds spent executing automated work + (LLM calls, tool calls, CI waiting, etc.) — i.e. time the system + was actively doing something. + human_time_seconds: Wall-clock seconds the workflow was paused waiting + for human input (approval gates, revision requests, Q&A). + input_tokens: Cumulative LLM prompt tokens consumed by this stage. + output_tokens: Cumulative LLM completion tokens produced by this stage. + started_at: ISO-8601 timestamp when the stage first started, or None + if the stage has not yet been entered. + ended_at: ISO-8601 timestamp when the stage finished (either completed + or abandoned), or None if it is still in progress. + """ + + stage_name: str + iteration_count: int + machine_time_seconds: float + human_time_seconds: float + input_tokens: int + output_tokens: int + started_at: str | None + ended_at: str | None + + +class StatsState(TypedDict, total=False): + """Mixin TypedDict for workflow-level statistics tracking. + + Intended to be composed into workflow state classes alongside BaseState + and other integration mixins. All fields are optional (total=False) so + that existing workflows can adopt the mixin incrementally without + providing values upfront. + + Outcome values follow the convention: + "Completed" — workflow finished successfully. + "Blocked: " — workflow is waiting on an external blocker. + "Failed: " — workflow terminated due to an unrecoverable error. + + Fields: + stats_stages: Mapping from stage name to its StageStats snapshot. + Updated in-place as each stage starts and ends. + stats_pr_urls: URLs of all pull requests opened during this workflow + run (across all repositories). + stats_ci_cycles: Number of CI fix-attempt cycles that were triggered + during the implementation phase. + stats_outcome: Final outcome string for the workflow run, or None while + the workflow is still in progress. + stats_outcome_reason: Human-readable elaboration on the outcome (e.g. + the blocking reason or error message), or None when not applicable. + stats_comment_posted: True once the summary statistics comment has been + posted to the Jira ticket (prevents double-posting on retries). + """ + + stats_stages: dict[str, StageStats] + stats_pr_urls: list[str] + stats_ci_cycles: int + stats_outcome: str | None + stats_outcome_reason: str | None + stats_comment_posted: bool diff --git a/tests/unit/workflow/test_stats.py b/tests/unit/workflow/test_stats.py new file mode 100644 index 00000000..7f79f4fe --- /dev/null +++ b/tests/unit/workflow/test_stats.py @@ -0,0 +1,217 @@ +"""Unit tests for StageStats and StatsState TypedDicts.""" + +from typing import get_type_hints + +import pytest + + +class TestStageStats: + """Tests for StageStats TypedDict.""" + + def test_stage_stats_has_all_required_fields(self): + """StageStats defines every field required by SC-001.""" + from forge.workflow.stats import StageStats + + hints = get_type_hints(StageStats) + + assert "stage_name" in hints + assert "iteration_count" in hints + assert "machine_time_seconds" in hints + assert "human_time_seconds" in hints + assert "input_tokens" in hints + assert "output_tokens" in hints + assert "started_at" in hints + assert "ended_at" in hints + + def test_stage_stats_field_types(self): + """StageStats fields carry the correct type annotations.""" + from forge.workflow.stats import StageStats + + hints = get_type_hints(StageStats) + + assert hints["stage_name"] is str + assert hints["iteration_count"] is int + assert hints["machine_time_seconds"] is float + assert hints["human_time_seconds"] is float + assert hints["input_tokens"] is int + assert hints["output_tokens"] is int + + def test_stage_stats_nullable_timestamps(self): + """started_at and ended_at accept None (X | None convention).""" + from forge.workflow.stats import StageStats + + hints = get_type_hints(StageStats, include_extras=False) + + # Under Python 3.11+ X | None becomes types.UnionType. + # str(str | None) is 'str | None' on 3.10+ union syntax. + started_hint = str(hints["started_at"]) + ended_hint = str(hints["ended_at"]) + + assert "str" in started_hint + assert "None" in started_hint + assert "str" in ended_hint + assert "None" in ended_hint + + def test_stage_stats_is_total_false(self): + """StageStats allows partial initialisation.""" + from forge.workflow.stats import StageStats + + # Should not raise — total=False makes all keys optional + partial: StageStats = {"stage_name": "implement", "iteration_count": 1} + assert partial["stage_name"] == "implement" + assert partial["iteration_count"] == 1 + + def test_stage_stats_full_construction(self): + """StageStats can be constructed with all fields populated.""" + from forge.workflow.stats import StageStats + + stats: StageStats = { + "stage_name": "implement", + "iteration_count": 3, + "machine_time_seconds": 120.5, + "human_time_seconds": 300.0, + "input_tokens": 4096, + "output_tokens": 2048, + "started_at": "2024-01-01T00:00:00Z", + "ended_at": "2024-01-01T00:07:00Z", + } + + assert stats["stage_name"] == "implement" + assert stats["iteration_count"] == 3 + assert stats["machine_time_seconds"] == 120.5 + assert stats["human_time_seconds"] == 300.0 + assert stats["input_tokens"] == 4096 + assert stats["output_tokens"] == 2048 + assert stats["started_at"] == "2024-01-01T00:00:00Z" + assert stats["ended_at"] == "2024-01-01T00:07:00Z" + + def test_stage_stats_nullable_timestamps_accept_none(self): + """started_at and ended_at can be explicitly set to None.""" + from forge.workflow.stats import StageStats + + stats: StageStats = { + "stage_name": "triage", + "started_at": None, + "ended_at": None, + } + assert stats["started_at"] is None + assert stats["ended_at"] is None + + +class TestStatsState: + """Tests for StatsState TypedDict mixin.""" + + def test_stats_state_has_all_required_fields(self): + """StatsState defines all workflow-level statistics fields.""" + from forge.workflow.stats import StatsState + + hints = get_type_hints(StatsState) + + assert "stats_stages" in hints + assert "stats_pr_urls" in hints + assert "stats_ci_cycles" in hints + assert "stats_outcome" in hints + assert "stats_outcome_reason" in hints + assert "stats_comment_posted" in hints + + def test_stats_state_is_total_false(self): + """StatsState allows partial initialisation.""" + from forge.workflow.stats import StatsState + + partial: StatsState = {"stats_ci_cycles": 0} + assert partial["stats_ci_cycles"] == 0 + + def test_stats_state_nullable_outcome_fields(self): + """stats_outcome and stats_outcome_reason accept None.""" + from forge.workflow.stats import StatsState + + hints = get_type_hints(StatsState, include_extras=False) + + outcome_hint = str(hints["stats_outcome"]) + reason_hint = str(hints["stats_outcome_reason"]) + + assert "str" in outcome_hint + assert "None" in outcome_hint + assert "str" in reason_hint + assert "None" in reason_hint + + def test_stats_state_full_construction(self): + """StatsState can be constructed with all fields populated.""" + from forge.workflow.stats import StageStats, StatsState + + stage: StageStats = { + "stage_name": "implement", + "iteration_count": 2, + "machine_time_seconds": 60.0, + "human_time_seconds": 0.0, + "input_tokens": 1000, + "output_tokens": 500, + "started_at": "2024-01-01T00:00:00Z", + "ended_at": "2024-01-01T00:01:00Z", + } + + state: StatsState = { + "stats_stages": {"implement": stage}, + "stats_pr_urls": ["https://github.com/org/repo/pull/42"], + "stats_ci_cycles": 1, + "stats_outcome": "Completed", + "stats_outcome_reason": None, + "stats_comment_posted": True, + } + + assert state["stats_stages"]["implement"]["stage_name"] == "implement" + assert state["stats_pr_urls"] == ["https://github.com/org/repo/pull/42"] + assert state["stats_ci_cycles"] == 1 + assert state["stats_outcome"] == "Completed" + assert state["stats_outcome_reason"] is None + assert state["stats_comment_posted"] is True + + @pytest.mark.parametrize( + "outcome", + [ + "Completed", + "Blocked: waiting for human approval", + "Failed: unrecoverable CI failure", + ], + ) + def test_stats_state_valid_outcome_values(self, outcome: str): + """stats_outcome accepts the three documented outcome patterns.""" + from forge.workflow.stats import StatsState + + state: StatsState = {"stats_outcome": outcome} + assert state["stats_outcome"] == outcome + + def test_stats_state_comment_posted_defaults_pattern(self): + """stats_comment_posted is a bool field.""" + from forge.workflow.stats import StatsState + + hints = get_type_hints(StatsState) + assert hints["stats_comment_posted"] is bool + + def test_stats_stages_is_dict_of_stage_stats(self): + """stats_stages maps string keys to StageStats dicts.""" + from forge.workflow.stats import StageStats, StatsState + + s1: StageStats = {"stage_name": "triage", "iteration_count": 1} + s2: StageStats = {"stage_name": "implement", "iteration_count": 3} + + state: StatsState = {"stats_stages": {"triage": s1, "implement": s2}} + assert len(state["stats_stages"]) == 2 + assert state["stats_stages"]["triage"]["stage_name"] == "triage" + assert state["stats_stages"]["implement"]["iteration_count"] == 3 + + +class TestStatsStateExportedFromPackage: + """Verify the new types are accessible via the workflow package.""" + + def test_stage_stats_importable_from_workflow(self): + """StageStats is exported from forge.workflow.""" + from forge.workflow import StageStats # noqa: F401 + + def test_stats_state_importable_from_workflow(self): + """StatsState is exported from forge.workflow.""" + from forge.workflow import StatsState # noqa: F401 + + def test_stats_state_importable_from_base(self): + """StatsState is importable via forge.workflow.base (re-exported).""" + from forge.workflow.base import StatsState # noqa: F401 From 4e0974b351a9426f62b0834b3eba1af491e11481 Mon Sep 17 00:00:00 2001 From: Forge Date: Wed, 24 Jun 2026 07:02:04 +0000 Subject: [PATCH 02/68] [AISOS-1889] Integrate StatsState mixin into FeatureState and BugState Detailed description: - Added StatsState to FeatureState inheritance chain in feature/state.py - Added StatsState to BugState inheritance chain in bug/state.py - Updated create_initial_feature_state() to initialize all stats fields: stats_stages={}, stats_pr_urls=[], stats_ci_cycles=0, stats_outcome=None, stats_outcome_reason=None, stats_comment_posted=False - Updated create_initial_bug_state() with the same stats field defaults - Extended tests/unit/workflow/feature/test_state.py with TestFeatureStateStatsIntegration and TestBugStateStatsIntegration classes verifying inheritance (via __orig_bases__), field presence, and defaults All 1272 unit tests pass. Closes: AISOS-1889 --- src/forge/workflow/bug/state.py | 15 ++++- src/forge/workflow/feature/state.py | 15 ++++- tests/unit/workflow/feature/test_state.py | 80 +++++++++++++++++++++++ 3 files changed, 108 insertions(+), 2 deletions(-) diff --git a/src/forge/workflow/bug/state.py b/src/forge/workflow/bug/state.py index 486ee0e3..a8e5f81a 100644 --- a/src/forge/workflow/bug/state.py +++ b/src/forge/workflow/bug/state.py @@ -10,11 +10,17 @@ CIIntegrationState, PRIntegrationState, ReviewIntegrationState, + StatsState, ) class BugState( - BaseState, PRIntegrationState, CIIntegrationState, ReviewIntegrationState, total=False + BaseState, + PRIntegrationState, + CIIntegrationState, + ReviewIntegrationState, + StatsState, + total=False, ): """State specific to Bug workflow.""" @@ -135,6 +141,13 @@ def create_initial_bug_state(ticket_key: str, **kwargs: Any) -> BugState: "qualitative_review_failed": False, "reflect_rca_retry_count": 0, "yolo_mode": False, + # Stats fields + "stats_stages": {}, + "stats_pr_urls": [], + "stats_ci_cycles": 0, + "stats_outcome": None, + "stats_outcome_reason": None, + "stats_comment_posted": False, } # Merge with kwargs, letting kwargs override defaults diff --git a/src/forge/workflow/feature/state.py b/src/forge/workflow/feature/state.py index a6c0ac3b..dbaae49d 100644 --- a/src/forge/workflow/feature/state.py +++ b/src/forge/workflow/feature/state.py @@ -10,11 +10,17 @@ CIIntegrationState, PRIntegrationState, ReviewIntegrationState, + StatsState, ) class FeatureState( - BaseState, PRIntegrationState, CIIntegrationState, ReviewIntegrationState, total=False + BaseState, + PRIntegrationState, + CIIntegrationState, + ReviewIntegrationState, + StatsState, + total=False, ): """State specific to Feature workflow.""" @@ -122,6 +128,13 @@ def create_initial_feature_state(ticket_key: str, **kwargs: Any) -> FeatureState "prd_pr_branch": None, "prd_pr_file_path": None, "yolo_mode": False, + # Stats fields + "stats_stages": {}, + "stats_pr_urls": [], + "stats_ci_cycles": 0, + "stats_outcome": None, + "stats_outcome_reason": None, + "stats_comment_posted": False, } # Merge with kwargs, letting kwargs override defaults diff --git a/tests/unit/workflow/feature/test_state.py b/tests/unit/workflow/feature/test_state.py index 94fdfb02..ecf289f6 100644 --- a/tests/unit/workflow/feature/test_state.py +++ b/tests/unit/workflow/feature/test_state.py @@ -132,3 +132,83 @@ def test_bug_state_qa_defaults(self): assert state["qa_history"] == [] assert state["generation_context"] == {} assert state["is_question"] is False + + +class TestFeatureStateStatsIntegration: + """Tests for StatsState mixin integration in FeatureState.""" + + def test_feature_state_inherits_stats_state(self): + """FeatureState includes StatsState in its inheritance chain.""" + from forge.workflow.feature.state import FeatureState + from forge.workflow.stats import StatsState + + # TypedDict flattens to dict in __mro__; use __orig_bases__ instead. + assert StatsState in FeatureState.__orig_bases__ + + def test_feature_state_has_stats_fields(self): + """FeatureState type hints include all StatsState fields.""" + from typing import get_type_hints + + from forge.workflow.feature.state import FeatureState + + hints = get_type_hints(FeatureState) + + assert "stats_stages" in hints + assert "stats_pr_urls" in hints + assert "stats_ci_cycles" in hints + assert "stats_outcome" in hints + assert "stats_outcome_reason" in hints + assert "stats_comment_posted" in hints + + def test_create_initial_feature_state_stats_defaults(self): + """create_initial_feature_state() initialises all stats fields with correct defaults.""" + from forge.workflow.feature.state import create_initial_feature_state + + state = create_initial_feature_state("TEST-123") + + assert state["stats_stages"] == {} + assert state["stats_pr_urls"] == [] + assert state["stats_ci_cycles"] == 0 + assert state["stats_outcome"] is None + assert state["stats_outcome_reason"] is None + assert state["stats_comment_posted"] is False + + +class TestBugStateStatsIntegration: + """Tests for StatsState mixin integration in BugState.""" + + def test_bug_state_inherits_stats_state(self): + """BugState includes StatsState in its inheritance chain.""" + from forge.workflow.bug.state import BugState + from forge.workflow.stats import StatsState + + # TypedDict flattens to dict in __mro__; use __orig_bases__ instead. + assert StatsState in BugState.__orig_bases__ + + def test_bug_state_has_stats_fields(self): + """BugState type hints include all StatsState fields.""" + from typing import get_type_hints + + from forge.workflow.bug.state import BugState + + hints = get_type_hints(BugState) + + assert "stats_stages" in hints + assert "stats_pr_urls" in hints + assert "stats_ci_cycles" in hints + assert "stats_outcome" in hints + assert "stats_outcome_reason" in hints + assert "stats_comment_posted" in hints + + def test_create_initial_bug_state_stats_defaults(self): + """create_initial_bug_state() initialises all stats fields with correct defaults.""" + from forge.workflow.bug.state import create_initial_bug_state + + state = create_initial_bug_state("BUG-456") + + assert state["stats_stages"] == {} + assert state["stats_pr_urls"] == [] + assert state["stats_ci_cycles"] == 0 + assert state["stats_outcome"] is None + assert state["stats_outcome_reason"] is None + assert state["stats_comment_posted"] is False From 4922ace9591b9b511d2a3c6895dbfe5c5e5be9a3 Mon Sep 17 00:00:00 2001 From: Forge Date: Wed, 24 Jun 2026 07:05:39 +0000 Subject: [PATCH 03/68] [AISOS-1890] Implement core stats recording utility functions Detailed description: - Created src/forge/workflow/stats_utils.py with 7 public functions: - record_stage_start: initializes stage in stats_stages with UTC timestamp, zeroed metrics (iteration_count=0, machine/human time=0.0, tokens=0) - record_stage_end: sets ended_at and accumulates machine/human time metrics - record_tokens: accumulates (not replaces) input/output token counts per stage - increment_revision: increments iteration_count by 1 for a stage - increment_ci_cycle: increments workflow-level stats_ci_cycles counter - add_pr_url: appends URL to stats_pr_urls (idempotent, no duplicates) - set_outcome: sets stats_outcome and stats_outcome_reason fields - All functions return partial state dicts for LangGraph state merging - All functions handle missing/uninitialized stages gracefully via _get_stage helper - UTC timestamps use datetime.now(UTC).isoformat() format - Unused state param in set_outcome prefixed with _ per project conventions - Created tests/unit/workflow/test_stats_utils.py with 45 unit tests covering all functions including edge cases (non-existent stages, None values, accumulation, idempotency, re-entry behavior) Closes: AISOS-1890 --- src/forge/workflow/stats_utils.py | 185 +++++++++++++ tests/unit/workflow/test_stats_utils.py | 351 ++++++++++++++++++++++++ 2 files changed, 536 insertions(+) create mode 100644 src/forge/workflow/stats_utils.py create mode 100644 tests/unit/workflow/test_stats_utils.py diff --git a/src/forge/workflow/stats_utils.py b/src/forge/workflow/stats_utils.py new file mode 100644 index 00000000..8fed559b --- /dev/null +++ b/src/forge/workflow/stats_utils.py @@ -0,0 +1,185 @@ +"""Utility functions for recording workflow execution statistics. + +These helpers are called by workflow nodes to update stats fields in the +LangGraph state. Every function returns a dict suitable for merging into +the state via LangGraph's reducer (partial state updates). + +All timestamps are UTC ISO-8601 strings (e.g. "2024-01-01T12:00:00.000000+00:00"). +""" + +from datetime import UTC, datetime + + +def _utc_now() -> str: + """Return the current UTC time as an ISO-8601 string.""" + return datetime.now(UTC).isoformat() + + +def _get_stage(state: dict, stage_name: str) -> dict: + """Return a copy of the stage entry, or a zeroed default if absent.""" + stages: dict = state.get("stats_stages") or {} + existing = stages.get(stage_name) + if existing is None: + return { + "stage_name": stage_name, + "iteration_count": 0, + "machine_time_seconds": 0.0, + "human_time_seconds": 0.0, + "input_tokens": 0, + "output_tokens": 0, + "started_at": None, + "ended_at": None, + } + # Return a shallow copy so callers can mutate freely + return dict(existing) + + +def record_stage_start(state: dict, stage_name: str) -> dict: + """Initialize a stage entry in stats_stages with a started_at timestamp. + + If the stage already exists (e.g. a retry), the started_at timestamp is + updated to now but accumulated metrics are preserved. iteration_count is + left as-is; call :func:`increment_revision` to bump it. + + Args: + state: Current workflow state dict. + stage_name: Name of the stage being started (e.g. ``"implement"``). + + Returns: + Partial state update dict with ``stats_stages`` key. + """ + stages: dict = dict(state.get("stats_stages") or {}) + stage = _get_stage(state, stage_name) + stage["started_at"] = _utc_now() + stage["ended_at"] = None # reset end marker when re-entering + stages[stage_name] = stage + return {"stats_stages": stages} + + +def record_stage_end( + state: dict, + stage_name: str, + machine_time: float, + human_time: float = 0.0, +) -> dict: + """Mark a stage as ended and accumulate time metrics. + + Time values are *accumulated* (not replaced) so that repeated calls for + the same stage (e.g. after retries) add up correctly. + + Args: + state: Current workflow state dict. + stage_name: Name of the stage that has finished. + machine_time: Wall-clock seconds of automated work to add. + human_time: Wall-clock seconds of human-wait time to add (default 0). + + Returns: + Partial state update dict with ``stats_stages`` key. + """ + stages: dict = dict(state.get("stats_stages") or {}) + stage = _get_stage(state, stage_name) + stage["ended_at"] = _utc_now() + stage["machine_time_seconds"] = stage.get("machine_time_seconds", 0.0) + machine_time + stage["human_time_seconds"] = stage.get("human_time_seconds", 0.0) + human_time + stages[stage_name] = stage + return {"stats_stages": stages} + + +def record_tokens( + state: dict, + stage_name: str, + input_tokens: int, + output_tokens: int, +) -> dict: + """Accumulate LLM token counts for a stage. + + Tokens are *accumulated* (not replaced) so that multiple LLM calls within + the same stage all contribute to the total. + + Args: + state: Current workflow state dict. + stage_name: Name of the stage consuming tokens. + input_tokens: Number of prompt tokens to add. + output_tokens: Number of completion tokens to add. + + Returns: + Partial state update dict with ``stats_stages`` key. + """ + stages: dict = dict(state.get("stats_stages") or {}) + stage = _get_stage(state, stage_name) + stage["input_tokens"] = stage.get("input_tokens", 0) + input_tokens + stage["output_tokens"] = stage.get("output_tokens", 0) + output_tokens + stages[stage_name] = stage + return {"stats_stages": stages} + + +def increment_revision(state: dict, stage_name: str) -> dict: + """Increment the iteration_count for a stage by 1. + + Should be called each time a stage is re-entered due to a revision + request or retry. + + Args: + state: Current workflow state dict. + stage_name: Name of the stage being revised. + + Returns: + Partial state update dict with ``stats_stages`` key. + """ + stages: dict = dict(state.get("stats_stages") or {}) + stage = _get_stage(state, stage_name) + stage["iteration_count"] = stage.get("iteration_count", 0) + 1 + stages[stage_name] = stage + return {"stats_stages": stages} + + +def increment_ci_cycle(state: dict) -> dict: + """Increment the workflow-level CI fix-attempt cycle counter by 1. + + Args: + state: Current workflow state dict. + + Returns: + Partial state update dict with ``stats_ci_cycles`` key. + """ + current: int = state.get("stats_ci_cycles") or 0 + return {"stats_ci_cycles": current + 1} + + +def add_pr_url(state: dict, pr_url: str) -> dict: + """Append a PR URL to stats_pr_urls (idempotent — no duplicates). + + Args: + state: Current workflow state dict. + pr_url: The pull-request URL to record. + + Returns: + Partial state update dict with ``stats_pr_urls`` key. + """ + existing: list[str] = list(state.get("stats_pr_urls") or []) + if pr_url not in existing: + existing.append(pr_url) + return {"stats_pr_urls": existing} + + +def set_outcome(_state: dict, outcome: str, reason: str | None = None) -> dict: + """Set the workflow outcome and optional reason. + + Conventional outcome values: + - ``"Completed"`` — finished successfully. + - ``"Blocked: "`` — waiting on an external blocker. + - ``"Failed: "`` — terminated due to an unrecoverable error. + + Args: + _state: Current workflow state dict (unused — outcome is set unconditionally). + outcome: Outcome string to record. + reason: Optional human-readable elaboration (e.g. blocking reason). + + Returns: + Partial state update dict with ``stats_outcome`` and + ``stats_outcome_reason`` keys. + """ + return { + "stats_outcome": outcome, + "stats_outcome_reason": reason, + } diff --git a/tests/unit/workflow/test_stats_utils.py b/tests/unit/workflow/test_stats_utils.py new file mode 100644 index 00000000..3a0ac578 --- /dev/null +++ b/tests/unit/workflow/test_stats_utils.py @@ -0,0 +1,351 @@ +"""Unit tests for forge.workflow.stats_utils.""" + +import pytest + +from forge.workflow.stats_utils import ( + add_pr_url, + increment_ci_cycle, + increment_revision, + record_stage_end, + record_stage_start, + record_tokens, + set_outcome, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _empty_state() -> dict: + """Return a minimal state with stats fields unset (simulates fresh run).""" + return {} + + +def _state_with_stage(stage_name: str, **overrides) -> dict: + """Return a state that already has one stage entry.""" + stage = { + "stage_name": stage_name, + "iteration_count": 0, + "machine_time_seconds": 0.0, + "human_time_seconds": 0.0, + "input_tokens": 0, + "output_tokens": 0, + "started_at": "2024-01-01T00:00:00+00:00", + "ended_at": None, + } + stage.update(overrides) + return {"stats_stages": {stage_name: stage}} + + +# --------------------------------------------------------------------------- +# record_stage_start +# --------------------------------------------------------------------------- + + +class TestRecordStageStart: + def test_initialises_stage_with_timestamp(self): + result = record_stage_start(_empty_state(), "implement") + + assert "stats_stages" in result + stage = result["stats_stages"]["implement"] + assert stage["started_at"] is not None + assert "T" in stage["started_at"] # ISO-8601 + + def test_zeroed_numeric_metrics(self): + result = record_stage_start(_empty_state(), "implement") + stage = result["stats_stages"]["implement"] + + assert stage["iteration_count"] == 0 + assert stage["machine_time_seconds"] == 0.0 + assert stage["human_time_seconds"] == 0.0 + assert stage["input_tokens"] == 0 + assert stage["output_tokens"] == 0 + + def test_ended_at_is_none_on_init(self): + result = record_stage_start(_empty_state(), "implement") + assert result["stats_stages"]["implement"]["ended_at"] is None + + def test_stage_name_recorded(self): + result = record_stage_start(_empty_state(), "triage") + assert result["stats_stages"]["triage"]["stage_name"] == "triage" + + def test_resets_ended_at_on_re_entry(self): + """Re-entering a stage clears ended_at (marks it in-progress again).""" + state = _state_with_stage("implement", ended_at="2024-01-01T01:00:00+00:00") + result = record_stage_start(state, "implement") + assert result["stats_stages"]["implement"]["ended_at"] is None + + def test_preserves_accumulated_metrics_on_re_entry(self): + """Re-entering should not zero out previously accumulated tokens.""" + state = _state_with_stage( + "implement", + input_tokens=500, + output_tokens=250, + machine_time_seconds=30.0, + ) + result = record_stage_start(state, "implement") + stage = result["stats_stages"]["implement"] + + assert stage["input_tokens"] == 500 + assert stage["output_tokens"] == 250 + assert stage["machine_time_seconds"] == 30.0 + + def test_handles_missing_stats_stages_key(self): + """Works when state has no stats_stages key at all.""" + result = record_stage_start({}, "plan") + assert "plan" in result["stats_stages"] + + def test_does_not_mutate_existing_stages(self): + """Other stages in stats_stages are preserved.""" + state = _state_with_stage("triage") + result = record_stage_start(state, "implement") + + assert "triage" in result["stats_stages"] + assert "implement" in result["stats_stages"] + + def test_returns_only_stats_stages_key(self): + result = record_stage_start(_empty_state(), "implement") + assert list(result.keys()) == ["stats_stages"] + + +# --------------------------------------------------------------------------- +# record_stage_end +# --------------------------------------------------------------------------- + + +class TestRecordStageEnd: + def test_sets_ended_at_timestamp(self): + state = _state_with_stage("implement") + result = record_stage_end(state, "implement", machine_time=60.0) + + assert result["stats_stages"]["implement"]["ended_at"] is not None + + def test_accumulates_machine_time(self): + state = _state_with_stage("implement", machine_time_seconds=10.0) + result = record_stage_end(state, "implement", machine_time=25.5) + + assert result["stats_stages"]["implement"]["machine_time_seconds"] == pytest.approx(35.5) + + def test_accumulates_human_time(self): + state = _state_with_stage("implement", human_time_seconds=100.0) + result = record_stage_end(state, "implement", machine_time=0.0, human_time=50.0) + + assert result["stats_stages"]["implement"]["human_time_seconds"] == pytest.approx(150.0) + + def test_human_time_defaults_to_zero(self): + state = _state_with_stage("implement") + result = record_stage_end(state, "implement", machine_time=10.0) + + assert result["stats_stages"]["implement"]["human_time_seconds"] == pytest.approx(0.0) + + def test_handles_non_existent_stage(self): + """Calling on a stage that was never started should not raise.""" + result = record_stage_end(_empty_state(), "ghost_stage", machine_time=5.0) + + stage = result["stats_stages"]["ghost_stage"] + assert stage["machine_time_seconds"] == pytest.approx(5.0) + assert stage["ended_at"] is not None + + def test_returns_only_stats_stages_key(self): + state = _state_with_stage("implement") + result = record_stage_end(state, "implement", machine_time=1.0) + assert list(result.keys()) == ["stats_stages"] + + +# --------------------------------------------------------------------------- +# record_tokens +# --------------------------------------------------------------------------- + + +class TestRecordTokens: + def test_accumulates_input_tokens(self): + state = _state_with_stage("implement", input_tokens=100) + result = record_tokens(state, "implement", input_tokens=200, output_tokens=0) + + assert result["stats_stages"]["implement"]["input_tokens"] == 300 + + def test_accumulates_output_tokens(self): + state = _state_with_stage("implement", output_tokens=50) + result = record_tokens(state, "implement", input_tokens=0, output_tokens=75) + + assert result["stats_stages"]["implement"]["output_tokens"] == 125 + + def test_accumulates_both_simultaneously(self): + state = _state_with_stage("implement", input_tokens=10, output_tokens=5) + result = record_tokens(state, "implement", input_tokens=20, output_tokens=10) + + stage = result["stats_stages"]["implement"] + assert stage["input_tokens"] == 30 + assert stage["output_tokens"] == 15 + + def test_handles_non_existent_stage(self): + """Should initialise a new stage entry if it does not exist.""" + result = record_tokens(_empty_state(), "new_stage", input_tokens=50, output_tokens=25) + + stage = result["stats_stages"]["new_stage"] + assert stage["input_tokens"] == 50 + assert stage["output_tokens"] == 25 + + def test_does_not_replace_tokens(self): + """Calling twice should add, not replace.""" + state = _state_with_stage("implement") + first = record_tokens(state, "implement", input_tokens=100, output_tokens=50) + second = record_tokens(first, "implement", input_tokens=100, output_tokens=50) + + assert second["stats_stages"]["implement"]["input_tokens"] == 200 + assert second["stats_stages"]["implement"]["output_tokens"] == 100 + + def test_returns_only_stats_stages_key(self): + result = record_tokens(_empty_state(), "impl", input_tokens=1, output_tokens=1) + assert list(result.keys()) == ["stats_stages"] + + +# --------------------------------------------------------------------------- +# increment_revision +# --------------------------------------------------------------------------- + + +class TestIncrementRevision: + def test_increments_iteration_count_by_one(self): + state = _state_with_stage("implement", iteration_count=2) + result = increment_revision(state, "implement") + + assert result["stats_stages"]["implement"]["iteration_count"] == 3 + + def test_starts_at_one_for_new_stage(self): + result = increment_revision(_empty_state(), "plan") + + assert result["stats_stages"]["plan"]["iteration_count"] == 1 + + def test_multiple_increments_accumulate(self): + state = _empty_state() + for _ in range(5): + state = {**state, **increment_revision(state, "implement")} + + assert state["stats_stages"]["implement"]["iteration_count"] == 5 + + def test_returns_only_stats_stages_key(self): + result = increment_revision(_empty_state(), "triage") + assert list(result.keys()) == ["stats_stages"] + + +# --------------------------------------------------------------------------- +# increment_ci_cycle +# --------------------------------------------------------------------------- + + +class TestIncrementCiCycle: + def test_increments_counter_from_zero(self): + result = increment_ci_cycle(_empty_state()) + assert result["stats_ci_cycles"] == 1 + + def test_increments_existing_counter(self): + state = {"stats_ci_cycles": 3} + result = increment_ci_cycle(state) + assert result["stats_ci_cycles"] == 4 + + def test_handles_none_counter(self): + state = {"stats_ci_cycles": None} + result = increment_ci_cycle(state) + assert result["stats_ci_cycles"] == 1 + + def test_multiple_increments(self): + state = _empty_state() + for _ in range(7): + state = {**state, **increment_ci_cycle(state)} + + assert state["stats_ci_cycles"] == 7 + + def test_returns_only_stats_ci_cycles_key(self): + result = increment_ci_cycle(_empty_state()) + assert list(result.keys()) == ["stats_ci_cycles"] + + +# --------------------------------------------------------------------------- +# add_pr_url +# --------------------------------------------------------------------------- + + +class TestAddPrUrl: + def test_appends_url_to_empty_list(self): + result = add_pr_url(_empty_state(), "https://github.com/org/repo/pull/1") + assert result["stats_pr_urls"] == ["https://github.com/org/repo/pull/1"] + + def test_appends_to_existing_list(self): + state = {"stats_pr_urls": ["https://github.com/org/repo/pull/1"]} + result = add_pr_url(state, "https://github.com/org/repo/pull/2") + + assert result["stats_pr_urls"] == [ + "https://github.com/org/repo/pull/1", + "https://github.com/org/repo/pull/2", + ] + + def test_idempotent_no_duplicates(self): + url = "https://github.com/org/repo/pull/1" + state = {"stats_pr_urls": [url]} + result = add_pr_url(state, url) + + assert result["stats_pr_urls"] == [url] + assert len(result["stats_pr_urls"]) == 1 + + def test_calling_twice_does_not_duplicate(self): + url = "https://github.com/org/repo/pull/42" + state = _empty_state() + state = {**state, **add_pr_url(state, url)} + state = {**state, **add_pr_url(state, url)} + + assert state["stats_pr_urls"].count(url) == 1 + + def test_handles_none_pr_urls(self): + state = {"stats_pr_urls": None} + result = add_pr_url(state, "https://example.com/pr/1") + assert result["stats_pr_urls"] == ["https://example.com/pr/1"] + + def test_returns_only_stats_pr_urls_key(self): + result = add_pr_url(_empty_state(), "https://example.com/pr/1") + assert list(result.keys()) == ["stats_pr_urls"] + + def test_preserves_order(self): + urls = [f"https://example.com/pr/{i}" for i in range(5)] + state = _empty_state() + for url in urls: + state = {**state, **add_pr_url(state, url)} + + assert state["stats_pr_urls"] == urls + + +# --------------------------------------------------------------------------- +# set_outcome +# --------------------------------------------------------------------------- + + +class TestSetOutcome: + def test_sets_outcome(self): + result = set_outcome(_empty_state(), "Completed") + assert result["stats_outcome"] == "Completed" + + def test_sets_reason_when_provided(self): + result = set_outcome(_empty_state(), "Blocked: awaiting review", "PR still open") + assert result["stats_outcome"] == "Blocked: awaiting review" + assert result["stats_outcome_reason"] == "PR still open" + + def test_reason_defaults_to_none(self): + result = set_outcome(_empty_state(), "Completed") + assert result["stats_outcome_reason"] is None + + def test_overwrites_previous_outcome(self): + state = {"stats_outcome": "Blocked", "stats_outcome_reason": "old reason"} + result = set_outcome(state, "Completed", None) + + assert result["stats_outcome"] == "Completed" + assert result["stats_outcome_reason"] is None + + def test_returns_both_keys(self): + result = set_outcome(_empty_state(), "Failed: timeout") + assert set(result.keys()) == {"stats_outcome", "stats_outcome_reason"} + + @pytest.mark.parametrize("outcome", ["Completed", "Blocked: foo", "Failed: bar"]) + def test_conventional_outcome_values(self, outcome: str): + result = set_outcome(_empty_state(), outcome) + assert result["stats_outcome"] == outcome From 83f2b8193a89809485b2811a51a86b3033dff4aa Mon Sep 17 00:00:00 2001 From: Forge Date: Wed, 24 Jun 2026 07:08:23 +0000 Subject: [PATCH 04/68] [AISOS-1891] Add cost alert threshold configuration Detailed description: - Added stats_cost_alert_enabled (bool, default: True) to Settings in src/forge/config.py - Added stats_cost_alert_threshold_tokens (int, default: 1_000_000) to Settings in src/forge/config.py - Both fields include Field descriptions documenting their purpose and behavior - Updated .env.example with a new Stats Cost Alert Configuration section documenting both settings - Added tests/unit/test_config_cost_alert.py with 7 unit tests covering defaults, type checking, and customization Closes: AISOS-1891 --- .env.example | 11 +++++++ src/forge/config.py | 18 ++++++++++++ tests/unit/test_config_cost_alert.py | 44 ++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+) create mode 100644 tests/unit/test_config_cost_alert.py diff --git a/.env.example b/.env.example index 0782e2a4..b8ea84c3 100644 --- a/.env.example +++ b/.env.example @@ -245,3 +245,14 @@ CI_FIX_MAX_RETRIES=5 CI_IGNORED_CHECKS=tide # Webhook acknowledgment timeout in seconds WEBHOOK_ACK_TIMEOUT=0.5 + +# ============================================================================= +# Stats Cost Alert Configuration +# ============================================================================= +# Enable cost alerting in workflow stats summaries. When enabled and aggregate +# token usage (input + output across all stages) exceeds the threshold, the +# stats summary will include a cost alert. +STATS_COST_ALERT_ENABLED=true +# Total token count threshold that triggers a cost alert (default: 1,000,000). +# Applies to aggregate token usage across all workflow stages. +STATS_COST_ALERT_THRESHOLD_TOKENS=1000000 diff --git a/src/forge/config.py b/src/forge/config.py index bcb2a93f..c50fbfc9 100644 --- a/src/forge/config.py +++ b/src/forge/config.py @@ -342,6 +342,24 @@ def ignored_ci_checks(self) -> list[str]: description="Enable Prometheus metrics endpoint in worker", ) + # Stats Cost Alert Configuration + stats_cost_alert_enabled: bool = Field( + default=True, + description=( + "Enable cost alerting in workflow stats summaries. " + "When enabled and aggregate token usage exceeds stats_cost_alert_threshold_tokens, " + "the stats summary will include a cost alert." + ), + ) + stats_cost_alert_threshold_tokens: int = Field( + default=1_000_000, + description=( + "Total token count threshold (input + output across all stages) that triggers " + "a cost alert in the workflow stats summary. Only active when " + "stats_cost_alert_enabled is True. Default: 1,000,000 tokens." + ), + ) + # OpenTelemetry Configuration otlp_endpoint: str = Field( default="", diff --git a/tests/unit/test_config_cost_alert.py b/tests/unit/test_config_cost_alert.py new file mode 100644 index 00000000..75442edd --- /dev/null +++ b/tests/unit/test_config_cost_alert.py @@ -0,0 +1,44 @@ +"""Tests for stats cost alert threshold configuration settings.""" + +import pytest + +from forge.config import Settings + + +REQUIRED_SETTINGS = dict( + jira_base_url="https://test.atlassian.net", + jira_api_token="test", + jira_user_email="test@example.com", + github_token="test", + anthropic_api_key="test", +) + + +class TestStatsCostAlertConfig: + def test_default_cost_alert_enabled_is_true(self): + settings = Settings(**REQUIRED_SETTINGS) + assert settings.stats_cost_alert_enabled is True + + def test_default_cost_alert_threshold_tokens(self): + settings = Settings(**REQUIRED_SETTINGS) + assert settings.stats_cost_alert_threshold_tokens == 1_000_000 + + def test_cost_alert_enabled_can_be_disabled(self): + settings = Settings(**REQUIRED_SETTINGS, stats_cost_alert_enabled=False) + assert settings.stats_cost_alert_enabled is False + + def test_cost_alert_threshold_can_be_customized(self): + settings = Settings(**REQUIRED_SETTINGS, stats_cost_alert_threshold_tokens=500_000) + assert settings.stats_cost_alert_threshold_tokens == 500_000 + + def test_cost_alert_threshold_accepts_large_values(self): + settings = Settings(**REQUIRED_SETTINGS, stats_cost_alert_threshold_tokens=10_000_000) + assert settings.stats_cost_alert_threshold_tokens == 10_000_000 + + def test_cost_alert_threshold_is_int(self): + settings = Settings(**REQUIRED_SETTINGS) + assert isinstance(settings.stats_cost_alert_threshold_tokens, int) + + def test_cost_alert_enabled_is_bool(self): + settings = Settings(**REQUIRED_SETTINGS) + assert isinstance(settings.stats_cost_alert_enabled, bool) From becb2dffc73a37ef19eedc5902d5cfbf02616cc0 Mon Sep 17 00:00:00 2001 From: Forge Date: Wed, 24 Jun 2026 07:10:45 +0000 Subject: [PATCH 05/68] [AISOS-1892] Define workflow stage constants for stats tracking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Detailed description: - Added 10 stage string constants to src/forge/workflow/stats.py: STAGE_PRD, STAGE_SPEC, STAGE_EPICS, STAGE_TASKS, STAGE_IMPLEMENTATION, STAGE_CI, STAGE_REVIEW (Feature workflow) and STAGE_TRIAGE, STAGE_RCA, STAGE_PLANNING (Bug workflow) - Added ALL_FEATURE_STAGES list (PRD → spec → epics → tasks → implementation → CI → review) - Added ALL_BUG_STAGES list (triage → rca → planning → implementation → CI → review) - Added TestStageConstants class to tests/unit/workflow/test_stats.py with 19 new tests covering individual constant values, list types, lengths, ordering, completeness, and import path All 37 tests in test_stats.py pass. Closes: AISOS-1892 --- src/forge/workflow/stats.py | 49 +++++++- tests/unit/workflow/test_stats.py | 190 +++++++++++++++++++++++++++++- 2 files changed, 237 insertions(+), 2 deletions(-) diff --git a/src/forge/workflow/stats.py b/src/forge/workflow/stats.py index 15ac3d15..b72e348c 100644 --- a/src/forge/workflow/stats.py +++ b/src/forge/workflow/stats.py @@ -1,11 +1,58 @@ """Statistics tracking data structures for workflow execution. This module defines the TypedDicts used to capture per-stage metrics and -overall workflow outcome data, as required by SC-001. +overall workflow outcome data, as required by SC-001. It also exports +canonical stage-name constants used by recording and formatting code to +ensure consistency across the codebase. """ from typing import TypedDict +# --------------------------------------------------------------------------- +# Workflow stage constants +# --------------------------------------------------------------------------- +# These string constants are the canonical identifiers for each named stage +# that is tracked in workflow statistics. Use these constants everywhere +# instead of bare strings so that typos are caught at import time. + +# Feature workflow stages +STAGE_PRD = "prd" +STAGE_SPEC = "spec" +STAGE_EPICS = "epics" +STAGE_TASKS = "tasks" +STAGE_IMPLEMENTATION = "implementation" +STAGE_CI = "ci" +STAGE_REVIEW = "review" + +# Bug workflow stages +STAGE_TRIAGE = "triage" +STAGE_RCA = "rca" +STAGE_PLANNING = "planning" + +# Ordered stage lists used by formatting code to display stages in the +# canonical sequence defined by the specification. + +#: Stages for the Feature workflow, in display order. +ALL_FEATURE_STAGES: list[str] = [ + STAGE_PRD, + STAGE_SPEC, + STAGE_EPICS, + STAGE_TASKS, + STAGE_IMPLEMENTATION, + STAGE_CI, + STAGE_REVIEW, +] + +#: Stages for the Bug workflow, in display order. +ALL_BUG_STAGES: list[str] = [ + STAGE_TRIAGE, + STAGE_RCA, + STAGE_PLANNING, + STAGE_IMPLEMENTATION, + STAGE_CI, + STAGE_REVIEW, +] + class StageStats(TypedDict, total=False): """Per-stage execution metrics captured during workflow execution. diff --git a/tests/unit/workflow/test_stats.py b/tests/unit/workflow/test_stats.py index 7f79f4fe..260e04f4 100644 --- a/tests/unit/workflow/test_stats.py +++ b/tests/unit/workflow/test_stats.py @@ -1,4 +1,4 @@ -"""Unit tests for StageStats and StatsState TypedDicts.""" +"""Unit tests for StageStats, StatsState TypedDicts, and stage constants.""" from typing import get_type_hints @@ -215,3 +215,191 @@ def test_stats_state_importable_from_workflow(self): def test_stats_state_importable_from_base(self): """StatsState is importable via forge.workflow.base (re-exported).""" from forge.workflow.base import StatsState # noqa: F401 + + +class TestStageConstants: + """Tests for workflow stage name constants and ordered stage lists.""" + + # ------------------------------------------------------------------ + # Individual constant values + # ------------------------------------------------------------------ + + def test_stage_prd_value(self): + from forge.workflow.stats import STAGE_PRD + + assert STAGE_PRD == "prd" + + def test_stage_spec_value(self): + from forge.workflow.stats import STAGE_SPEC + + assert STAGE_SPEC == "spec" + + def test_stage_epics_value(self): + from forge.workflow.stats import STAGE_EPICS + + assert STAGE_EPICS == "epics" + + def test_stage_tasks_value(self): + from forge.workflow.stats import STAGE_TASKS + + assert STAGE_TASKS == "tasks" + + def test_stage_implementation_value(self): + from forge.workflow.stats import STAGE_IMPLEMENTATION + + assert STAGE_IMPLEMENTATION == "implementation" + + def test_stage_ci_value(self): + from forge.workflow.stats import STAGE_CI + + assert STAGE_CI == "ci" + + def test_stage_review_value(self): + from forge.workflow.stats import STAGE_REVIEW + + assert STAGE_REVIEW == "review" + + def test_stage_rca_value(self): + from forge.workflow.stats import STAGE_RCA + + assert STAGE_RCA == "rca" + + def test_stage_triage_value(self): + from forge.workflow.stats import STAGE_TRIAGE + + assert STAGE_TRIAGE == "triage" + + def test_stage_planning_value(self): + from forge.workflow.stats import STAGE_PLANNING + + assert STAGE_PLANNING == "planning" + + # ------------------------------------------------------------------ + # ALL_FEATURE_STAGES list + # ------------------------------------------------------------------ + + def test_all_feature_stages_is_list(self): + """ALL_FEATURE_STAGES is a list of strings.""" + from forge.workflow.stats import ALL_FEATURE_STAGES + + assert isinstance(ALL_FEATURE_STAGES, list) + assert all(isinstance(s, str) for s in ALL_FEATURE_STAGES) + + def test_all_feature_stages_length(self): + """ALL_FEATURE_STAGES contains exactly 7 stages.""" + from forge.workflow.stats import ALL_FEATURE_STAGES + + assert len(ALL_FEATURE_STAGES) == 7 + + def test_all_feature_stages_order(self): + """ALL_FEATURE_STAGES lists stages in the canonical display order.""" + from forge.workflow.stats import ( + ALL_FEATURE_STAGES, + STAGE_CI, + STAGE_EPICS, + STAGE_IMPLEMENTATION, + STAGE_PRD, + STAGE_REVIEW, + STAGE_SPEC, + STAGE_TASKS, + ) + + assert ALL_FEATURE_STAGES == [ + STAGE_PRD, + STAGE_SPEC, + STAGE_EPICS, + STAGE_TASKS, + STAGE_IMPLEMENTATION, + STAGE_CI, + STAGE_REVIEW, + ] + + def test_all_feature_stages_completeness(self): + """ALL_FEATURE_STAGES contains every expected Feature stage.""" + from forge.workflow.stats import ( + ALL_FEATURE_STAGES, + STAGE_CI, + STAGE_EPICS, + STAGE_IMPLEMENTATION, + STAGE_PRD, + STAGE_REVIEW, + STAGE_SPEC, + STAGE_TASKS, + ) + + expected = {STAGE_PRD, STAGE_SPEC, STAGE_EPICS, STAGE_TASKS, STAGE_IMPLEMENTATION, STAGE_CI, STAGE_REVIEW} + assert set(ALL_FEATURE_STAGES) == expected + + # ------------------------------------------------------------------ + # ALL_BUG_STAGES list + # ------------------------------------------------------------------ + + def test_all_bug_stages_is_list(self): + """ALL_BUG_STAGES is a list of strings.""" + from forge.workflow.stats import ALL_BUG_STAGES + + assert isinstance(ALL_BUG_STAGES, list) + assert all(isinstance(s, str) for s in ALL_BUG_STAGES) + + def test_all_bug_stages_length(self): + """ALL_BUG_STAGES contains exactly 6 stages.""" + from forge.workflow.stats import ALL_BUG_STAGES + + assert len(ALL_BUG_STAGES) == 6 + + def test_all_bug_stages_order(self): + """ALL_BUG_STAGES lists stages in the canonical display order.""" + from forge.workflow.stats import ( + ALL_BUG_STAGES, + STAGE_CI, + STAGE_IMPLEMENTATION, + STAGE_PLANNING, + STAGE_RCA, + STAGE_REVIEW, + STAGE_TRIAGE, + ) + + assert ALL_BUG_STAGES == [ + STAGE_TRIAGE, + STAGE_RCA, + STAGE_PLANNING, + STAGE_IMPLEMENTATION, + STAGE_CI, + STAGE_REVIEW, + ] + + def test_all_bug_stages_completeness(self): + """ALL_BUG_STAGES contains every expected Bug stage.""" + from forge.workflow.stats import ( + ALL_BUG_STAGES, + STAGE_CI, + STAGE_IMPLEMENTATION, + STAGE_PLANNING, + STAGE_RCA, + STAGE_REVIEW, + STAGE_TRIAGE, + ) + + expected = {STAGE_TRIAGE, STAGE_RCA, STAGE_PLANNING, STAGE_IMPLEMENTATION, STAGE_CI, STAGE_REVIEW} + assert set(ALL_BUG_STAGES) == expected + + # ------------------------------------------------------------------ + # Export verification + # ------------------------------------------------------------------ + + def test_constants_importable_from_stats_module(self): + """All stage constants and lists are importable from forge.workflow.stats.""" + from forge.workflow.stats import ( # noqa: F401 + ALL_BUG_STAGES, + ALL_FEATURE_STAGES, + STAGE_CI, + STAGE_EPICS, + STAGE_IMPLEMENTATION, + STAGE_PLANNING, + STAGE_PRD, + STAGE_RCA, + STAGE_REVIEW, + STAGE_SPEC, + STAGE_TASKS, + STAGE_TRIAGE, + ) From ab3e0127516252c1a86b21ce5e3e0d418f408825 Mon Sep 17 00:00:00 2001 From: Forge Date: Wed, 24 Jun 2026 07:31:14 +0000 Subject: [PATCH 06/68] [AISOS-1893] Integrate stats recording into PRD and Spec generation nodes Detailed description: - prd_generation.py: Added record_stage_start at entry, record_tokens after LLM call (estimated from content length), increment_revision when regenerating from feedback, and record_stage_end with wall-clock machine time at all exit paths (success, early-return, exception) - spec_generation.py: Same instrumentation pattern using STAGE_SPEC - Both nodes use _estimate_tokens() helper (~4 chars/token) since the ForgeAgent interface returns plain strings without token metadata - Added tests/unit/workflow/nodes/test_prd_spec_stats.py with 26 tests covering all acceptance criteria (stage_start, tokens, revision increment, stage_end) for both generate and regenerate functions Closes: AISOS-1893 --- src/forge/workflow/nodes/prd_generation.py | 51 ++ src/forge/workflow/nodes/spec_generation.py | 51 ++ .../workflow/nodes/test_prd_spec_stats.py | 738 ++++++++++++++++++ 3 files changed, 840 insertions(+) create mode 100644 tests/unit/workflow/nodes/test_prd_spec_stats.py diff --git a/src/forge/workflow/nodes/prd_generation.py b/src/forge/workflow/nodes/prd_generation.py index 2b4a0529..5d0f2fc3 100644 --- a/src/forge/workflow/nodes/prd_generation.py +++ b/src/forge/workflow/nodes/prd_generation.py @@ -2,6 +2,7 @@ import logging import re +import time from datetime import UTC, datetime from typing import Any @@ -12,12 +13,24 @@ from forge.models.workflow import ForgeLabel from forge.orchestrator.checkpointer import set_pr_ticket_index from forge.workflow.feature.state import FeatureState as WorkflowState +from forge.workflow.stats import STAGE_PRD +from forge.workflow.stats_utils import ( + increment_revision, + record_stage_end, + record_stage_start, + record_tokens, +) from forge.workflow.utils import update_state_timestamp from forge.workflow.utils.jira_status import post_status_comment logger = logging.getLogger(__name__) +def _estimate_tokens(text: str) -> int: + """Estimate token count from text length (approx. 4 chars per token).""" + return max(1, len(text) // 4) + + def _slugify(text: str, max_length: int = 60) -> str: """Convert text to URL-safe slug.""" slug = text.lower().strip() @@ -167,6 +180,10 @@ async def generate_prd(state: WorkflowState) -> WorkflowState: ticket_key = state["ticket_key"] logger.info(f"Generating PRD for {ticket_key}") + # Record stage start and begin timing + state = {**state, **record_stage_start(state, STAGE_PRD)} + node_start = time.monotonic() + jira = JiraClient() agent = ForgeAgent() prd_content = None @@ -185,8 +202,11 @@ async def generate_prd(state: WorkflowState) -> WorkflowState: if not raw_requirements.strip(): logger.warning(f"No description found for {ticket_key}") + machine_time = time.monotonic() - node_start + end_stats = record_stage_end(state, STAGE_PRD, machine_time) return { **state, + **end_stats, "last_error": "No requirements found in issue description", "current_node": "generate_prd", } @@ -206,6 +226,11 @@ async def generate_prd(state: WorkflowState) -> WorkflowState: # Generate PRD using Claude - primary operation prd_content = await agent.generate_prd(raw_requirements, context) + # Record token usage (estimated from content length) + input_tokens = _estimate_tokens(raw_requirements) + output_tokens = _estimate_tokens(prd_content) + state = {**state, **record_tokens(state, STAGE_PRD, input_tokens, output_tokens)} + # Publish PRD - either as GitHub PR or Jira update # Per-project opt-in: check forge.prd_proposals_repo project property proposals_repo = await _resolve_prd_proposals_repo(issue.project_key, jira) @@ -244,10 +269,15 @@ async def generate_prd(state: WorkflowState) -> WorkflowState: "generated_at": datetime.now(UTC).isoformat(), } + # Record stage end with elapsed wall-clock time + machine_time = time.monotonic() - node_start + end_stats = record_stage_end(state, STAGE_PRD, machine_time) + # If publish failed, set a warning but still advance (content exists) result = update_state_timestamp( { **state, + **end_stats, "prd_content": prd_content, "generation_context": generation_context, "current_node": "prd_approval_gate", @@ -264,8 +294,11 @@ async def generate_prd(state: WorkflowState) -> WorkflowState: await notify_error(state, str(e), "generate_prd") # If we have partial content, save it even on failure + machine_time = time.monotonic() - node_start + end_stats = record_stage_end(state, STAGE_PRD, machine_time) result_state = { **state, + **end_stats, "last_error": str(e), "current_node": "generate_prd", "retry_count": state.get("retry_count", 0) + 1, @@ -301,6 +334,11 @@ async def regenerate_prd_with_feedback(state: WorkflowState) -> WorkflowState: logger.info(f"Regenerating PRD for {ticket_key} with feedback") + # Record stage re-entry: start timer, increment revision count + state = {**state, **record_stage_start(state, STAGE_PRD)} + state = {**state, **increment_revision(state, STAGE_PRD)} + node_start = time.monotonic() + jira = JiraClient() agent = ForgeAgent() @@ -320,6 +358,11 @@ async def regenerate_prd_with_feedback(state: WorkflowState) -> WorkflowState: }, ) + # Record token usage (estimated from content length) + input_tokens = _estimate_tokens(original_prd) + _estimate_tokens(feedback) + output_tokens = _estimate_tokens(new_prd) + state = {**state, **record_tokens(state, STAGE_PRD, input_tokens, output_tokens)} + # Publish revised PRD if state.get("prd_pr_number"): await _update_prd_proposal_pr(ticket_key, new_prd, state) @@ -341,9 +384,14 @@ async def regenerate_prd_with_feedback(state: WorkflowState) -> WorkflowState: logger.info(f"PRD regenerated for {ticket_key} ({len(new_prd)} chars)") + # Record stage end with elapsed wall-clock time + machine_time = time.monotonic() - node_start + end_stats = record_stage_end(state, STAGE_PRD, machine_time) + return update_state_timestamp( { **state, + **end_stats, "prd_content": new_prd, "feedback_comment": None, "revision_requested": False, @@ -357,8 +405,11 @@ async def regenerate_prd_with_feedback(state: WorkflowState) -> WorkflowState: from forge.workflow.nodes.error_handler import notify_error await notify_error(state, str(e), "regenerate_prd") + machine_time = time.monotonic() - node_start + end_stats = record_stage_end(state, STAGE_PRD, machine_time) return { **state, + **end_stats, "last_error": str(e), "current_node": "regenerate_prd", "retry_count": state.get("retry_count", 0) + 1, diff --git a/src/forge/workflow/nodes/spec_generation.py b/src/forge/workflow/nodes/spec_generation.py index 40b14583..cd070540 100644 --- a/src/forge/workflow/nodes/spec_generation.py +++ b/src/forge/workflow/nodes/spec_generation.py @@ -1,6 +1,7 @@ """Specification generation node for LangGraph workflow.""" import logging +import time from datetime import UTC, datetime from typing import Any @@ -9,6 +10,13 @@ from forge.integrations.jira.client import JiraClient from forge.models.workflow import ForgeLabel from forge.workflow.feature.state import FeatureState as WorkflowState +from forge.workflow.stats import STAGE_SPEC +from forge.workflow.stats_utils import ( + increment_revision, + record_stage_end, + record_stage_start, + record_tokens, +) from forge.workflow.utils import update_state_timestamp from forge.workflow.utils.jira_status import post_status_comment from forge.workflow.utils.qa_summary import post_qa_summary_if_needed @@ -16,6 +24,11 @@ logger = logging.getLogger(__name__) +def _estimate_tokens(text: str) -> int: + """Estimate token count from text length (approx. 4 chars per token).""" + return max(1, len(text) // 4) + + async def generate_spec(state: WorkflowState) -> WorkflowState: """Generate a behavioral specification from the approved PRD. @@ -36,6 +49,10 @@ async def generate_spec(state: WorkflowState) -> WorkflowState: logger.info(f"Generating specification for {ticket_key}") + # Record stage start and begin timing + state = {**state, **record_stage_start(state, STAGE_SPEC)} + node_start = time.monotonic() + # Post Q&A summary for PRD if any qa_history = state.get("qa_history", []) if qa_history: @@ -60,8 +77,11 @@ async def generate_spec(state: WorkflowState) -> WorkflowState: if not prd_content.strip(): logger.warning(f"No PRD content found for {ticket_key}") + machine_time = time.monotonic() - node_start + end_stats = record_stage_end(state, STAGE_SPEC, machine_time) return { **state, + **end_stats, "last_error": "No PRD content available for spec generation", "current_node": "generate_spec", } @@ -79,6 +99,11 @@ async def generate_spec(state: WorkflowState) -> WorkflowState: # Generate specification using Claude - primary operation spec_content = await agent.generate_spec(prd_content, context) + # Record token usage (estimated from content length) + input_tokens = _estimate_tokens(prd_content) + output_tokens = _estimate_tokens(spec_content) + state = {**state, **record_tokens(state, STAGE_SPEC, input_tokens, output_tokens)} + # Store spec in Jira - secondary operation try: settings = get_settings() @@ -120,9 +145,14 @@ async def generate_spec(state: WorkflowState) -> WorkflowState: "generated_at": datetime.now(UTC).isoformat(), } + # Record stage end with elapsed wall-clock time + machine_time = time.monotonic() - node_start + end_stats = record_stage_end(state, STAGE_SPEC, machine_time) + return update_state_timestamp( { **state, + **end_stats, "spec_content": spec_content, "generation_context": generation_context, "current_node": "spec_approval_gate", @@ -136,8 +166,11 @@ async def generate_spec(state: WorkflowState) -> WorkflowState: await notify_error(state, str(e), "generate_spec") # If we have partial content, save it even on failure + machine_time = time.monotonic() - node_start + end_stats = record_stage_end(state, STAGE_SPEC, machine_time) result_state = { **state, + **end_stats, "last_error": str(e), "current_node": "generate_spec", "retry_count": state.get("retry_count", 0) + 1, @@ -169,6 +202,11 @@ async def regenerate_spec_with_feedback(state: WorkflowState) -> WorkflowState: logger.info(f"Regenerating spec for {ticket_key} with feedback") + # Record stage re-entry: start timer, increment revision count + state = {**state, **record_stage_start(state, STAGE_SPEC)} + state = {**state, **increment_revision(state, STAGE_SPEC)} + node_start = time.monotonic() + jira = JiraClient() agent = ForgeAgent() @@ -188,6 +226,11 @@ async def regenerate_spec_with_feedback(state: WorkflowState) -> WorkflowState: }, ) + # Record token usage (estimated from content length) + input_tokens = _estimate_tokens(original_spec) + _estimate_tokens(feedback) + output_tokens = _estimate_tokens(new_spec) + state = {**state, **record_tokens(state, STAGE_SPEC, input_tokens, output_tokens)} + # Store updated spec in Jira (comment or custom field based on config) settings = get_settings() if settings.jira_store_in_comments: @@ -225,9 +268,14 @@ async def regenerate_spec_with_feedback(state: WorkflowState) -> WorkflowState: logger.info(f"Spec regenerated for {ticket_key} ({len(new_spec)} chars)") + # Record stage end with elapsed wall-clock time + machine_time = time.monotonic() - node_start + end_stats = record_stage_end(state, STAGE_SPEC, machine_time) + return update_state_timestamp( { **state, + **end_stats, "spec_content": new_spec, "feedback_comment": None, "revision_requested": False, @@ -241,8 +289,11 @@ async def regenerate_spec_with_feedback(state: WorkflowState) -> WorkflowState: from forge.workflow.nodes.error_handler import notify_error await notify_error(state, str(e), "regenerate_spec") + machine_time = time.monotonic() - node_start + end_stats = record_stage_end(state, STAGE_SPEC, machine_time) return { **state, + **end_stats, "last_error": str(e), "current_node": "regenerate_spec", "retry_count": state.get("retry_count", 0) + 1, diff --git a/tests/unit/workflow/nodes/test_prd_spec_stats.py b/tests/unit/workflow/nodes/test_prd_spec_stats.py new file mode 100644 index 00000000..807bf30d --- /dev/null +++ b/tests/unit/workflow/nodes/test_prd_spec_stats.py @@ -0,0 +1,738 @@ +"""Unit tests for stats recording in PRD and Spec generation nodes.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from forge.models.workflow import TicketType +from forge.workflow.feature.state import create_initial_feature_state +from forge.workflow.stats import STAGE_PRD, STAGE_SPEC + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def create_mock_jira( + description: str = "Raw requirements text", + summary: str = "Test Feature", + project_key: str = "TEST", +) -> MagicMock: + """Return a JiraClient mock with default async methods.""" + mock = MagicMock() + mock.close = AsyncMock() + mock.update_description = AsyncMock() + mock.add_structured_comment = AsyncMock() + mock.set_workflow_label = AsyncMock() + mock.get_prd_proposals_repo = AsyncMock(return_value=None) + mock.add_comment = AsyncMock() + mock.get_issue = AsyncMock( + return_value=MagicMock( + summary=summary, + description=description, + project_key=project_key, + ) + ) + return mock + + +def create_mock_agent( + prd_content: str = "# Generated PRD\n\nContent here.", + spec_content: str = "# Generated Spec\n\nAcceptance criteria here.", +) -> MagicMock: + """Return a ForgeAgent mock with default async methods.""" + mock = MagicMock() + mock.close = AsyncMock() + mock.generate_prd = AsyncMock(return_value=prd_content) + mock.generate_spec = AsyncMock(return_value=spec_content) + mock.regenerate_with_feedback = AsyncMock(return_value="# Revised content") + return mock + + +def _get_stage(result: dict, stage_name: str) -> dict: + """Extract a stage entry from result state, or {} if absent.""" + return (result.get("stats_stages") or {}).get(stage_name, {}) + + +# --------------------------------------------------------------------------- +# PRD generation stats tests +# --------------------------------------------------------------------------- + + +class TestGeneratePrdStatsRecording: + """Tests for stats recording in generate_prd node.""" + + @pytest.mark.asyncio + async def test_records_stage_start_on_entry(self): + """generate_prd should initialise the PRD stage with a started_at timestamp.""" + from forge.workflow.nodes.prd_generation import generate_prd + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + ) + + with ( + patch("forge.workflow.nodes.prd_generation.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.prd_generation.ForgeAgent", return_value=mock_agent), + patch( + "forge.workflow.nodes.prd_generation.post_status_comment", + new_callable=AsyncMock, + ), + ): + result = await generate_prd(state) + + stage = _get_stage(result, STAGE_PRD) + assert stage, "stats_stages[STAGE_PRD] should be populated" + assert stage.get("started_at") is not None, "started_at must be set" + + @pytest.mark.asyncio + async def test_records_stage_end_with_machine_time(self): + """generate_prd should populate ended_at and positive machine_time_seconds.""" + from forge.workflow.nodes.prd_generation import generate_prd + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + ) + + with ( + patch("forge.workflow.nodes.prd_generation.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.prd_generation.ForgeAgent", return_value=mock_agent), + patch( + "forge.workflow.nodes.prd_generation.post_status_comment", + new_callable=AsyncMock, + ), + ): + result = await generate_prd(state) + + stage = _get_stage(result, STAGE_PRD) + assert stage.get("ended_at") is not None, "ended_at must be set on success" + assert stage.get("machine_time_seconds", 0.0) >= 0.0, "machine_time must be non-negative" + + @pytest.mark.asyncio + async def test_records_tokens_from_llm_response(self): + """generate_prd should record non-zero token counts after LLM call.""" + from forge.workflow.nodes.prd_generation import generate_prd + + mock_jira = create_mock_jira(description="A" * 400) # 100 estimated tokens + mock_agent = create_mock_agent(prd_content="B" * 800) # 200 estimated tokens + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + ) + + with ( + patch("forge.workflow.nodes.prd_generation.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.prd_generation.ForgeAgent", return_value=mock_agent), + patch( + "forge.workflow.nodes.prd_generation.post_status_comment", + new_callable=AsyncMock, + ), + ): + result = await generate_prd(state) + + stage = _get_stage(result, STAGE_PRD) + assert stage.get("input_tokens", 0) > 0, "input_tokens should be positive" + assert stage.get("output_tokens", 0) > 0, "output_tokens should be positive" + + @pytest.mark.asyncio + async def test_stats_recorded_on_missing_requirements(self): + """generate_prd should record stage_end even when requirements are empty.""" + from forge.workflow.nodes.prd_generation import generate_prd + + mock_jira = create_mock_jira(description="") + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + ) + + with ( + patch("forge.workflow.nodes.prd_generation.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.prd_generation.ForgeAgent", return_value=mock_agent), + patch( + "forge.workflow.nodes.prd_generation.post_status_comment", + new_callable=AsyncMock, + ), + ): + result = await generate_prd(state) + + stage = _get_stage(result, STAGE_PRD) + assert stage.get("started_at") is not None + assert stage.get("ended_at") is not None + + @pytest.mark.asyncio + async def test_stats_recorded_on_exception(self): + """generate_prd should record stage_end even when an exception is raised.""" + from forge.workflow.nodes.prd_generation import generate_prd + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + mock_agent.generate_prd = AsyncMock(side_effect=RuntimeError("LLM failure")) + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + ) + + with ( + patch("forge.workflow.nodes.prd_generation.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.prd_generation.ForgeAgent", return_value=mock_agent), + patch( + "forge.workflow.nodes.prd_generation.post_status_comment", + new_callable=AsyncMock, + ), + patch( + "forge.workflow.nodes.error_handler.notify_error", + new_callable=AsyncMock, + ), + ): + result = await generate_prd(state) + + stage = _get_stage(result, STAGE_PRD) + assert stage.get("started_at") is not None + assert stage.get("ended_at") is not None + assert result.get("last_error") is not None + + +# --------------------------------------------------------------------------- +# PRD regeneration stats tests +# --------------------------------------------------------------------------- + + +class TestRegeneratePrdStatsRecording: + """Tests for stats recording in regenerate_prd_with_feedback node.""" + + @pytest.mark.asyncio + async def test_increments_revision_on_feedback(self): + """regenerate_prd_with_feedback should increment iteration_count by 1.""" + from forge.workflow.nodes.prd_generation import regenerate_prd_with_feedback + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + prd_content="# Original PRD", + feedback_comment="! Please add more detail about authentication", + ) + + with ( + patch( + "forge.workflow.nodes.prd_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.prd_generation.ForgeAgent", + return_value=mock_agent, + ), + ): + result = await regenerate_prd_with_feedback(state) + + stage = _get_stage(result, STAGE_PRD) + assert stage.get("iteration_count", 0) >= 1, "iteration_count must be incremented" + + @pytest.mark.asyncio + async def test_records_stage_start_on_feedback(self): + """regenerate_prd_with_feedback should set started_at on re-entry.""" + from forge.workflow.nodes.prd_generation import regenerate_prd_with_feedback + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + prd_content="# Original PRD", + feedback_comment="! Needs more detail", + ) + + with ( + patch( + "forge.workflow.nodes.prd_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.prd_generation.ForgeAgent", + return_value=mock_agent, + ), + ): + result = await regenerate_prd_with_feedback(state) + + stage = _get_stage(result, STAGE_PRD) + assert stage.get("started_at") is not None + + @pytest.mark.asyncio + async def test_records_stage_end_on_feedback(self): + """regenerate_prd_with_feedback should record ended_at and machine_time.""" + from forge.workflow.nodes.prd_generation import regenerate_prd_with_feedback + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + prd_content="# Original PRD", + feedback_comment="! Add more context", + ) + + with ( + patch( + "forge.workflow.nodes.prd_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.prd_generation.ForgeAgent", + return_value=mock_agent, + ), + ): + result = await regenerate_prd_with_feedback(state) + + stage = _get_stage(result, STAGE_PRD) + assert stage.get("ended_at") is not None + assert stage.get("machine_time_seconds", 0.0) >= 0.0 + + @pytest.mark.asyncio + async def test_records_tokens_on_feedback(self): + """regenerate_prd_with_feedback should record tokens for the revision.""" + from forge.workflow.nodes.prd_generation import regenerate_prd_with_feedback + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + mock_agent.regenerate_with_feedback = AsyncMock(return_value="D" * 800) + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + prd_content="C" * 400, + feedback_comment="! " + "E" * 40, + ) + + with ( + patch( + "forge.workflow.nodes.prd_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.prd_generation.ForgeAgent", + return_value=mock_agent, + ), + ): + result = await regenerate_prd_with_feedback(state) + + stage = _get_stage(result, STAGE_PRD) + assert stage.get("input_tokens", 0) > 0 + assert stage.get("output_tokens", 0) > 0 + + @pytest.mark.asyncio + async def test_no_feedback_returns_unchanged_state(self): + """regenerate_prd_with_feedback with no feedback should return state unchanged.""" + from forge.workflow.nodes.prd_generation import regenerate_prd_with_feedback + + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + prd_content="# Original PRD", + ) + + result = await regenerate_prd_with_feedback(state) + + # State returned unchanged — no stats_stages mutation + assert result is state + + @pytest.mark.asyncio + async def test_stats_recorded_on_exception(self): + """regenerate_prd_with_feedback records stage_end even on exception.""" + from forge.workflow.nodes.prd_generation import regenerate_prd_with_feedback + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + mock_agent.regenerate_with_feedback = AsyncMock(side_effect=RuntimeError("API error")) + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + prd_content="# Original PRD", + feedback_comment="! Add more detail", + ) + + with ( + patch( + "forge.workflow.nodes.prd_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.prd_generation.ForgeAgent", + return_value=mock_agent, + ), + patch( + "forge.workflow.nodes.error_handler.notify_error", + new_callable=AsyncMock, + ), + ): + result = await regenerate_prd_with_feedback(state) + + stage = _get_stage(result, STAGE_PRD) + assert stage.get("ended_at") is not None + assert result.get("last_error") is not None + + +# --------------------------------------------------------------------------- +# Spec generation stats tests +# --------------------------------------------------------------------------- + + +class TestGenerateSpecStatsRecording: + """Tests for stats recording in generate_spec node.""" + + @pytest.mark.asyncio + async def test_records_stage_start_on_entry(self): + """generate_spec should initialise the SPEC stage with a started_at timestamp.""" + from forge.workflow.nodes.spec_generation import generate_spec + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + prd_content="# Approved PRD", + ) + + with ( + patch("forge.workflow.nodes.spec_generation.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.spec_generation.ForgeAgent", return_value=mock_agent), + patch( + "forge.workflow.nodes.spec_generation.post_status_comment", + new_callable=AsyncMock, + ), + ): + result = await generate_spec(state) + + stage = _get_stage(result, STAGE_SPEC) + assert stage, "stats_stages[STAGE_SPEC] should be populated" + assert stage.get("started_at") is not None + + @pytest.mark.asyncio + async def test_records_stage_end_with_machine_time(self): + """generate_spec should populate ended_at and machine_time_seconds.""" + from forge.workflow.nodes.spec_generation import generate_spec + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + prd_content="# Approved PRD", + ) + + with ( + patch("forge.workflow.nodes.spec_generation.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.spec_generation.ForgeAgent", return_value=mock_agent), + patch( + "forge.workflow.nodes.spec_generation.post_status_comment", + new_callable=AsyncMock, + ), + ): + result = await generate_spec(state) + + stage = _get_stage(result, STAGE_SPEC) + assert stage.get("ended_at") is not None + assert stage.get("machine_time_seconds", 0.0) >= 0.0 + + @pytest.mark.asyncio + async def test_records_tokens_from_llm_response(self): + """generate_spec should record non-zero token counts after LLM call.""" + from forge.workflow.nodes.spec_generation import generate_spec + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent(spec_content="F" * 800) + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + prd_content="G" * 400, + ) + + with ( + patch("forge.workflow.nodes.spec_generation.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.spec_generation.ForgeAgent", return_value=mock_agent), + patch( + "forge.workflow.nodes.spec_generation.post_status_comment", + new_callable=AsyncMock, + ), + ): + result = await generate_spec(state) + + stage = _get_stage(result, STAGE_SPEC) + assert stage.get("input_tokens", 0) > 0 + assert stage.get("output_tokens", 0) > 0 + + @pytest.mark.asyncio + async def test_stats_recorded_on_missing_prd(self): + """generate_spec should record stage_end even when PRD content is empty.""" + from forge.workflow.nodes.spec_generation import generate_spec + + # No prd_content in state, and Jira returns empty description + mock_jira = create_mock_jira(description="") + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + ) + + with ( + patch("forge.workflow.nodes.spec_generation.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.spec_generation.ForgeAgent", return_value=mock_agent), + patch( + "forge.workflow.nodes.spec_generation.post_status_comment", + new_callable=AsyncMock, + ), + ): + result = await generate_spec(state) + + stage = _get_stage(result, STAGE_SPEC) + assert stage.get("started_at") is not None + assert stage.get("ended_at") is not None + + @pytest.mark.asyncio + async def test_stats_recorded_on_exception(self): + """generate_spec should record stage_end even when an exception is raised.""" + from forge.workflow.nodes.spec_generation import generate_spec + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + mock_agent.generate_spec = AsyncMock(side_effect=RuntimeError("Spec LLM failure")) + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + prd_content="# Approved PRD", + ) + + with ( + patch("forge.workflow.nodes.spec_generation.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.spec_generation.ForgeAgent", return_value=mock_agent), + patch( + "forge.workflow.nodes.spec_generation.post_status_comment", + new_callable=AsyncMock, + ), + patch( + "forge.workflow.nodes.error_handler.notify_error", + new_callable=AsyncMock, + ), + ): + result = await generate_spec(state) + + stage = _get_stage(result, STAGE_SPEC) + assert stage.get("started_at") is not None + assert stage.get("ended_at") is not None + assert result.get("last_error") is not None + + +# --------------------------------------------------------------------------- +# Spec regeneration stats tests +# --------------------------------------------------------------------------- + + +class TestRegenerateSpecStatsRecording: + """Tests for stats recording in regenerate_spec_with_feedback node.""" + + @pytest.mark.asyncio + async def test_increments_revision_on_feedback(self): + """regenerate_spec_with_feedback should increment iteration_count.""" + from forge.workflow.nodes.spec_generation import regenerate_spec_with_feedback + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + spec_content="# Original Spec", + feedback_comment="! Please add more Given/When/Then scenarios", + ) + + with ( + patch( + "forge.workflow.nodes.spec_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.spec_generation.ForgeAgent", + return_value=mock_agent, + ), + ): + result = await regenerate_spec_with_feedback(state) + + stage = _get_stage(result, STAGE_SPEC) + assert stage.get("iteration_count", 0) >= 1 + + @pytest.mark.asyncio + async def test_records_stage_start_on_feedback(self): + """regenerate_spec_with_feedback should set started_at on re-entry.""" + from forge.workflow.nodes.spec_generation import regenerate_spec_with_feedback + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + spec_content="# Original Spec", + feedback_comment="! Needs more detail", + ) + + with ( + patch( + "forge.workflow.nodes.spec_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.spec_generation.ForgeAgent", + return_value=mock_agent, + ), + ): + result = await regenerate_spec_with_feedback(state) + + stage = _get_stage(result, STAGE_SPEC) + assert stage.get("started_at") is not None + + @pytest.mark.asyncio + async def test_records_stage_end_on_feedback(self): + """regenerate_spec_with_feedback should record ended_at and machine_time.""" + from forge.workflow.nodes.spec_generation import regenerate_spec_with_feedback + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + spec_content="# Original Spec", + feedback_comment="! Add edge cases", + ) + + with ( + patch( + "forge.workflow.nodes.spec_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.spec_generation.ForgeAgent", + return_value=mock_agent, + ), + ): + result = await regenerate_spec_with_feedback(state) + + stage = _get_stage(result, STAGE_SPEC) + assert stage.get("ended_at") is not None + assert stage.get("machine_time_seconds", 0.0) >= 0.0 + + @pytest.mark.asyncio + async def test_records_tokens_on_feedback(self): + """regenerate_spec_with_feedback should record tokens for the revision.""" + from forge.workflow.nodes.spec_generation import regenerate_spec_with_feedback + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + mock_agent.regenerate_with_feedback = AsyncMock(return_value="H" * 800) + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + spec_content="I" * 400, + feedback_comment="! " + "J" * 40, + ) + + with ( + patch( + "forge.workflow.nodes.spec_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.spec_generation.ForgeAgent", + return_value=mock_agent, + ), + ): + result = await regenerate_spec_with_feedback(state) + + stage = _get_stage(result, STAGE_SPEC) + assert stage.get("input_tokens", 0) > 0 + assert stage.get("output_tokens", 0) > 0 + + @pytest.mark.asyncio + async def test_no_feedback_returns_unchanged_state(self): + """regenerate_spec_with_feedback with no feedback should return state unchanged.""" + from forge.workflow.nodes.spec_generation import regenerate_spec_with_feedback + + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + spec_content="# Original Spec", + ) + + result = await regenerate_spec_with_feedback(state) + + assert result is state + + @pytest.mark.asyncio + async def test_stats_recorded_on_exception(self): + """regenerate_spec_with_feedback records stage_end even on exception.""" + from forge.workflow.nodes.spec_generation import regenerate_spec_with_feedback + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + mock_agent.regenerate_with_feedback = AsyncMock(side_effect=RuntimeError("API error")) + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + spec_content="# Original Spec", + feedback_comment="! Add more detail", + ) + + with ( + patch( + "forge.workflow.nodes.spec_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.spec_generation.ForgeAgent", + return_value=mock_agent, + ), + patch( + "forge.workflow.nodes.error_handler.notify_error", + new_callable=AsyncMock, + ), + ): + result = await regenerate_spec_with_feedback(state) + + stage = _get_stage(result, STAGE_SPEC) + assert stage.get("ended_at") is not None + assert result.get("last_error") is not None + + +# --------------------------------------------------------------------------- +# Token estimation helper tests +# --------------------------------------------------------------------------- + + +class TestEstimateTokens: + """Tests for the _estimate_tokens helper.""" + + def test_empty_string_returns_one(self): + from forge.workflow.nodes.prd_generation import _estimate_tokens + + assert _estimate_tokens("") == 1 + + def test_four_chars_returns_one(self): + from forge.workflow.nodes.prd_generation import _estimate_tokens + + assert _estimate_tokens("abcd") == 1 + + def test_estimate_scales_with_length(self): + from forge.workflow.nodes.prd_generation import _estimate_tokens + + assert _estimate_tokens("a" * 400) == 100 + + def test_spec_module_helper_matches(self): + from forge.workflow.nodes.prd_generation import _estimate_tokens as prd_est + from forge.workflow.nodes.spec_generation import _estimate_tokens as spec_est + + text = "Hello world test" + assert prd_est(text) == spec_est(text) From 93f3b492fb664dbc765ce084e0b5ae7e2ffc9870 Mon Sep 17 00:00:00 2001 From: Forge Date: Wed, 24 Jun 2026 07:39:37 +0000 Subject: [PATCH 07/68] [AISOS-1894] Implement Stats Summary Formatter Module MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Detailed description: - Converted src/forge/workflow/stats.py to a package (src/forge/workflow/stats/__init__.py) so that formatter.py can live under the stats/ namespace; all existing imports (forge.workflow.stats.StatsState etc.) continue to work without changes. - Created src/forge/workflow/stats/formatter.py with the public format_stats_summary(stats, outcome, outcome_detail=None) -> str function that transforms StatsState data into Jira wiki markup: * Stage metrics table (||Stage||Iterations||Machine Time||Human Time||Input Tokens||Output Tokens||) * One row per feature stage using ALL_FEATURE_STAGES; unexecuted stages show em-dash (—) not zeros * Aggregate token totals row (*Total* row with bold input/output sums) * PR links section (omitted when stats_pr_urls is empty) * CI Cycles field * Outcome field (Completed / Blocked: / Failed: ) * Outcome/block/failure reasons truncated at 200 chars with '...' suffix - Created tests/unit/workflow/stats/test_formatter.py with 64 unit tests achieving 100% branch coverage across all helpers and the public API. Closes: AISOS-1894 --- .../workflow/{stats.py => stats/__init__.py} | 0 src/forge/workflow/stats/formatter.py | 197 ++++++++ tests/unit/workflow/stats/__init__.py | 0 tests/unit/workflow/stats/test_formatter.py | 444 ++++++++++++++++++ 4 files changed, 641 insertions(+) rename src/forge/workflow/{stats.py => stats/__init__.py} (100%) create mode 100644 src/forge/workflow/stats/formatter.py create mode 100644 tests/unit/workflow/stats/__init__.py create mode 100644 tests/unit/workflow/stats/test_formatter.py diff --git a/src/forge/workflow/stats.py b/src/forge/workflow/stats/__init__.py similarity index 100% rename from src/forge/workflow/stats.py rename to src/forge/workflow/stats/__init__.py diff --git a/src/forge/workflow/stats/formatter.py b/src/forge/workflow/stats/formatter.py new file mode 100644 index 00000000..47ca2e94 --- /dev/null +++ b/src/forge/workflow/stats/formatter.py @@ -0,0 +1,197 @@ +"""Jira wiki markup formatter for workflow statistics summaries. + +This module transforms StatsState data into Jira wiki markup suitable for +posting as a comment on the associated Jira ticket at the end of a workflow run. +""" + +from forge.workflow.stats import ( + ALL_FEATURE_STAGES, + StageStats, + StatsState, +) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +#: Maximum length for outcome_detail before truncation. +_MAX_DETAIL_LEN = 200 + +#: Display labels for each stage key, in the order they appear in the table. +_STAGE_LABELS: dict[str, str] = { + "prd": "PRD", + "spec": "Spec", + "epics": "Epics", + "tasks": "Tasks", + "implementation": "Implementation", + "ci": "CI", + "review": "Review", + # Bug workflow stages (if needed in future extensions) + "triage": "Triage", + "rca": "RCA", + "planning": "Planning", +} + +#: Em-dash used when a stage was never executed. +_DASH = "\u2014" + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _truncate(text: str, max_len: int = _MAX_DETAIL_LEN) -> str: + """Return *text* truncated to *max_len* characters with '...' suffix. + + If *text* is already within the limit it is returned unchanged. + """ + if len(text) <= max_len: + return text + return text[:max_len] + "..." + + +def _fmt_seconds(seconds: float) -> str: + """Format a duration in seconds to a human-readable string (e.g. '1h 23m 45s').""" + total = int(seconds) + hours, remainder = divmod(total, 3600) + minutes, secs = divmod(remainder, 60) + if hours: + return f"{hours}h {minutes}m {secs}s" + if minutes: + return f"{minutes}m {secs}s" + return f"{secs}s" + + +def _fmt_tokens(count: int) -> str: + """Format a token count with thousands separators.""" + return f"{count:,}" + + +def _build_stage_row(label: str, stage: StageStats | None) -> str: + """Return a single Jira table row for a workflow stage. + + If *stage* is None (never executed), all metric columns show '—'. + """ + if stage is None: + return f"|{label}|{_DASH}|{_DASH}|{_DASH}|{_DASH}|{_DASH}|" + + iterations = stage.get("iteration_count", 0) + machine_time = _fmt_seconds(stage.get("machine_time_seconds", 0.0)) + human_time = _fmt_seconds(stage.get("human_time_seconds", 0.0)) + input_tok = _fmt_tokens(stage.get("input_tokens", 0)) + output_tok = _fmt_tokens(stage.get("output_tokens", 0)) + + return f"|{label}|{iterations}|{machine_time}|{human_time}|{input_tok}|{output_tok}|" + + +def _build_totals_row(stages: dict[str, StageStats]) -> str: + """Return the aggregate token totals row summed across all stages.""" + total_input = sum(s.get("input_tokens", 0) for s in stages.values()) + total_output = sum(s.get("output_tokens", 0) for s in stages.values()) + return f"|*Total*|—|—|—|*{_fmt_tokens(total_input)}*|*{_fmt_tokens(total_output)}*|" + + +def _build_outcome_str(outcome: str, outcome_detail: str | None) -> str: + """Construct the formatted outcome string for display. + + Supported outcome values: + ``"completed"`` → ``"Completed"`` + ``"blocked"`` → ``"Blocked: "`` + ``"failed"`` → ``"Failed: "`` + + The *outcome* parameter is matched case-insensitively. Any detail longer + than 200 characters is truncated with '...' suffix. + """ + key = outcome.lower() + if key == "completed": + return "Completed" + detail = _truncate(outcome_detail or "") if outcome_detail else "" + if key == "blocked": + if detail: + return f"Blocked: {detail}" + return "Blocked" + if key == "failed": + if detail: + return f"Failed: {detail}" + return "Failed" + # Fallback for unknown outcome values — display as-is with optional detail. + if detail: + return f"{outcome}: {detail}" + return outcome + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def format_stats_summary( + stats: StatsState, + outcome: str, + outcome_detail: str | None = None, +) -> str: + """Format a StatsState snapshot into a Jira wiki markup comment. + + The generated comment includes: + * A stage-by-stage metrics table (iterations, machine time, human time, + input tokens, output tokens). + * An aggregate token totals row. + * A PR links section (omitted when no PRs were created). + * A CI cycles line. + * A final outcome field. + + Args: + stats: The workflow statistics state to format. + outcome: Outcome category — one of ``"completed"``, ``"blocked"``, or + ``"failed"`` (matched case-insensitively). + outcome_detail: Optional elaboration on the outcome (e.g. the blocking + reason or error message). Truncated to 200 characters if longer. + + Returns: + A Jira wiki markup string ready to post as a ticket comment. + """ + stages: dict[str, StageStats] = stats.get("stats_stages") or {} + pr_urls: list[str] = stats.get("stats_pr_urls") or [] + ci_cycles: int = stats.get("stats_ci_cycles") or 0 + + lines: list[str] = [] + + # ------------------------------------------------------------------ + # Stage metrics table + # ------------------------------------------------------------------ + lines.append("h3. Workflow Statistics") + lines.append("") + lines.append("||Stage||Iterations||Machine Time||Human Time||Input Tokens||Output Tokens||") + + for stage_key in ALL_FEATURE_STAGES: + label = _STAGE_LABELS.get(stage_key, stage_key.title()) + stage_data = stages.get(stage_key) + lines.append(_build_stage_row(label, stage_data)) + + # Aggregate totals row (always shown, even when no stages ran) + lines.append(_build_totals_row(stages)) + + # ------------------------------------------------------------------ + # PR links section (omitted when no PRs) + # ------------------------------------------------------------------ + if pr_urls: + lines.append("") + lines.append("*Pull Requests*") + for url in pr_urls: + lines.append(f"* [{url}|{url}]") + + # ------------------------------------------------------------------ + # CI cycles + # ------------------------------------------------------------------ + lines.append("") + lines.append(f"*CI Cycles:* {ci_cycles}") + + # ------------------------------------------------------------------ + # Outcome + # ------------------------------------------------------------------ + lines.append("") + outcome_str = _build_outcome_str(outcome, outcome_detail) + lines.append(f"*Outcome:* {outcome_str}") + + return "\n".join(lines) diff --git a/tests/unit/workflow/stats/__init__.py b/tests/unit/workflow/stats/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/workflow/stats/test_formatter.py b/tests/unit/workflow/stats/test_formatter.py new file mode 100644 index 00000000..f92e5f4e --- /dev/null +++ b/tests/unit/workflow/stats/test_formatter.py @@ -0,0 +1,444 @@ +"""Unit tests for forge.workflow.stats.formatter. + +All tests target format_stats_summary() and its internal helpers. +The suite is designed to achieve 100% branch coverage. +""" + +from forge.workflow.stats.formatter import ( + _build_outcome_str, + _build_stage_row, + _build_totals_row, + _fmt_seconds, + _fmt_tokens, + _truncate, + format_stats_summary, +) + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + + +def _make_stage( + *, + stage_name: str = "prd", + iteration_count: int = 1, + machine_time_seconds: float = 60.0, + human_time_seconds: float = 30.0, + input_tokens: int = 1000, + output_tokens: int = 500, + started_at: str | None = "2024-01-01T00:00:00+00:00", + ended_at: str | None = "2024-01-01T00:01:00+00:00", +) -> dict: + return { + "stage_name": stage_name, + "iteration_count": iteration_count, + "machine_time_seconds": machine_time_seconds, + "human_time_seconds": human_time_seconds, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "started_at": started_at, + "ended_at": ended_at, + } + + +def _minimal_stats(**overrides) -> dict: + """Return a minimal StatsState-like dict.""" + base = { + "stats_stages": {}, + "stats_pr_urls": [], + "stats_ci_cycles": 0, + "stats_outcome": None, + "stats_outcome_reason": None, + "stats_comment_posted": False, + } + base.update(overrides) + return base + + +# --------------------------------------------------------------------------- +# _truncate +# --------------------------------------------------------------------------- + + +class TestTruncate: + def test_short_string_unchanged(self): + assert _truncate("hello") == "hello" + + def test_exactly_max_len_unchanged(self): + text = "x" * 200 + assert _truncate(text) == text + + def test_one_over_max_len_truncated(self): + text = "x" * 201 + result = _truncate(text) + assert result == "x" * 200 + "..." + assert len(result) == 203 # 200 chars + "..." + + def test_much_longer_text_truncated(self): + text = "a" * 500 + result = _truncate(text) + assert result.endswith("...") + assert len(result) == 203 + + def test_custom_max_len(self): + result = _truncate("hello world", max_len=5) + assert result == "hello..." + + def test_empty_string(self): + assert _truncate("") == "" + + +# --------------------------------------------------------------------------- +# _fmt_seconds +# --------------------------------------------------------------------------- + + +class TestFmtSeconds: + def test_seconds_only(self): + assert _fmt_seconds(45.0) == "45s" + + def test_zero_seconds(self): + assert _fmt_seconds(0.0) == "0s" + + def test_minutes_and_seconds(self): + assert _fmt_seconds(90.0) == "1m 30s" + + def test_exact_minutes(self): + assert _fmt_seconds(120.0) == "2m 0s" + + def test_hours_minutes_seconds(self): + assert _fmt_seconds(3661.0) == "1h 1m 1s" + + def test_exact_hour(self): + assert _fmt_seconds(3600.0) == "1h 0m 0s" + + def test_fractional_seconds_truncated(self): + # Float fractions are discarded (int conversion) + assert _fmt_seconds(90.9) == "1m 30s" + + def test_multiple_hours(self): + assert _fmt_seconds(7322.0) == "2h 2m 2s" + + +# --------------------------------------------------------------------------- +# _fmt_tokens +# --------------------------------------------------------------------------- + + +class TestFmtTokens: + def test_zero(self): + assert _fmt_tokens(0) == "0" + + def test_small_number(self): + assert _fmt_tokens(999) == "999" + + def test_thousands(self): + assert _fmt_tokens(1000) == "1,000" + + def test_millions(self): + assert _fmt_tokens(1_500_000) == "1,500,000" + + +# --------------------------------------------------------------------------- +# _build_stage_row +# --------------------------------------------------------------------------- + + +class TestBuildStageRow: + def test_none_stage_shows_dashes(self): + row = _build_stage_row("PRD", None) + # Should show em-dash in all metric columns + assert row == "|PRD|—|—|—|—|—|" + + def test_executed_stage_shows_metrics(self): + stage = _make_stage( + iteration_count=2, + machine_time_seconds=90.0, + human_time_seconds=60.0, + input_tokens=1000, + output_tokens=500, + ) + row = _build_stage_row("PRD", stage) + assert row == "|PRD|2|1m 30s|1m 0s|1,000|500|" + + def test_stage_with_zero_times(self): + stage = _make_stage( + iteration_count=1, + machine_time_seconds=0.0, + human_time_seconds=0.0, + input_tokens=0, + output_tokens=0, + ) + row = _build_stage_row("Spec", stage) + assert row == "|Spec|1|0s|0s|0|0|" + + +# --------------------------------------------------------------------------- +# _build_totals_row +# --------------------------------------------------------------------------- + + +class TestBuildTotalsRow: + def test_empty_stages(self): + row = _build_totals_row({}) + assert row == "|*Total*|—|—|—|*0*|*0*|" + + def test_single_stage(self): + stages = {"prd": _make_stage(input_tokens=100, output_tokens=50)} + row = _build_totals_row(stages) + assert row == "|*Total*|—|—|—|*100*|*50*|" + + def test_multiple_stages_summed(self): + stages = { + "prd": _make_stage(input_tokens=1000, output_tokens=500), + "spec": _make_stage(input_tokens=2000, output_tokens=800), + } + row = _build_totals_row(stages) + assert row == "|*Total*|—|—|—|*3,000*|*1,300*|" + + +# --------------------------------------------------------------------------- +# _build_outcome_str +# --------------------------------------------------------------------------- + + +class TestBuildOutcomeStr: + def test_completed_no_detail(self): + assert _build_outcome_str("completed", None) == "Completed" + + def test_completed_case_insensitive(self): + assert _build_outcome_str("Completed", None) == "Completed" + assert _build_outcome_str("COMPLETED", None) == "Completed" + + def test_completed_ignores_detail(self): + # For 'completed', outcome_detail should be ignored + assert _build_outcome_str("completed", "some detail") == "Completed" + + def test_blocked_with_reason(self): + result = _build_outcome_str("blocked", "Waiting for security review") + assert result == "Blocked: Waiting for security review" + + def test_blocked_without_reason(self): + assert _build_outcome_str("blocked", None) == "Blocked" + + def test_blocked_with_empty_reason(self): + assert _build_outcome_str("blocked", "") == "Blocked" + + def test_blocked_truncates_long_reason(self): + long_reason = "x" * 201 + result = _build_outcome_str("blocked", long_reason) + assert result == "Blocked: " + "x" * 200 + "..." + + def test_failed_with_error(self): + result = _build_outcome_str("failed", "Database connection timeout") + assert result == "Failed: Database connection timeout" + + def test_failed_without_error(self): + assert _build_outcome_str("failed", None) == "Failed" + + def test_failed_with_empty_error(self): + assert _build_outcome_str("failed", "") == "Failed" + + def test_failed_truncates_long_error(self): + long_error = "e" * 300 + result = _build_outcome_str("failed", long_error) + assert result.startswith("Failed: ") + assert result.endswith("...") + # detail portion is 200 chars + assert len(result) == len("Failed: ") + 200 + 3 + + def test_unknown_outcome_no_detail(self): + result = _build_outcome_str("aborted", None) + assert result == "aborted" + + def test_unknown_outcome_with_detail(self): + result = _build_outcome_str("aborted", "some reason") + assert result == "aborted: some reason" + + +# --------------------------------------------------------------------------- +# format_stats_summary — structural / content tests +# --------------------------------------------------------------------------- + + +class TestFormatStatsSummaryStructure: + def test_returns_string(self): + result = format_stats_summary(_minimal_stats(), "completed") + assert isinstance(result, str) + + def test_contains_header(self): + result = format_stats_summary(_minimal_stats(), "completed") + assert "h3. Workflow Statistics" in result + + def test_contains_table_header_row(self): + result = format_stats_summary(_minimal_stats(), "completed") + assert ( + "||Stage||Iterations||Machine Time||Human Time||Input Tokens||Output Tokens||" in result + ) + + def test_contains_all_feature_stages(self): + result = format_stats_summary(_minimal_stats(), "completed") + for label in ["PRD", "Spec", "Epics", "Tasks", "Implementation", "CI", "Review"]: + assert label in result + + def test_never_executed_stages_show_dash(self): + result = format_stats_summary(_minimal_stats(), "completed") + # All stages are unexecuted; each row should have em-dashes + lines = result.splitlines() + stage_rows = [ + line + for line in lines + if line.startswith("|") + and not line.startswith("||") + and not line.startswith("|*Total*") + ] + assert len(stage_rows) == 7 # 7 feature stages + for row in stage_rows: + assert "—" in row + + def test_contains_totals_row(self): + result = format_stats_summary(_minimal_stats(), "completed") + assert "|*Total*|" in result + + def test_contains_ci_cycles(self): + stats = _minimal_stats(stats_ci_cycles=3) + result = format_stats_summary(stats, "completed") + assert "*CI Cycles:* 3" in result + + def test_contains_outcome(self): + result = format_stats_summary(_minimal_stats(), "completed") + assert "*Outcome:* Completed" in result + + +class TestFormatStatsSummaryPRLinks: + def test_no_prs_omits_section(self): + result = format_stats_summary(_minimal_stats(), "completed") + assert "Pull Requests" not in result + + def test_single_pr_included(self): + stats = _minimal_stats(stats_pr_urls=["https://github.com/org/repo/pull/1"]) + result = format_stats_summary(stats, "completed") + assert "*Pull Requests*" in result + assert "* [https://github.com/org/repo/pull/1|https://github.com/org/repo/pull/1]" in result + + def test_multiple_prs_all_included(self): + urls = [ + "https://github.com/org/repo/pull/1", + "https://github.com/org/repo/pull/2", + ] + stats = _minimal_stats(stats_pr_urls=urls) + result = format_stats_summary(stats, "completed") + assert "*Pull Requests*" in result + for url in urls: + assert f"* [{url}|{url}]" in result + + +class TestFormatStatsSummaryStageData: + def test_executed_stage_shows_metrics(self): + stage = _make_stage( + stage_name="prd", + iteration_count=3, + machine_time_seconds=3661.0, + human_time_seconds=120.0, + input_tokens=5000, + output_tokens=1500, + ) + stats = _minimal_stats(stats_stages={"prd": stage}) + result = format_stats_summary(stats, "completed") + assert "|PRD|3|1h 1m 1s|2m 0s|5,000|1,500|" in result + + def test_unexecuted_stage_shows_dashes(self): + stats = _minimal_stats() + result = format_stats_summary(stats, "completed") + assert "|PRD|—|—|—|—|—|" in result + + def test_totals_sum_across_stages(self): + stages = { + "prd": _make_stage(input_tokens=1000, output_tokens=500), + "spec": _make_stage(input_tokens=2000, output_tokens=800), + "implementation": _make_stage(input_tokens=10000, output_tokens=4000), + } + stats = _minimal_stats(stats_stages=stages) + result = format_stats_summary(stats, "completed") + assert "|*Total*|—|—|—|*13,000*|*5,300*|" in result + + def test_empty_stages_totals_zero(self): + result = format_stats_summary(_minimal_stats(), "completed") + assert "|*Total*|—|—|—|*0*|*0*|" in result + + +class TestFormatStatsSummaryOutcome: + def test_completed_outcome(self): + result = format_stats_summary(_minimal_stats(), "completed") + assert "*Outcome:* Completed" in result + + def test_blocked_outcome_with_reason(self): + result = format_stats_summary( + _minimal_stats(), + "blocked", + outcome_detail="Waiting for approval", + ) + assert "*Outcome:* Blocked: Waiting for approval" in result + + def test_blocked_outcome_no_reason(self): + result = format_stats_summary(_minimal_stats(), "blocked") + assert "*Outcome:* Blocked" in result + + def test_failed_outcome_with_error(self): + result = format_stats_summary( + _minimal_stats(), + "failed", + outcome_detail="Unhandled exception", + ) + assert "*Outcome:* Failed: Unhandled exception" in result + + def test_failed_outcome_no_error(self): + result = format_stats_summary(_minimal_stats(), "failed") + assert "*Outcome:* Failed" in result + + def test_long_detail_truncated(self): + long_reason = "z" * 300 + result = format_stats_summary( + _minimal_stats(), + "blocked", + outcome_detail=long_reason, + ) + expected_detail = "z" * 200 + "..." + assert f"*Outcome:* Blocked: {expected_detail}" in result + + def test_exactly_200_char_detail_not_truncated(self): + reason = "a" * 200 + result = format_stats_summary(_minimal_stats(), "blocked", outcome_detail=reason) + assert f"*Outcome:* Blocked: {reason}" in result + assert "..." not in result + + def test_outcome_case_insensitive(self): + result = format_stats_summary(_minimal_stats(), "Completed") + assert "*Outcome:* Completed" in result + + +class TestFormatStatsSummaryMissingFields: + """Ensure the formatter handles states with missing optional fields gracefully.""" + + def test_empty_state_dict(self): + """A completely empty dict should produce valid output without errors.""" + result = format_stats_summary({}, "completed") + assert isinstance(result, str) + assert "*CI Cycles:* 0" in result + assert "*Outcome:* Completed" in result + + def test_none_stats_stages(self): + stats = _minimal_stats(stats_stages=None) + result = format_stats_summary(stats, "completed") + assert "|*Total*|—|—|—|*0*|*0*|" in result + + def test_none_pr_urls(self): + stats = _minimal_stats(stats_pr_urls=None) + result = format_stats_summary(stats, "completed") + assert "Pull Requests" not in result + + def test_none_ci_cycles(self): + stats = _minimal_stats(stats_ci_cycles=None) + result = format_stats_summary(stats, "completed") + assert "*CI Cycles:* 0" in result From 42ea5f4222267d9a70487da2aa192c6ae686d787 Mon Sep 17 00:00:00 2001 From: Forge Date: Wed, 24 Jun 2026 07:53:45 +0000 Subject: [PATCH 08/68] [AISOS-1895] Create Stats Comment Posting Service Detailed description: - Added src/forge/workflow/stats/poster.py implementing post_stats_comment() async function that formats and posts workflow statistics as a Jira comment - Exponential backoff retry logic: up to 3 attempts with 1s/2s delays - 5-minute SLA enforcement via asyncio.wait_for() with _OPERATION_TIMEOUT_SECONDS=300 - Non-blocking on failure: all exceptions are caught and logged; False returned - JiraClient is instantiated per attempt and always closed in a finally block - Added tests/unit/workflow/stats/test_poster.py with 22 unit tests covering: success path, API failure (graceful degradation), retry logic (backoff/sleep call counts, per-attempt client creation), timeout scenarios, and comment content verification via formatter mock Closes: AISOS-1895 --- src/forge/workflow/stats/poster.py | 144 +++++++++ tests/unit/workflow/stats/test_poster.py | 389 +++++++++++++++++++++++ 2 files changed, 533 insertions(+) create mode 100644 src/forge/workflow/stats/poster.py create mode 100644 tests/unit/workflow/stats/test_poster.py diff --git a/src/forge/workflow/stats/poster.py b/src/forge/workflow/stats/poster.py new file mode 100644 index 00000000..7fb89144 --- /dev/null +++ b/src/forge/workflow/stats/poster.py @@ -0,0 +1,144 @@ +"""Stats comment posting service for Jira tickets. + +This module provides a non-blocking async function that formats and posts +workflow statistics as a comment to the associated Jira ticket at the end +of a workflow run. +""" + +import asyncio +import logging + +from forge.integrations.jira.client import JiraClient +from forge.workflow.stats import StatsState +from forge.workflow.stats.formatter import format_stats_summary + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Retry configuration +# --------------------------------------------------------------------------- + +#: Maximum number of posting attempts (1 initial + 2 retries). +_MAX_ATTEMPTS = 3 + +#: Initial backoff delay in seconds before the first retry. +_INITIAL_BACKOFF_SECONDS = 1.0 + +#: Maximum allowed backoff delay (caps exponential growth). +_MAX_BACKOFF_SECONDS = 16.0 + +#: Overall timeout for the entire post_stats_comment operation (5-minute SLA). +_OPERATION_TIMEOUT_SECONDS = 300.0 + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +async def post_stats_comment( + ticket_key: str, + stats: StatsState, + outcome: str, + outcome_detail: str | None = None, +) -> bool: + """Post a formatted stats summary comment to a Jira ticket. + + Formats the workflow statistics contained in *stats* into Jira wiki markup + and posts it as a comment on *ticket_key*. The operation uses exponential + backoff and retries up to :data:`_MAX_ATTEMPTS` times before giving up. + The entire operation is bounded by a 5-minute timeout. + + This function is *non-blocking on failure*: any exception is caught, + logged, and ``False`` is returned so that callers are not disrupted. + + Args: + ticket_key: The Jira issue key to comment on (e.g. ``"PROJ-123"``). + stats: The workflow statistics state to format and post. + outcome: Outcome category — one of ``"completed"``, ``"blocked"``, or + ``"failed"`` (matched case-insensitively by the formatter). + outcome_detail: Optional elaboration on the outcome. + + Returns: + ``True`` if the comment was successfully posted, ``False`` otherwise. + """ + try: + return await asyncio.wait_for( + _post_with_retry(ticket_key, stats, outcome, outcome_detail), + timeout=_OPERATION_TIMEOUT_SECONDS, + ) + except TimeoutError: + logger.error( + "post_stats_comment timed out after %.0fs for ticket %s", + _OPERATION_TIMEOUT_SECONDS, + ticket_key, + ) + return False + except Exception: + # Broad catch: we must never let stats posting crash the caller. + logger.exception( + "Unexpected error posting stats comment for ticket %s", + ticket_key, + ) + return False + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +async def _post_with_retry( + ticket_key: str, + stats: StatsState, + outcome: str, + outcome_detail: str | None, +) -> bool: + """Attempt to post the stats comment with exponential backoff on failure. + + Args: + ticket_key: Jira issue key. + stats: Workflow statistics state. + outcome: Outcome string passed to the formatter. + outcome_detail: Optional detail string passed to the formatter. + + Returns: + ``True`` if the comment was posted successfully, ``False`` after all + attempts are exhausted. + """ + comment_body = format_stats_summary(stats, outcome, outcome_detail) + backoff = _INITIAL_BACKOFF_SECONDS + + for attempt in range(1, _MAX_ATTEMPTS + 1): + jira = JiraClient() + try: + await jira.add_comment(ticket_key, comment_body) + logger.info( + "Posted stats comment to %s (attempt %d/%d)", + ticket_key, + attempt, + _MAX_ATTEMPTS, + ) + return True + except Exception as exc: + logger.warning( + "Failed to post stats comment to %s (attempt %d/%d): %s", + ticket_key, + attempt, + _MAX_ATTEMPTS, + exc, + ) + if attempt < _MAX_ATTEMPTS: + wait = min(backoff, _MAX_BACKOFF_SECONDS) + logger.debug("Retrying in %.1fs…", wait) + await asyncio.sleep(wait) + backoff *= 2 + finally: + await jira.close() + + logger.error( + "Gave up posting stats comment to %s after %d attempts", + ticket_key, + _MAX_ATTEMPTS, + ) + return False diff --git a/tests/unit/workflow/stats/test_poster.py b/tests/unit/workflow/stats/test_poster.py new file mode 100644 index 00000000..26ca9bf9 --- /dev/null +++ b/tests/unit/workflow/stats/test_poster.py @@ -0,0 +1,389 @@ +"""Unit tests for forge.workflow.stats.poster. + +Tests verify: +- Successful comment posting returns True +- Jira API failures are handled gracefully (return False, log error) +- Retry logic with exponential backoff fires on transient failures +- Timeout handling returns False within the SLA +- JiraClient is always closed after use (resource cleanup) +- The correct comment body is passed to JiraClient.add_comment() +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from forge.workflow.stats.poster import ( + _INITIAL_BACKOFF_SECONDS, + _MAX_ATTEMPTS, + post_stats_comment, +) + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + +TICKET_KEY = "PROJ-42" +OUTCOME = "completed" +OUTCOME_DETAIL = None + + +def _minimal_stats(**overrides) -> dict: + base = { + "stats_stages": {}, + "stats_pr_urls": [], + "stats_ci_cycles": 0, + "stats_outcome": None, + "stats_outcome_reason": None, + "stats_comment_posted": False, + } + base.update(overrides) + return base + + +def _make_jira_mock(side_effect=None) -> MagicMock: + """Return a mock JiraClient instance with add_comment and close as coroutines.""" + mock = MagicMock() + if side_effect is not None: + mock.add_comment = AsyncMock(side_effect=side_effect) + else: + mock.add_comment = AsyncMock(return_value=MagicMock()) + mock.close = AsyncMock() + return mock + + +# --------------------------------------------------------------------------- +# Success scenario +# --------------------------------------------------------------------------- + + +class TestPostStatsCommentSuccess: + """post_stats_comment() returns True when the comment is posted successfully.""" + + @pytest.mark.asyncio + async def test_returns_true_on_success(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira): + result = await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert result is True + + @pytest.mark.asyncio + async def test_calls_add_comment_with_correct_ticket(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira): + await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + mock_jira.add_comment.assert_called_once() + args, _ = mock_jira.add_comment.call_args + assert args[0] == TICKET_KEY + + @pytest.mark.asyncio + async def test_comment_body_contains_outcome(self): + """The comment body produced by the formatter should mention 'Completed'.""" + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira): + await post_stats_comment(TICKET_KEY, _minimal_stats(), "completed") + + args, _ = mock_jira.add_comment.call_args + comment_body = args[1] + assert "Completed" in comment_body + + @pytest.mark.asyncio + async def test_comment_body_contains_outcome_detail(self): + mock_jira = _make_jira_mock() + detail = "deployment succeeded" + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira): + await post_stats_comment(TICKET_KEY, _minimal_stats(), "blocked", detail) + + args, _ = mock_jira.add_comment.call_args + comment_body = args[1] + assert detail in comment_body + + @pytest.mark.asyncio + async def test_jira_client_closed_on_success(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira): + await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + mock_jira.close.assert_called_once() + + @pytest.mark.asyncio + async def test_only_one_attempt_on_success(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira): + await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert mock_jira.add_comment.call_count == 1 + + +# --------------------------------------------------------------------------- +# Jira API failure scenarios +# --------------------------------------------------------------------------- + + +class TestPostStatsCommentApiFailure: + """post_stats_comment() is non-blocking: logs errors and returns False.""" + + @pytest.mark.asyncio + async def test_returns_false_on_persistent_failure(self): + mock_jira = _make_jira_mock(side_effect=Exception("API down")) + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch("forge.workflow.stats.poster.asyncio.sleep", new_callable=AsyncMock), + ): + result = await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert result is False + + @pytest.mark.asyncio + async def test_does_not_raise_on_api_error(self): + """post_stats_comment must never propagate exceptions to callers.""" + mock_jira = _make_jira_mock(side_effect=RuntimeError("connection refused")) + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch("forge.workflow.stats.poster.asyncio.sleep", new_callable=AsyncMock), + ): + # Should not raise + result = await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert result is False + + @pytest.mark.asyncio + async def test_jira_client_closed_on_failure(self): + """JiraClient.close() must be called even when add_comment raises.""" + mock_jira = _make_jira_mock(side_effect=Exception("API down")) + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch("forge.workflow.stats.poster.asyncio.sleep", new_callable=AsyncMock), + ): + await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + # close() is called once per attempt + assert mock_jira.close.call_count == _MAX_ATTEMPTS + + @pytest.mark.asyncio + async def test_http_status_error_returns_false(self): + import httpx + + mock_request = MagicMock(spec=httpx.Request) + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 500 + http_error = httpx.HTTPStatusError( + "Internal Server Error", request=mock_request, response=mock_response + ) + + mock_jira = _make_jira_mock(side_effect=http_error) + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch("forge.workflow.stats.poster.asyncio.sleep", new_callable=AsyncMock), + ): + result = await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert result is False + + +# --------------------------------------------------------------------------- +# Retry logic +# --------------------------------------------------------------------------- + + +class TestRetryLogic: + """Verify exponential backoff and retry behaviour.""" + + @pytest.mark.asyncio + async def test_retries_up_to_max_attempts_on_failure(self): + mock_jira = _make_jira_mock(side_effect=Exception("transient")) + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch("forge.workflow.stats.poster.asyncio.sleep", new_callable=AsyncMock), + ): + await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert mock_jira.add_comment.call_count == _MAX_ATTEMPTS + + @pytest.mark.asyncio + async def test_succeeds_on_second_attempt(self): + """Returns True when the first attempt fails but the second succeeds.""" + mock_jira = MagicMock() + mock_jira.add_comment = AsyncMock(side_effect=[Exception("transient"), MagicMock()]) + mock_jira.close = AsyncMock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch("forge.workflow.stats.poster.asyncio.sleep", new_callable=AsyncMock), + ): + result = await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert result is True + assert mock_jira.add_comment.call_count == 2 + + @pytest.mark.asyncio + async def test_exponential_backoff_sleep_calls(self): + """sleep() is called between retries with exponentially increasing delays.""" + mock_jira = _make_jira_mock(side_effect=Exception("transient")) + mock_sleep = AsyncMock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch("forge.workflow.stats.poster.asyncio.sleep", mock_sleep), + ): + await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + # With _MAX_ATTEMPTS=3 there are 2 sleeps (after attempt 1 and 2) + expected_sleep_count = _MAX_ATTEMPTS - 1 + assert mock_sleep.call_count == expected_sleep_count + + # Verify delays grow (first < second for default backoff) + if expected_sleep_count >= 2: + delays = [c.args[0] for c in mock_sleep.call_args_list] + assert delays[1] > delays[0], "Second backoff should be larger than first" + + @pytest.mark.asyncio + async def test_initial_backoff_value(self): + """First retry uses _INITIAL_BACKOFF_SECONDS as the wait duration.""" + mock_jira = _make_jira_mock( + side_effect=[Exception("fail"), Exception("fail"), Exception("fail")] + ) + mock_sleep = AsyncMock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch("forge.workflow.stats.poster.asyncio.sleep", mock_sleep), + ): + await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + first_delay = mock_sleep.call_args_list[0].args[0] + assert first_delay == _INITIAL_BACKOFF_SECONDS + + @pytest.mark.asyncio + async def test_jira_client_instantiated_per_attempt(self): + """A fresh JiraClient is created for each attempt.""" + mock_jira = _make_jira_mock(side_effect=Exception("transient")) + mock_cls = MagicMock(return_value=mock_jira) + + with ( + patch("forge.workflow.stats.poster.JiraClient", mock_cls), + patch("forge.workflow.stats.poster.asyncio.sleep", new_callable=AsyncMock), + ): + await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert mock_cls.call_count == _MAX_ATTEMPTS + + @pytest.mark.asyncio + async def test_no_sleep_after_last_attempt(self): + """No sleep is issued after the final (exhausted) attempt.""" + mock_jira = _make_jira_mock(side_effect=Exception("transient")) + mock_sleep = AsyncMock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch("forge.workflow.stats.poster.asyncio.sleep", mock_sleep), + ): + await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + # sleeps = attempts - 1 + assert mock_sleep.call_count == _MAX_ATTEMPTS - 1 + + +# --------------------------------------------------------------------------- +# Timeout scenario +# --------------------------------------------------------------------------- + + +class TestTimeoutHandling: + """post_stats_comment() respects the 5-minute SLA timeout.""" + + @pytest.mark.asyncio + async def test_returns_false_on_timeout(self): + async def slow_add_comment(*_args, **_kwargs): + await asyncio.sleep(999) + + mock_jira = MagicMock() + mock_jira.add_comment = slow_add_comment + mock_jira.close = AsyncMock() + + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), patch( + "forge.workflow.stats.poster._OPERATION_TIMEOUT_SECONDS", + 0.05, # Use a very short timeout for the test + ): + result = await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert result is False + + @pytest.mark.asyncio + async def test_does_not_raise_on_timeout(self): + """TimeoutError must be swallowed and False returned.""" + + async def slow_add_comment(*_args, **_kwargs): + await asyncio.sleep(999) + + mock_jira = MagicMock() + mock_jira.add_comment = slow_add_comment + mock_jira.close = AsyncMock() + + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), patch( + "forge.workflow.stats.poster._OPERATION_TIMEOUT_SECONDS", + 0.05, + ): + # Should not raise TimeoutError + result = await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert result is False + + +# --------------------------------------------------------------------------- +# Comment content +# --------------------------------------------------------------------------- + + +class TestCommentContent: + """Verify the formatted comment body is constructed from stats correctly.""" + + @pytest.mark.asyncio + async def test_comment_includes_workflow_statistics_header(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira): + await post_stats_comment(TICKET_KEY, _minimal_stats(), "completed") + + args, _ = mock_jira.add_comment.call_args + assert "Workflow Statistics" in args[1] + + @pytest.mark.asyncio + async def test_comment_includes_ci_cycles(self): + stats = _minimal_stats(stats_ci_cycles=3) + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira): + await post_stats_comment(TICKET_KEY, stats, "completed") + + args, _ = mock_jira.add_comment.call_args + assert "3" in args[1] + + @pytest.mark.asyncio + async def test_comment_failed_outcome_with_detail(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira): + await post_stats_comment(TICKET_KEY, _minimal_stats(), "failed", "disk full") + + args, _ = mock_jira.add_comment.call_args + body = args[1] + assert "Failed" in body + assert "disk full" in body + + @pytest.mark.asyncio + async def test_format_stats_summary_called_with_correct_args(self): + """Ensure the formatter is invoked with the right stats, outcome, and detail.""" + mock_jira = _make_jira_mock() + stats = _minimal_stats(stats_ci_cycles=1) + detail = "some detail" + + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), patch( + "forge.workflow.stats.poster.format_stats_summary", + wraps=__import__( + "forge.workflow.stats.formatter", fromlist=["format_stats_summary"] + ).format_stats_summary, + ) as mock_fmt: + await post_stats_comment(TICKET_KEY, stats, "blocked", detail) + + mock_fmt.assert_called_once_with(stats, "blocked", detail) From 6c9a43064b6fe9cc0203b67479f69f1e9cb88460 Mon Sep 17 00:00:00 2001 From: Forge Date: Wed, 24 Jun 2026 08:03:37 +0000 Subject: [PATCH 09/68] [AISOS-1896] Implement idempotency guard for stats comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Detailed description: - Created src/forge/workflow/stats/idempotency.py with: - has_stats_been_posted(ticket_key, run_id) async function — checks Redis for an existing idempotency marker (returns True if duplicate) - mark_stats_posted(ticket_key, run_id) async function — stores a marker in Redis with a 7-day TTL (604 800 seconds) - build_run_marker(run_id) — builds the hidden HTML comment to embed in the comment body () - _make_key(ticket_key, run_id) — constructs the Redis key in the format forge:stats:posted:: - STATS_IDEMPOTENCY_TTL_SECONDS = 604 800 (7 days) constant - Added workflow_run_id: str field to StatsState TypedDict to carry the unique run identifier through workflow state - Updated create_initial_feature_state() and create_initial_bug_state() to generate a UUID4 workflow_run_id at workflow initialization - Integrated idempotency guard into post_stats_comment(): - Pre-check: skips posting and returns True if already posted for run_id - Post-mark: writes marker to Redis after a successful post - Failure resilience: Redis errors do not block posting (log + continue) - run_id is resolved from the explicit arg or stats['workflow_run_id'] - Updated _post_with_retry() to accept run_id and append the HTML marker to the comment body when run_id is present - Created tests/unit/workflow/stats/test_idempotency.py — 32 unit tests with mocked Redis covering all functions and edge cases - Created tests/unit/workflow/stats/test_stats_idempotency_integration.py — 5 integration tests demonstrating end-to-end duplicate prevention using an in-memory FakeRedis stub Closes: AISOS-1896 --- src/forge/workflow/bug/state.py | 2 + src/forge/workflow/feature/state.py | 2 + src/forge/workflow/stats/__init__.py | 4 + src/forge/workflow/stats/idempotency.py | 135 +++++ src/forge/workflow/stats/poster.py | 77 ++- tests/unit/workflow/stats/test_idempotency.py | 468 ++++++++++++++++++ .../test_stats_idempotency_integration.py | 196 ++++++++ 7 files changed, 881 insertions(+), 3 deletions(-) create mode 100644 src/forge/workflow/stats/idempotency.py create mode 100644 tests/unit/workflow/stats/test_idempotency.py create mode 100644 tests/unit/workflow/stats/test_stats_idempotency_integration.py diff --git a/src/forge/workflow/bug/state.py b/src/forge/workflow/bug/state.py index a8e5f81a..6406024f 100644 --- a/src/forge/workflow/bug/state.py +++ b/src/forge/workflow/bug/state.py @@ -1,5 +1,6 @@ """Bug workflow state definition.""" +import uuid from datetime import datetime from typing import Any @@ -148,6 +149,7 @@ def create_initial_bug_state(ticket_key: str, **kwargs: Any) -> BugState: "stats_outcome": None, "stats_outcome_reason": None, "stats_comment_posted": False, + "workflow_run_id": str(uuid.uuid4()), } # Merge with kwargs, letting kwargs override defaults diff --git a/src/forge/workflow/feature/state.py b/src/forge/workflow/feature/state.py index dbaae49d..09522905 100644 --- a/src/forge/workflow/feature/state.py +++ b/src/forge/workflow/feature/state.py @@ -1,5 +1,6 @@ """Feature workflow state definition.""" +import uuid from datetime import datetime from typing import Any @@ -135,6 +136,7 @@ def create_initial_feature_state(ticket_key: str, **kwargs: Any) -> FeatureState "stats_outcome": None, "stats_outcome_reason": None, "stats_comment_posted": False, + "workflow_run_id": str(uuid.uuid4()), } # Merge with kwargs, letting kwargs override defaults diff --git a/src/forge/workflow/stats/__init__.py b/src/forge/workflow/stats/__init__.py index b72e348c..b648af1e 100644 --- a/src/forge/workflow/stats/__init__.py +++ b/src/forge/workflow/stats/__init__.py @@ -114,6 +114,9 @@ class StatsState(TypedDict, total=False): the blocking reason or error message), or None when not applicable. stats_comment_posted: True once the summary statistics comment has been posted to the Jira ticket (prevents double-posting on retries). + workflow_run_id: A unique identifier for this specific workflow run + (UUID4 string). Used as the idempotency key when posting the stats + comment to prevent duplicate posts across retries or re-invocations. """ stats_stages: dict[str, StageStats] @@ -122,3 +125,4 @@ class StatsState(TypedDict, total=False): stats_outcome: str | None stats_outcome_reason: str | None stats_comment_posted: bool + workflow_run_id: str diff --git a/src/forge/workflow/stats/idempotency.py b/src/forge/workflow/stats/idempotency.py new file mode 100644 index 00000000..0bc5264f --- /dev/null +++ b/src/forge/workflow/stats/idempotency.py @@ -0,0 +1,135 @@ +"""Idempotency guard for stats comment posting. + +Prevents duplicate stats comments from being posted to the same Jira ticket +for the same workflow run. Markers are stored in Redis with a 7-day TTL, +which is more than sufficient for any workflow to complete. + +Usage:: + + from forge.workflow.stats.idempotency import has_stats_been_posted, mark_stats_posted + + if not await has_stats_been_posted(ticket_key, run_id): + # … post comment … + await mark_stats_posted(ticket_key, run_id) +""" + +import logging + +import redis.asyncio as redis + +from forge.orchestrator.checkpointer import get_redis_client + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +#: Redis key prefix for stats-posted idempotency markers. +_KEY_PREFIX = "forge:stats:posted:" + +#: Time-to-live for idempotency markers (7 days in seconds). +STATS_IDEMPOTENCY_TTL_SECONDS = 7 * 24 * 60 * 60 # 604 800 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_key(ticket_key: str, run_id: str) -> str: + """Return the Redis key for a given ticket / run combination. + + Args: + ticket_key: The Jira issue key (e.g. ``"PROJ-123"``). + run_id: The unique workflow run identifier (UUID4 string). + + Returns: + Redis key string in the form ``forge:stats:posted::``. + """ + return f"{_KEY_PREFIX}{ticket_key}:{run_id}" + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +async def has_stats_been_posted( + ticket_key: str, + run_id: str, + *, + redis_client: redis.Redis | None = None, +) -> bool: + """Check whether a stats comment has already been posted for this run. + + Args: + ticket_key: The Jira issue key (e.g. ``"PROJ-123"``). + run_id: The unique workflow run identifier stored in + ``StatsState.workflow_run_id``. + redis_client: Optional Redis client to use. A shared client is + obtained via :func:`~forge.orchestrator.checkpointer.get_redis_client` + when not provided. + + Returns: + ``True`` if the marker exists in Redis (comment already posted), + ``False`` otherwise. + """ + client = redis_client if redis_client is not None else await get_redis_client() + key = _make_key(ticket_key, run_id) + exists = await client.exists(key) + posted = bool(exists) + if posted: + logger.debug( + "Stats comment already posted for ticket=%s run_id=%s (key=%s)", + ticket_key, + run_id, + key, + ) + return posted + + +async def mark_stats_posted( + ticket_key: str, + run_id: str, + *, + redis_client: redis.Redis | None = None, +) -> None: + """Record that a stats comment has been posted for this run. + + Stores a marker in Redis with a 7-day TTL so that subsequent calls to + :func:`has_stats_been_posted` return ``True`` for the same combination. + + Args: + ticket_key: The Jira issue key (e.g. ``"PROJ-123"``). + run_id: The unique workflow run identifier stored in + ``StatsState.workflow_run_id``. + redis_client: Optional Redis client to use. A shared client is + obtained via :func:`~forge.orchestrator.checkpointer.get_redis_client` + when not provided. + """ + client = redis_client if redis_client is not None else await get_redis_client() + key = _make_key(ticket_key, run_id) + await client.setex(key, STATS_IDEMPOTENCY_TTL_SECONDS, "1") + logger.debug( + "Marked stats comment as posted for ticket=%s run_id=%s (TTL=%ds)", + ticket_key, + run_id, + STATS_IDEMPOTENCY_TTL_SECONDS, + ) + + +def build_run_marker(run_id: str) -> str: + """Return the hidden HTML comment marker to embed in the posted comment. + + Including this marker in the Jira comment body allows independent + verification that a comment was posted for a specific run — useful + for debugging and for future tooling that inspects comment bodies. + + Args: + run_id: The unique workflow run identifier. + + Returns: + HTML comment string of the form ````. + """ + return f"" diff --git a/src/forge/workflow/stats/poster.py b/src/forge/workflow/stats/poster.py index 7fb89144..ea1c4b8b 100644 --- a/src/forge/workflow/stats/poster.py +++ b/src/forge/workflow/stats/poster.py @@ -3,6 +3,15 @@ This module provides a non-blocking async function that formats and posts workflow statistics as a comment to the associated Jira ticket at the end of a workflow run. + +Idempotency +----------- +``post_stats_comment`` checks Redis before posting and skips the comment if +one has already been recorded for the given ``run_id``. After a successful +post the marker is written to Redis with a 7-day TTL via +:func:`~forge.workflow.stats.idempotency.mark_stats_posted`. A hidden HTML +comment (````) is also embedded in the comment +body for independent verification. """ import asyncio @@ -11,6 +20,11 @@ from forge.integrations.jira.client import JiraClient from forge.workflow.stats import StatsState from forge.workflow.stats.formatter import format_stats_summary +from forge.workflow.stats.idempotency import ( + build_run_marker, + has_stats_been_posted, + mark_stats_posted, +) logger = logging.getLogger(__name__) @@ -41,6 +55,7 @@ async def post_stats_comment( stats: StatsState, outcome: str, outcome_detail: str | None = None, + run_id: str | None = None, ) -> bool: """Post a formatted stats summary comment to a Jira ticket. @@ -49,6 +64,12 @@ async def post_stats_comment( backoff and retries up to :data:`_MAX_ATTEMPTS` times before giving up. The entire operation is bounded by a 5-minute timeout. + **Idempotency**: when *run_id* is provided (or can be read from + ``stats["workflow_run_id"]``), the function checks Redis before posting + and returns ``True`` immediately if the comment has already been posted for + this run. A hidden HTML comment is embedded in the body and a Redis + marker is written after a successful post. + This function is *non-blocking on failure*: any exception is caught, logged, and ``False`` is returned so that callers are not disrupted. @@ -58,13 +79,38 @@ async def post_stats_comment( outcome: Outcome category — one of ``"completed"``, ``"blocked"``, or ``"failed"`` (matched case-insensitively by the formatter). outcome_detail: Optional elaboration on the outcome. + run_id: Unique workflow run identifier for idempotency. Falls back to + ``stats.get("workflow_run_id")`` when not given explicitly. Returns: - ``True`` if the comment was successfully posted, ``False`` otherwise. + ``True`` if the comment was successfully posted (or was already + posted for this run), ``False`` otherwise. """ + # Resolve the run identifier from the explicit argument or from state. + effective_run_id: str | None = run_id or stats.get("workflow_run_id") # type: ignore[call-overload] + + # --- Idempotency pre-check ------------------------------------------- + if effective_run_id: + try: + if await has_stats_been_posted(ticket_key, effective_run_id): + logger.info( + "Stats comment already posted for ticket=%s run_id=%s — skipping", + ticket_key, + effective_run_id, + ) + return True + except Exception: + # Redis check failures must not block posting. + logger.warning( + "Idempotency pre-check failed for ticket=%s run_id=%s; proceeding with post", + ticket_key, + effective_run_id, + exc_info=True, + ) + try: - return await asyncio.wait_for( - _post_with_retry(ticket_key, stats, outcome, outcome_detail), + posted = await asyncio.wait_for( + _post_with_retry(ticket_key, stats, outcome, outcome_detail, effective_run_id), timeout=_OPERATION_TIMEOUT_SECONDS, ) except TimeoutError: @@ -82,6 +128,22 @@ async def post_stats_comment( ) return False + # --- Idempotency post-mark ------------------------------------------- + if posted and effective_run_id: + try: + await mark_stats_posted(ticket_key, effective_run_id) + except Exception: + # Marker write failures are non-fatal — the comment is already + # posted; we just risk a harmless duplicate on the next retry. + logger.warning( + "Failed to write idempotency marker for ticket=%s run_id=%s", + ticket_key, + effective_run_id, + exc_info=True, + ) + + return posted + # --------------------------------------------------------------------------- # Internal helpers @@ -93,6 +155,7 @@ async def _post_with_retry( stats: StatsState, outcome: str, outcome_detail: str | None, + run_id: str | None = None, ) -> bool: """Attempt to post the stats comment with exponential backoff on failure. @@ -101,12 +164,20 @@ async def _post_with_retry( stats: Workflow statistics state. outcome: Outcome string passed to the formatter. outcome_detail: Optional detail string passed to the formatter. + run_id: Unique workflow run identifier. When provided, a hidden HTML + marker is appended to the comment body for verification. Returns: ``True`` if the comment was posted successfully, ``False`` after all attempts are exhausted. """ comment_body = format_stats_summary(stats, outcome, outcome_detail) + + # Append the idempotency marker so readers can verify which run produced + # this comment without querying Redis. + if run_id: + comment_body = f"{comment_body}\n{build_run_marker(run_id)}" + backoff = _INITIAL_BACKOFF_SECONDS for attempt in range(1, _MAX_ATTEMPTS + 1): diff --git a/tests/unit/workflow/stats/test_idempotency.py b/tests/unit/workflow/stats/test_idempotency.py new file mode 100644 index 00000000..1f5d4072 --- /dev/null +++ b/tests/unit/workflow/stats/test_idempotency.py @@ -0,0 +1,468 @@ +"""Unit tests for forge.workflow.stats.idempotency. + +Tests verify: +- has_stats_been_posted() returns False when key does not exist in Redis +- has_stats_been_posted() returns True when key exists in Redis +- mark_stats_posted() stores key with 7-day TTL via setex +- build_run_marker() returns the correct HTML comment string +- Redis key format includes both ticket_key and run_id +- Redis pre-check failures in post_stats_comment are non-fatal +- Idempotency integration: post_stats_comment skips duplicate posts +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from forge.workflow.stats.idempotency import ( + _KEY_PREFIX, + STATS_IDEMPOTENCY_TTL_SECONDS, + _make_key, + build_run_marker, + has_stats_been_posted, + mark_stats_posted, +) + +# --------------------------------------------------------------------------- +# Constants for tests +# --------------------------------------------------------------------------- + +TICKET_KEY = "PROJ-42" +RUN_ID = "550e8400-e29b-41d4-a716-446655440000" + + +# --------------------------------------------------------------------------- +# _make_key +# --------------------------------------------------------------------------- + + +class TestMakeKey: + """Internal key construction helper.""" + + def test_includes_prefix(self): + key = _make_key(TICKET_KEY, RUN_ID) + assert key.startswith(_KEY_PREFIX) + + def test_includes_ticket_key(self): + key = _make_key(TICKET_KEY, RUN_ID) + assert TICKET_KEY in key + + def test_includes_run_id(self): + key = _make_key(TICKET_KEY, RUN_ID) + assert RUN_ID in key + + def test_format(self): + key = _make_key("ABC-1", "run-xyz") + assert key == f"{_KEY_PREFIX}ABC-1:run-xyz" + + def test_different_tickets_produce_different_keys(self): + key1 = _make_key("PROJ-1", RUN_ID) + key2 = _make_key("PROJ-2", RUN_ID) + assert key1 != key2 + + def test_different_run_ids_produce_different_keys(self): + key1 = _make_key(TICKET_KEY, "run-1") + key2 = _make_key(TICKET_KEY, "run-2") + assert key1 != key2 + + +# --------------------------------------------------------------------------- +# build_run_marker +# --------------------------------------------------------------------------- + + +class TestBuildRunMarker: + """HTML comment marker for embedding in comment body.""" + + def test_returns_html_comment(self): + marker = build_run_marker(RUN_ID) + assert marker.startswith("") + + def test_includes_run_id(self): + marker = build_run_marker(RUN_ID) + assert RUN_ID in marker + + def test_contains_forge_stats_prefix(self): + marker = build_run_marker(RUN_ID) + assert "forge:stats:" in marker + + def test_format(self): + marker = build_run_marker("abc-123") + assert marker == "" + + def test_different_run_ids_produce_different_markers(self): + assert build_run_marker("run-1") != build_run_marker("run-2") + + +# --------------------------------------------------------------------------- +# TTL constant +# --------------------------------------------------------------------------- + + +class TestTtlConstant: + """Verify the 7-day TTL value.""" + + def test_seven_days_in_seconds(self): + assert STATS_IDEMPOTENCY_TTL_SECONDS == 7 * 24 * 60 * 60 + + def test_is_integer(self): + assert isinstance(STATS_IDEMPOTENCY_TTL_SECONDS, int) + + +# --------------------------------------------------------------------------- +# has_stats_been_posted +# --------------------------------------------------------------------------- + + +class TestHasStatsBeenPosted: + """has_stats_been_posted() checks Redis for the marker key.""" + + @pytest.mark.asyncio + async def test_returns_false_when_key_absent(self): + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=0) + + result = await has_stats_been_posted(TICKET_KEY, RUN_ID, redis_client=mock_redis) + + assert result is False + + @pytest.mark.asyncio + async def test_returns_true_when_key_present(self): + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=1) + + result = await has_stats_been_posted(TICKET_KEY, RUN_ID, redis_client=mock_redis) + + assert result is True + + @pytest.mark.asyncio + async def test_calls_exists_with_correct_key(self): + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=0) + + await has_stats_been_posted(TICKET_KEY, RUN_ID, redis_client=mock_redis) + + expected_key = _make_key(TICKET_KEY, RUN_ID) + mock_redis.exists.assert_called_once_with(expected_key) + + @pytest.mark.asyncio + async def test_uses_shared_client_when_none_provided(self): + """When redis_client is None, get_redis_client() is called.""" + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=0) + + with patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=mock_redis), + ): + result = await has_stats_been_posted(TICKET_KEY, RUN_ID) + + assert result is False + mock_redis.exists.assert_called_once() + + @pytest.mark.asyncio + async def test_truthy_redis_value_returns_true(self): + """Any non-zero integer from exists() is treated as True.""" + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=2) + + result = await has_stats_been_posted(TICKET_KEY, RUN_ID, redis_client=mock_redis) + + assert result is True + + +# --------------------------------------------------------------------------- +# mark_stats_posted +# --------------------------------------------------------------------------- + + +class TestMarkStatsPosted: + """mark_stats_posted() writes the marker key with correct TTL.""" + + @pytest.mark.asyncio + async def test_calls_setex(self): + mock_redis = AsyncMock() + mock_redis.setex = AsyncMock() + + await mark_stats_posted(TICKET_KEY, RUN_ID, redis_client=mock_redis) + + mock_redis.setex.assert_called_once() + + @pytest.mark.asyncio + async def test_setex_uses_correct_key(self): + mock_redis = AsyncMock() + mock_redis.setex = AsyncMock() + + await mark_stats_posted(TICKET_KEY, RUN_ID, redis_client=mock_redis) + + call_args = mock_redis.setex.call_args + key = call_args.args[0] + assert key == _make_key(TICKET_KEY, RUN_ID) + + @pytest.mark.asyncio + async def test_setex_uses_correct_ttl(self): + mock_redis = AsyncMock() + mock_redis.setex = AsyncMock() + + await mark_stats_posted(TICKET_KEY, RUN_ID, redis_client=mock_redis) + + call_args = mock_redis.setex.call_args + ttl = call_args.args[1] + assert ttl == STATS_IDEMPOTENCY_TTL_SECONDS + + @pytest.mark.asyncio + async def test_setex_stores_truthy_value(self): + mock_redis = AsyncMock() + mock_redis.setex = AsyncMock() + + await mark_stats_posted(TICKET_KEY, RUN_ID, redis_client=mock_redis) + + call_args = mock_redis.setex.call_args + value = call_args.args[2] + assert value # any truthy value is fine + + @pytest.mark.asyncio + async def test_uses_shared_client_when_none_provided(self): + mock_redis = AsyncMock() + mock_redis.setex = AsyncMock() + + with patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=mock_redis), + ): + await mark_stats_posted(TICKET_KEY, RUN_ID) + + mock_redis.setex.assert_called_once() + + @pytest.mark.asyncio + async def test_returns_none(self): + mock_redis = AsyncMock() + mock_redis.setex = AsyncMock() + + result = await mark_stats_posted(TICKET_KEY, RUN_ID, redis_client=mock_redis) + + assert result is None + + +# --------------------------------------------------------------------------- +# Integration with post_stats_comment +# --------------------------------------------------------------------------- + + +class TestPostStatsCommentIdempotency: + """post_stats_comment() integrates idempotency guard correctly.""" + + def _minimal_stats(self, **overrides) -> dict: + base = { + "stats_stages": {}, + "stats_pr_urls": [], + "stats_ci_cycles": 0, + "stats_outcome": None, + "stats_outcome_reason": None, + "stats_comment_posted": False, + "workflow_run_id": RUN_ID, + } + base.update(overrides) + return base + + def _make_jira_mock(self, side_effect=None) -> MagicMock: + mock = MagicMock() + if side_effect is not None: + mock.add_comment = AsyncMock(side_effect=side_effect) + else: + mock.add_comment = AsyncMock(return_value=MagicMock()) + mock.close = AsyncMock() + return mock + + @pytest.mark.asyncio + async def test_skips_posting_when_already_posted(self): + """Returns True immediately without calling Jira when Redis marker exists.""" + from forge.workflow.stats.poster import post_stats_comment + + mock_jira = self._make_jira_mock() + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=1) # already posted + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=mock_redis), + ), + ): + result = await post_stats_comment( + TICKET_KEY, self._minimal_stats(), "completed", run_id=RUN_ID + ) + + assert result is True + mock_jira.add_comment.assert_not_called() + + @pytest.mark.asyncio + async def test_posts_and_marks_when_not_yet_posted(self): + """Posts the comment and writes the marker when Redis key is absent.""" + from forge.workflow.stats.poster import post_stats_comment + + mock_jira = self._make_jira_mock() + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=0) # not yet posted + mock_redis.setex = AsyncMock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=mock_redis), + ), + ): + result = await post_stats_comment( + TICKET_KEY, self._minimal_stats(), "completed", run_id=RUN_ID + ) + + assert result is True + mock_jira.add_comment.assert_called_once() + mock_redis.setex.assert_called_once() + + @pytest.mark.asyncio + async def test_comment_body_includes_run_marker(self): + """The posted comment body contains the hidden HTML marker.""" + from forge.workflow.stats.poster import post_stats_comment + + mock_jira = self._make_jira_mock() + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=0) + mock_redis.setex = AsyncMock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=mock_redis), + ), + ): + await post_stats_comment(TICKET_KEY, self._minimal_stats(), "completed", run_id=RUN_ID) + + args, _ = mock_jira.add_comment.call_args + comment_body = args[1] + assert f"" in comment_body + + @pytest.mark.asyncio + async def test_uses_workflow_run_id_from_stats_when_no_explicit_run_id(self): + """Falls back to stats['workflow_run_id'] when run_id not passed explicitly.""" + from forge.workflow.stats.poster import post_stats_comment + + mock_jira = self._make_jira_mock() + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=0) + mock_redis.setex = AsyncMock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=mock_redis), + ), + ): + # Note: no explicit run_id — should pick up workflow_run_id from stats + result = await post_stats_comment(TICKET_KEY, self._minimal_stats(), "completed") + + assert result is True + args, _ = mock_jira.add_comment.call_args + comment_body = args[1] + assert f"" in comment_body + + @pytest.mark.asyncio + async def test_redis_check_failure_does_not_block_post(self): + """If the Redis pre-check raises, the comment is still attempted.""" + from forge.workflow.stats.poster import post_stats_comment + + mock_jira = self._make_jira_mock() + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(side_effect=ConnectionError("redis down")) + mock_redis.setex = AsyncMock(side_effect=ConnectionError("redis down")) + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=mock_redis), + ), + ): + result = await post_stats_comment( + TICKET_KEY, self._minimal_stats(), "completed", run_id=RUN_ID + ) + + # Comment should still be posted even if Redis is unavailable + assert result is True + mock_jira.add_comment.assert_called_once() + + @pytest.mark.asyncio + async def test_marker_write_failure_does_not_affect_return_value(self): + """If the Redis marker write fails after a successful post, True is still returned.""" + from forge.workflow.stats.poster import post_stats_comment + + mock_jira = self._make_jira_mock() + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=0) + mock_redis.setex = AsyncMock(side_effect=ConnectionError("redis down")) + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=mock_redis), + ), + ): + result = await post_stats_comment( + TICKET_KEY, self._minimal_stats(), "completed", run_id=RUN_ID + ) + + assert result is True + + @pytest.mark.asyncio + async def test_no_marker_when_run_id_absent(self): + """When no run_id is available, the comment body has no HTML marker.""" + from forge.workflow.stats.poster import post_stats_comment + + mock_jira = self._make_jira_mock() + # Stats without workflow_run_id + stats = { + "stats_stages": {}, + "stats_pr_urls": [], + "stats_ci_cycles": 0, + "stats_outcome": None, + "stats_outcome_reason": None, + "stats_comment_posted": False, + } + + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira): + await post_stats_comment(TICKET_KEY, stats, "completed") + + args, _ = mock_jira.add_comment.call_args + comment_body = args[1] + assert "forge:stats:" not in comment_body + + @pytest.mark.asyncio + async def test_does_not_mark_when_post_fails(self): + """Redis marker is NOT written if the Jira post fails.""" + from forge.workflow.stats.poster import post_stats_comment + + mock_jira = self._make_jira_mock(side_effect=Exception("API down")) + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=0) + mock_redis.setex = AsyncMock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch("forge.workflow.stats.poster.asyncio.sleep", new_callable=AsyncMock), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=mock_redis), + ), + ): + result = await post_stats_comment( + TICKET_KEY, self._minimal_stats(), "completed", run_id=RUN_ID + ) + + assert result is False + mock_redis.setex.assert_not_called() diff --git a/tests/unit/workflow/stats/test_stats_idempotency_integration.py b/tests/unit/workflow/stats/test_stats_idempotency_integration.py new file mode 100644 index 00000000..0f84e634 --- /dev/null +++ b/tests/unit/workflow/stats/test_stats_idempotency_integration.py @@ -0,0 +1,196 @@ +"""Integration test demonstrating stats comment duplicate prevention. + +This test shows the full idempotency flow end-to-end: + +1. First call to post_stats_comment() — Redis has no marker → posts comment + and writes the marker. +2. Second call to post_stats_comment() with the same run_id — Redis marker + present → skips posting entirely. + +The test uses an in-memory dict backed fake Redis to avoid requiring a +running Redis instance. This is an integration-level test because it +exercises the interaction between poster.py and idempotency.py rather than +testing each module in isolation. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Fake Redis implementation (in-memory dict — no real Redis required) +# --------------------------------------------------------------------------- + + +class FakeRedis: + """Minimal in-memory Redis stub supporting exists() and setex().""" + + def __init__(self): + self._store: dict[str, str] = {} + + async def exists(self, key: str) -> int: + return 1 if key in self._store else 0 + + async def setex(self, key: str, _ttl: int, value: str) -> None: + self._store[key] = value + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +TICKET_KEY = "INTTEST-99" +RUN_ID = "aabbccdd-1234-5678-abcd-000000000001" +OUTCOME = "completed" + + +def _minimal_stats(run_id: str = RUN_ID) -> dict: + return { + "stats_stages": {}, + "stats_pr_urls": [], + "stats_ci_cycles": 0, + "stats_outcome": None, + "stats_outcome_reason": None, + "stats_comment_posted": False, + "workflow_run_id": run_id, + } + + +def _make_jira_mock() -> MagicMock: + mock = MagicMock() + mock.add_comment = AsyncMock(return_value=MagicMock()) + mock.close = AsyncMock() + return mock + + +# --------------------------------------------------------------------------- +# Integration tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_first_call_posts_comment_and_marks_redis(): + """First invocation posts the comment and records the marker in Redis.""" + from forge.workflow.stats.poster import post_stats_comment + + fake_redis = FakeRedis() + mock_jira = _make_jira_mock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=fake_redis), + ), + ): + result = await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert result is True + mock_jira.add_comment.assert_called_once() + + # Marker must now be present in our fake Redis (key format: forge:stats:posted::) + assert await fake_redis.exists(f"forge:stats:posted:{TICKET_KEY}:{RUN_ID}") == 1 + + +@pytest.mark.asyncio +async def test_second_call_skips_posting(): + """Second invocation with the same run_id skips Jira entirely.""" + from forge.workflow.stats.poster import post_stats_comment + + fake_redis = FakeRedis() + mock_jira = _make_jira_mock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=fake_redis), + ), + ): + # First call — should post + result_first = await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + # Second call — should skip + result_second = await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert result_first is True + assert result_second is True # still "successful" — just a no-op + # Jira was only called once despite two invocations + assert mock_jira.add_comment.call_count == 1 + + +@pytest.mark.asyncio +async def test_different_run_ids_each_post_independently(): + """Two calls with different run_ids each result in a Jira post.""" + from forge.workflow.stats.poster import post_stats_comment + + fake_redis = FakeRedis() + mock_jira = _make_jira_mock() + run_id_a = "aaaaaaaa-0000-0000-0000-000000000001" + run_id_b = "bbbbbbbb-0000-0000-0000-000000000002" + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=fake_redis), + ), + ): + result_a = await post_stats_comment(TICKET_KEY, _minimal_stats(run_id_a), OUTCOME) + result_b = await post_stats_comment(TICKET_KEY, _minimal_stats(run_id_b), OUTCOME) + + assert result_a is True + assert result_b is True + assert mock_jira.add_comment.call_count == 2 + + +@pytest.mark.asyncio +async def test_comment_body_contains_unique_marker(): + """The posted comment embeds the hidden HTML marker for the run_id.""" + from forge.workflow.stats.poster import post_stats_comment + + fake_redis = FakeRedis() + mock_jira = _make_jira_mock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=fake_redis), + ), + ): + await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + args, _ = mock_jira.add_comment.call_args + comment_body = args[1] + assert f"" in comment_body + + +@pytest.mark.asyncio +async def test_same_ticket_different_runs_are_independent(): + """Same ticket key but different run IDs behave as independent posts.""" + from forge.workflow.stats.poster import post_stats_comment + + fake_redis = FakeRedis() + mock_jira_1 = _make_jira_mock() + mock_jira_2 = _make_jira_mock() + run_id_1 = "run-11111111-0000-0000-0000-000000000001" + run_id_2 = "run-22222222-0000-0000-0000-000000000002" + + with ( + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=fake_redis), + ), + ): + # First run on the same ticket + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira_1): + r1 = await post_stats_comment(TICKET_KEY, _minimal_stats(run_id_1), OUTCOME) + + # Second run (new run_id) on the same ticket — should also post + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira_2): + r2 = await post_stats_comment(TICKET_KEY, _minimal_stats(run_id_2), OUTCOME) + + assert r1 is True + assert r2 is True + mock_jira_1.add_comment.assert_called_once() + mock_jira_2.add_comment.assert_called_once() From 575838a5778aa9e4abbfe43f74ef19e878d00285 Mon Sep 17 00:00:00 2001 From: Forge Date: Wed, 24 Jun 2026 08:10:17 +0000 Subject: [PATCH 10/68] [AISOS-1897] Implement re-post mechanism for final stats comment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Detailed description: - Added ensure_stats_is_final_comment() async function to poster.py that guarantees the stats comment is always the last Forge comment on a ticket - Added _is_stats_comment() internal helper that detects stats comments by the embedded HTML marker () in the comment body - Added _STATS_BODY_MARKER constant for stats comment identification - Added jira_service_account_id setting to config.py for identifying which Jira comments were authored by the Forge service account - The function fetches all comments, filters by service account ID (if configured), checks if the most recent Forge comment is a stats comment, and re-posts if not — making it safe to call multiple times (idempotent) - Created 24 unit tests covering: stats detection, no-forge-comments case, idempotency when stats is already final, re-post logic, service account filtering, resource management, and error handling Closes: AISOS-1897 --- src/forge/config.py | 10 + src/forge/workflow/stats/poster.py | 118 +++- .../workflow/stats/test_ensure_stats_final.py | 508 ++++++++++++++++++ 3 files changed, 635 insertions(+), 1 deletion(-) create mode 100644 tests/unit/workflow/stats/test_ensure_stats_final.py diff --git a/src/forge/config.py b/src/forge/config.py index c50fbfc9..ee826a9c 100644 --- a/src/forge/config.py +++ b/src/forge/config.py @@ -58,6 +58,16 @@ def jira_domain_resolved(self) -> str: default="", description="Custom field ID for Specification storage (optional)", ) + jira_service_account_id: str = Field( + default="", + description=( + "Jira account ID of the Forge service account used to post comments. " + "When set, only comments authored by this account are treated as Forge " + "comments when checking whether the stats comment is the final comment " + "on a ticket (see ensure_stats_is_final_comment). " + "Set via JIRA_SERVICE_ACCOUNT_ID environment variable." + ), + ) # Jira workflow configuration jira_use_labels: bool = Field( diff --git a/src/forge/workflow/stats/poster.py b/src/forge/workflow/stats/poster.py index ea1c4b8b..a3fdb261 100644 --- a/src/forge/workflow/stats/poster.py +++ b/src/forge/workflow/stats/poster.py @@ -1,6 +1,6 @@ """Stats comment posting service for Jira tickets. -This module provides a non-blocking async function that formats and posts +This module provides non-blocking async functions that format and post workflow statistics as a comment to the associated Jira ticket at the end of a workflow run. @@ -12,11 +12,19 @@ :func:`~forge.workflow.stats.idempotency.mark_stats_posted`. A hidden HTML comment (````) is also embedded in the comment body for independent verification. + +Re-Post Mechanism +----------------- +``ensure_stats_is_final_comment`` guarantees the stats comment is always the +*last* Forge comment on the ticket. It fetches all comments, identifies the +most recent one posted by the Forge service account, and re-posts the stats +summary if a non-stats comment was added after the most recent stats comment. """ import asyncio import logging +from forge.config import get_settings from forge.integrations.jira.client import JiraClient from forge.workflow.stats import StatsState from forge.workflow.stats.formatter import format_stats_summary @@ -44,6 +52,12 @@ #: Overall timeout for the entire post_stats_comment operation (5-minute SLA). _OPERATION_TIMEOUT_SECONDS = 300.0 +#: Prefix embedded in all stats comment bodies for identification. +#: This substring is present in every comment posted by post_stats_comment / +#: ensure_stats_is_final_comment and is used by _is_stats_comment() to +#: distinguish stats comments from other Forge comments. +_STATS_BODY_MARKER = "``) + that :func:`post_stats_comment` embeds in every comment it posts. + + Args: + body: The raw text body of a Jira comment. + + Returns: + ``True`` when the body contains the stats marker, ``False`` otherwise. + """ + return _STATS_BODY_MARKER in body diff --git a/tests/unit/workflow/stats/test_ensure_stats_final.py b/tests/unit/workflow/stats/test_ensure_stats_final.py new file mode 100644 index 00000000..7069377d --- /dev/null +++ b/tests/unit/workflow/stats/test_ensure_stats_final.py @@ -0,0 +1,508 @@ +"""Unit tests for ensure_stats_is_final_comment() in forge.workflow.stats.poster. + +Tests verify: +- No Forge comments exist → posts new stats comment +- Most recent Forge comment IS a stats comment → no re-post (returns True) +- Most recent Forge comment is NOT a stats comment → re-posts stats +- Service account ID filtering: only Forge comments are considered +- When service_account_id is empty, all comments are treated as Forge comments +- JiraClient.get_comments() failure → returns False gracefully +- JiraClient is always closed after fetching comments +- _is_stats_comment() correctly identifies stats comments by marker +""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from forge.workflow.stats.poster import ( + _STATS_BODY_MARKER, + _is_stats_comment, + ensure_stats_is_final_comment, +) + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + +TICKET_KEY = "PROJ-99" +OUTCOME = "completed" +SERVICE_ACCOUNT_ID = "forge-bot-123" + +# A body that looks like a stats comment (contains the marker) +STATS_BODY = f"h2. Workflow Stats\n...\n{_STATS_BODY_MARKER}run-abc -->" + +# A body that does NOT look like a stats comment +OTHER_BODY = "This is a regular error notification comment." + + +def _minimal_stats(**overrides) -> dict: + base = { + "stats_stages": {}, + "stats_pr_urls": [], + "stats_ci_cycles": 0, + "stats_outcome": None, + "stats_outcome_reason": None, + "stats_comment_posted": False, + } + base.update(overrides) + return base + + +def _make_comment( + comment_id: str, + body: str, + author_id: str = SERVICE_ACCOUNT_ID, +) -> MagicMock: + """Build a mock JiraComment with the given attributes.""" + comment = MagicMock() + comment.id = comment_id + comment.body = body + comment.author_id = author_id + comment.created = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC) + return comment + + +def _make_jira_mock(comments: list) -> MagicMock: + """Return a mock JiraClient with get_comments returning *comments*.""" + mock = MagicMock() + mock.get_comments = AsyncMock(return_value=comments) + mock.add_comment = AsyncMock(return_value=MagicMock()) + mock.close = AsyncMock() + return mock + + +def _patch_service_account(account_id: str = SERVICE_ACCOUNT_ID): + """Context manager that patches get_settings to return account_id.""" + mock_settings = MagicMock() + mock_settings.jira_service_account_id = account_id + return patch("forge.workflow.stats.poster.get_settings", return_value=mock_settings) + + +# --------------------------------------------------------------------------- +# _is_stats_comment() helper +# --------------------------------------------------------------------------- + + +class TestIsStatsComment: + """Unit tests for the _is_stats_comment() detection helper.""" + + def test_returns_true_for_body_with_marker(self): + assert _is_stats_comment(STATS_BODY) is True + + def test_returns_true_for_minimal_marker(self): + assert _is_stats_comment("") is True + + def test_returns_false_for_plain_comment(self): + assert _is_stats_comment("Just a regular comment.") is False + + def test_returns_false_for_empty_body(self): + assert _is_stats_comment("") is False + + def test_returns_false_for_similar_but_wrong_marker(self): + # Must match the exact prefix _STATS_BODY_MARKER + assert _is_stats_comment("") is False + assert _is_stats_comment("") is False + + def test_marker_constant_starts_with_expected_prefix(self): + assert _STATS_BODY_MARKER == "