diff --git a/.env.example b/.env.example index 0782e2a4..6fa3835f 100644 --- a/.env.example +++ b/.env.example @@ -245,3 +245,21 @@ CI_FIX_MAX_RETRIES=5 CI_IGNORED_CHECKS=tide # Webhook acknowledgment timeout in seconds WEBHOOK_ACK_TIMEOUT=0.5 + +# ============================================================================= +# Stats Cost Alert Configuration +# ============================================================================= +# Enable cost alerting in workflow stats summaries. When enabled and aggregate +# token usage (input + output across all stages) exceeds the threshold, the +# stats summary will include a cost alert. +STATS_ALERT_ENABLED=true +# Total token count threshold that triggers a cost alert (default: 1,000,000). +# Applies to aggregate token usage across all workflow stages. +STATS_ALERT_THRESHOLD_TOKENS=1000000 +# Dollar cost threshold for cost alerts. When set, compares total dollar cost against +# this value instead of using the token-based threshold above. +# STATS_ALERT_THRESHOLD_COST=10.00 +# LLM pricing table as a JSON-encoded string mapping model name substrings to +# per-million-token rates (input and output in $/MTok). Longest key match wins. +# Default rates are pre-populated; override only if prices change. +# LLM_PRICING={"claude-opus-4":{"input":15.00,"output":75.00},"claude-sonnet-4":{"input":3.00,"output":15.00}} diff --git a/.gitignore b/.gitignore index f4906988..fb4eb685 100644 --- a/.gitignore +++ b/.gitignore @@ -42,6 +42,7 @@ ENV/ # Testing .pytest_cache/ +.mypy_cache/ .coverage htmlcov/ *.cover diff --git a/CLAUDE.md b/CLAUDE.md index e5d53bd0..5f27c2a4 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -122,6 +122,8 @@ podman rm $(podman ps -a --filter name=forge- -q) | `!` | Revision request — triggers regeneration with feedback | | `?` or `@forge ask` | Question — triggers Q&A answer | | `>option N` | RCA option selection (RCA Option Gate only) | +| `/forge stats` | Post current workflow statistics as a Jira comment (read-only) | +| `/forge stats retry` | Re-post stats comment, forcing a fresh calculation | | _(no prefix)_ | Informational — workflow ignores it | ## GitHub PR Comment Commands diff --git a/containers/entrypoint.py b/containers/entrypoint.py index 29022265..77b9f82f 100644 --- a/containers/entrypoint.py +++ b/containers/entrypoint.py @@ -457,6 +457,54 @@ async def run_agent_task( else: result = await agent.ainvoke(initial_message, config=config) + # Extract and aggregate tokens from usage_metadata + try: + total_input_tokens = 0 + total_output_tokens = 0 + messages = result.get("messages", []) if isinstance(result, dict) else [] + for message in messages: + msg_type = type(message).__name__ + if msg_type in ("AIMessage", "AIMessageChunk"): + usage = getattr(message, "usage_metadata", None) + if not usage: + resp_metadata = getattr(message, "response_metadata", {}) + if isinstance(resp_metadata, dict): + usage = resp_metadata.get("token_usage") or resp_metadata.get("usage") + + if isinstance(usage, dict): + total_input_tokens += ( + usage.get("input_tokens", 0) or usage.get("prompt_tokens", 0) or 0 + ) + total_output_tokens += ( + usage.get("output_tokens", 0) or usage.get("completion_tokens", 0) or 0 + ) + elif usage is not None: + total_input_tokens += ( + getattr(usage, "input_tokens", 0) + or getattr(usage, "prompt_tokens", 0) + or 0 + ) + total_output_tokens += ( + getattr(usage, "output_tokens", 0) + or getattr(usage, "completion_tokens", 0) + or 0 + ) + + metrics_dir = workspace / ".forge" + metrics_dir.mkdir(parents=True, exist_ok=True) + metrics_file = metrics_dir / "metrics.json" + metrics_file.write_text( + json.dumps( + {"input_tokens": total_input_tokens, "output_tokens": total_output_tokens}, + indent=2, + ) + ) + logger.info( + f"Saved container metrics to {metrics_file}: input_tokens={total_input_tokens}, output_tokens={total_output_tokens}" + ) + except Exception as e: + logger.warning(f"Failed to record token usage inside sandbox: {e}") + # Flush Langfuse traces before exit if langfuse_enabled: try: diff --git a/docs/guide/bug-workflow.md b/docs/guide/bug-workflow.md index 253890eb..534539fb 100644 --- a/docs/guide/bug-workflow.md +++ b/docs/guide/bug-workflow.md @@ -119,6 +119,7 @@ At any approval gate, Forge classifies your comment by its prefix: - **`!` prefix** — revision request: Forge regenerates the current artifact with your feedback - **`?` prefix or `@forge ask`** — question: Forge answers and stays paused - **`>option N`** — RCA option selection (RCA Option Gate only) +- **`/forge stats`** — posts current workflow statistics as a Jira comment (read-only) - **No prefix** — informational: ignored by the workflow --- diff --git a/docs/guide/feature-workflow.md b/docs/guide/feature-workflow.md index 8abffdac..d9fb643b 100644 --- a/docs/guide/feature-workflow.md +++ b/docs/guide/feature-workflow.md @@ -185,7 +185,7 @@ Start a comment with `!` followed by your feedback. Forge regenerates the curren ``` !!! note - Comments without a recognized prefix (`!`, `?`, `@forge ask`) are treated as informational and ignored by the workflow. Only `!`-prefixed comments trigger regeneration. + Comments without a recognized prefix (`!`, `?`, `@forge ask`, `/forge stats`) are treated as informational and ignored by the workflow. Only `!`-prefixed comments trigger regeneration. ## Handling Failures @@ -199,6 +199,52 @@ To retry, add the `forge:retry` label. Forge resumes from the exact node that fa !!! tip "CI retries" If CI fix attempts are exhausted, `forge:retry` resets the attempt counter for a fresh budget of retries. +## Workflow Statistics + +At the end of a workflow execution (when the ticket reaches a terminal state, including **Completed**, **Blocked**, or **Failed**), Forge aggregates execution data and automatically posts a comprehensive summary on the Jira ticket. This ensures that even when a workflow is blocked or fails, stakeholders can inspect the resource usage and performance metrics up to that point. This helps teams track efficiency, analyze execution bottlenecks, and monitor LLM token costs. + +### Summary Format + +The summary is generated as a Markdown table with the following columns: + +| Column | Description | +|---|---| +| **Stage** | The name of the pipeline stage (e.g., PRD, Spec, Epics, Tasks, Implementation, CI, Review). | +| **Iterations** | The number of attempts or iterations executed during that stage. | +| **Machine Time** | Monotonic duration of active processing by Forge during that stage (formatted as `1h 2m 3s`). | +| **Input Tokens** | Estimated number of LLM input tokens consumed during that stage. | +| **Output Tokens** | Estimated number of LLM output tokens consumed during that stage. | +| **Cost** | Calculated cost based on the stage's token consumption and LLM pricing mappings. | + +At the bottom of the table, a **Total** rollup row displays sum totals across all executed stages. + +### Cost Alerting + +If the cumulative resource consumption exceeds specified safety thresholds, a prominent warning alert is appended to the statistics summary comment. + +Alert thresholds are defined globally (or can be customized in the configuration): +- **Token Threshold:** Triggers if cumulative input + output tokens exceed a specified value (default: `1,000,000` tokens). +- **Dollar Threshold:** Triggers if cumulative calculated cost exceeds a specified monetary value (default: disabled/`None`). + +When triggered, a cost warning similar to the following is displayed directly below the summary table: + +```text +⚠️ WARNING: This workflow run exceeded the configured cost/token limits! +Please review the resource usage details above for potential optimizations. +``` + +--- + +## On-Demand Stats Commands + +In addition to automatic summary posting at the end of a successful workflow run, team members can request or force-refresh stats at any time using Jira comment commands. + +| Command | Action | Description | +|---|---|---| +| `/forge stats` | Request Stats | Generates the current statistics table and posts it as a comment on the Jira ticket, reflecting metrics up to the current stage of execution. | +| `/forge stats retry` | Refresh Stats | Forces a fresh recalculation of statistics and re-posts the summary table. This ensures the stats comment remains updated as the final comment on the Jira issue. | + + ## Labels Summary See [Jira Labels](labels.md) for the complete reference. diff --git a/docs/guide/labels.md b/docs/guide/labels.md index 16d7461c..26a519e9 100644 --- a/docs/guide/labels.md +++ b/docs/guide/labels.md @@ -42,7 +42,7 @@ These labels advance the pipeline. Forge watches for label changes via Jira webh **Asking questions:** Start a comment with `?` or `@forge ask`. Forge answers without advancing or regenerating. -**Informational comments:** Comments without a recognized prefix (`!`, `?`, `@forge ask`, `>option`) are ignored by the workflow — use them for team discussion without triggering Forge. +**Informational comments:** Comments without a recognized prefix (`!`, `?`, `@forge ask`, `>option`, `/forge stats`) are ignored by the workflow — use them for team discussion without triggering Forge. **Handling failures:** When `forge:blocked` appears, read the Forge comment for the error. Fix the underlying issue if needed, then add `forge:retry`. diff --git a/docs/guide/weekly-reporting.md b/docs/guide/weekly-reporting.md new file mode 100644 index 00000000..ab346ca0 --- /dev/null +++ b/docs/guide/weekly-reporting.md @@ -0,0 +1,63 @@ +# Weekly Reporting System + +Forge includes an automated, weekly aggregation and reporting system that compiles and publishes metrics across all managed tickets for a specific Jira project. This documentation explains how the reporting system operates behind the scenes. + +## Quick Start + +Generate a weekly report for your project (e.g., `PROJ`) with the following command: + +```bash +forge weekly-report --project PROJ +``` + +> **Note:** The `forge weekly-report` command requires active Redis access and must be run from the Forge project directory containing `.env` to load configurations. + +## Aggregation Logic + +When you run `forge weekly-report` (or trigger it via automated schedules), the reporting system performs the following steps: + +1. **Query Active/Historical Checkpoints:** Forge scans the Redis event and state checkpoints for the specified project (`PROJECT_KEY`). It uses a key scanning pattern `checkpoint:{PROJECT_KEY}-*` to find all state checkpoints. +2. **Filter by Sliding Window:** Metrics are collected and filtered based on a sliding window of `N` days (by default, `7` days). A checkpoint falls within the reporting window if its `updated_at` timestamp or any stage `started_at`/`ended_at` timestamp is greater than or equal to the cutoff (`now - N days`). +3. **Aggregate Stats per Stage:** Data is aggregated across all feature and bug workflows, tracking: + - **Ticket Rollups:** Total numbers of active, completed, or blocked workflows. + - **Machine Time:** Cumulative active machine processing time (monotonic durations) across all stages. + - **LLM Token Costs:** Sum of all input and output tokens consumed, translating them into actual dollar costs based on LLM pricing mappings. + - **Feature Rollups:** Metrics aggregated per epic-linked ticket and feature. Ancestry traversal resolves the parent/grandparent Feature for each ticket in Jira up to two hops (e.g., ticket -> Epic -> Feature). Tickets without a resolved Feature are grouped under the "Unassigned" bucket. + - **Bottleneck Analysis:** Identifies the slowest stage by average duration, ranks stages by iteration count, and calculates the CI fix rate. + +## Idempotency & Ticket Publishing + +To avoid cluttering Jira with duplicate reports every week, the reporting system is designed to be completely **idempotent** when publishing to Jira via the `--create-ticket` flag. + +- **Ticket Naming Convention:** The ticket summary is formatted dynamically based on the project key and current date: + ```text + Forge Weekly Report - {PROJECT} - Week of {date} + ``` + Where `{PROJECT}` is the project key, and `{date}` is the first day of the reporting week (i.e. `today - N + 1 days`). +- **Label Identification:** The system uses the special `forge:weekly-report` and `forge:generated` labels to identify and tag report tickets. +- **Idempotency Guard:** + - When `--create-ticket` is run, Forge first searches Jira using the following JQL: + ```jql + project = "{PROJECT}" AND labels = "forge:weekly-report" AND summary ~ "Week of {date}" + ``` + - If a matching ticket is found, Forge updates that existing ticket's description with the newly compiled statistics instead of creating a new one. + - If no matching ticket exists, Forge creates a new Jira Task issue, assigns the `forge:weekly-report` and `forge:generated` labels, and sets the description. + +## Stakeholder Notifications + +When using the `--notify` option alongside `--create-ticket`, Forge automatically mentions and notifies designated stakeholders. + +### Notification List Compilation + +The notification list is compiled hierarchically to allow easy overriding (highest priority first): + +1. **Jira Project Property (Highest Priority):** Forge attempts to read the `forge.weekly-report.notify` project property from Jira. This property must contain a JSON array of Jira Account IDs (e.g., `["account-id-1", "account-id-2"]`) or a comma-separated string of account IDs. +2. **Environment Variable (Global Fallback):** If no project-specific property is set, Forge falls back to the `FORGE_WEEKLY_REPORT_NOTIFY` environment variable in `.env`. This variable should contain a comma-separated list of Jira Account IDs or the keyword `"project-leads"`. The special value `"project-leads"` instructs Forge to query the per-project Jira property. +3. **No Recipients:** If neither is configured, no notifications are triggered. + +### How Notifications are Delivered + +Once the recipient account IDs are resolved: +- Forge posts a comment directly on the generated weekly report Jira ticket. +- The comment mentions each stakeholder using Jira's native `[~accountid:{id}]` mention syntax. +- This triggers email and/or Slack notifications based on the users' individual Atlassian notification preferences, ensuring visibility to project leads and management. diff --git a/docs/index.md b/docs/index.md index b03712b7..0cfba8dd 100644 --- a/docs/index.md +++ b/docs/index.md @@ -22,6 +22,8 @@ graph TD - [Getting Started](getting-started.md) — Set up Forge in 10 minutes - [Feature Workflow](guide/feature-workflow.md) — How features flow through Forge +- [Weekly Reporting Guide](guide/weekly-reporting.md) — Automated project-wide metrics and notifications +- [CLI Reference](reference/cli.md) — Command-line interface documentation - [Developer Guide](developer-guide.md) — Full local development reference - [Skills System](skills/index.md) — Customize Forge for your stack - [Contributing](dev/contributing.md) — How to contribute diff --git a/docs/reference/api.md b/docs/reference/api.md index cd1983ed..817df4d2 100644 --- a/docs/reference/api.md +++ b/docs/reference/api.md @@ -44,7 +44,7 @@ Receives Jira webhook events. Validates the signature and enqueues the event for - `jira:issue_created` — triggers new workflow if `forge:managed` label is present - `jira:issue_updated` — handles label changes (approvals, retry) -- `jira:issue_commented` — handles Q&A and revision requests +- `jira:issue_commented` — handles Q&A, revision requests, and `/forge stats` commands Returns HTTP 200 immediately. Processing is asynchronous. diff --git a/docs/reference/cli.md b/docs/reference/cli.md new file mode 100644 index 00000000..b781116a --- /dev/null +++ b/docs/reference/cli.md @@ -0,0 +1,160 @@ +# CLI Reference + +Forge provides a command-line interface (CLI) to manage workflows, inspect system health, trigger manual interventions, and view statistics or generate weekly reports. + +## Stats Commands + +### `forge stats ` + +Display workflow statistics and execution metrics for a specific Jira ticket. This command retrieves the recorded metrics from the Redis checkpoint and formats them for display. + +#### Arguments and Flags + +| Argument/Flag | Type | Description | +|---|---|---| +| `ticket` | Positional | The Jira ticket key (e.g., `AISOS-123`). This argument is required. | +| `--json` | Flag | Output the raw statistics in JSON format instead of a formatted ASCII table. | + +#### Examples + +##### 1. Displaying Stats as an ASCII Table + +```bash +forge stats AISOS-123 +``` + +**Output:** + +```text +================================================================================ +Workflow Statistics Summary for AISOS-123 +================================================================================ +Outcome: Completed + +| Stage | Iterations | Machine Time | Tokens In | Tokens Out | +|-------|------------|--------------|-----------|------------| +| PRD | 1 | 45s | 12,500 | 4,200 | +| Spec | 1 | 1m 15s | 18,300 | 6,100 | +| Epics | 1 | 30s | 9,800 | 3,100 | +| Tasks | 1 | 25s | 8,500 | 2,800 | +| Implementation | 2 | 4m 10s | 45,000 | 12,500 | +| CI | 2 | 8m 15s | 25,000 | 4,500 | +| Review | 1 | 1m 5s | 15,200 | 4,800 | +|-------|------------|--------------|-----------|------------| +| TOTAL | | 17m 0s | 134,300 | 38,000 | +================================================================================ +``` + +##### 2. Exporting Stats in JSON Format + +```bash +forge stats AISOS-123 --json +``` + +**Output:** + +```json +{ + "ticket": "AISOS-123", + "outcome": "Completed", + "outcome_detail": null, + "ci_cycles": 2, + "pr_urls": [ + "https://github.com/my-org/my-repo/pull/42" + ], + "stages": { + "prd": { + "stage_name": "prd", + "iteration_count": 1, + "machine_time_seconds": 45.0, + "input_tokens": 12500, + "output_tokens": 4200 + }, + "spec": { + "stage_name": "spec", + "iteration_count": 1, + "machine_time_seconds": 75.0, + "input_tokens": 18300, + "output_tokens": 6100 + }, + "epics": { + "stage_name": "epics", + "iteration_count": 1, + "machine_time_seconds": 30.0, + "input_tokens": 9800, + "output_tokens": 3100 + }, + "tasks": { + "stage_name": "tasks", + "iteration_count": 1, + "machine_time_seconds": 25.0, + "input_tokens": 8500, + "output_tokens": 2800 + }, + "implementation": { + "stage_name": "implementation", + "iteration_count": 2, + "machine_time_seconds": 250.0, + "input_tokens": 45000, + "output_tokens": 12500 + }, + "ci": { + "stage_name": "ci", + "iteration_count": 2, + "machine_time_seconds": 495.0, + "input_tokens": 25000, + "output_tokens": 4500 + }, + "review": { + "stage_name": "review", + "iteration_count": 1, + "machine_time_seconds": 65.0, + "input_tokens": 15200, + "output_tokens": 4800 + } + } +} +``` + +--- + +## Weekly Reporting Commands + +### `forge weekly-report` + +Generate a weekly aggregated report of workflow activity and resources consumed across all managed tickets under a specified Jira project. + +> **Note:** The `forge weekly-report` command requires active Redis access and must be run from the Forge project directory containing `.env` to load configurations. + +The report aggregates data across a sliding window of `N` days, detailing completed, in-progress, and blocked workflows, as well as total machine execution time, token usage, and costs. + +#### Options and Flags + +| Option/Flag | Description | +|---|---| +| `--project PROJECT_KEY` | **Required.** The Jira project key to scope the report (e.g., `PROJ`). | +| `--days N` | The reporting window in days (default: `7`). | +| `--output FILE` | File path to write the report to instead of standard output (`stdout`). | +| `--format FORMAT` | Output format: `text` (default), `markdown`, or `json`. | +| `--create-ticket` | Enable idempotent creation or update of a Jira weekly report issue. The ticket summary follows the pattern `Forge Weekly Report - {PROJECT} - Week of {date}` and carries the `forge:weekly-report` label. Running this command multiple times is idempotent — the existing ticket is updated with the latest content instead of creating duplicates. | +| `--notify` | Post a notification comment on the report ticket mentioning configured stakeholders. Requires `--create-ticket` to have been specified. Stakeholder account IDs are resolved from the per-project Jira property `forge.weekly-report.notify` or the `FORGE_WEEKLY_REPORT_NOTIFY` environment variable. | + +#### Examples + +##### 1. Generate text report to stdout for the last 7 days + +```bash +forge weekly-report --project PROJ +``` + +##### 2. Generate markdown report for the last 14 days and save it to a file + +```bash +forge weekly-report --project PROJ --days 14 --output report.md --format markdown +``` + +##### 3. Generate report, create/update Jira ticket, and notify stakeholders + +```bash +forge weekly-report --project PROJ --create-ticket --notify +``` diff --git a/docs/reference/config.md b/docs/reference/config.md index 72f94b5d..b8078685 100644 --- a/docs/reference/config.md +++ b/docs/reference/config.md @@ -125,6 +125,43 @@ These variables are used by `docker-compose.yml`, `devtools/docker-compose.dev.y | `REDIS_HOST` | Redis host for standalone Grafana compose | | `REDIS_PORT` | Redis port for standalone Grafana compose | +## Workflow Statistics and Weekly Reporting + +These settings configure resource tracking, cost metrics, cost alerting, and automated weekly reporting features within the Forge orchestrator. + +### Environment Variables and Pydantic Properties + +| Environment Variable | Settings Property | Type | Default Value | Description | +|----------------------|-------------------|------|---------------|-------------| +| `STATS_ALERT_ENABLED` | `stats_alert_enabled` | `bool` | `True` | Toggle to enable/disable cost alerts if token or dollar thresholds are exceeded. | +| `STATS_ALERT_THRESHOLD_TOKENS` | `stats_alert_threshold_tokens` | `int` | `1,000,000` | Cumulative token limit threshold (input + output across all stages) for triggering warnings. | +| `STATS_ALERT_THRESHOLD_COST` | `stats_alert_threshold_cost` | `float \| None` | `None` | Optional monetary threshold in USD for triggering cost warnings. If set, cost warnings are triggered based on calculated costs instead of token counts. | +| `LLM_PRICING` | `llm_pricing` | `dict[str, dict[str, float]]` | (JSON) | Pricing structure mapping LLM models or model substrings (longest match wins) to input and output token rates per million tokens. Configured as a JSON-encoded string when set via environment variables. | +| `FORGE_WEEKLY_REPORT_NOTIFY` | `weekly_report_notify` | `str` | `""` | Global fallback notification recipients. Set to a comma-separated list of Jira account IDs (e.g. `abc123,def456`) or the special value `project-leads` to defer to the per-project property `forge.weekly-report.notify`. | +| `JIRA_SERVICE_ACCOUNT_ID` | `jira_service_account_id` | `str` | `""` | Jira account ID of the Forge service account used to post comments. This is optional and auto-detected by default via the `/myself` API endpoint. When set, only comments authored by this account are treated as Forge comments when checking whether the stats comment is the final comment on a ticket (see ensure_stats_is_final_comment). | + +The default JSON structure for `LLM_PRICING` rates (USD per million tokens) is as follows: + +```json +{ + "claude-opus-4": {"input": 15.00, "output": 75.00}, + "claude-sonnet-4": {"input": 3.00, "output": 15.00}, + "claude-haiku-3-5": {"input": 0.80, "output": 4.00}, + "gemini-3.5-flash": {"input": 1.50, "output": 9.00}, + "gemini-2.5-pro": {"input": 1.25, "output": 10.00}, + "gemini-2.5-flash": {"input": 0.30, "output": 2.50}, + "gemini-2.0-flash": {"input": 0.10, "output": 0.40} +} +``` + +### Jira Project Properties + +You can customize the notification list for a specific project. Setting this property via the Jira project properties REST API overrides or resolves the `FORGE_WEEKLY_REPORT_NOTIFY` setting: + +- **Property Name:** `forge.weekly-report.notify` +- **Value:** A JSON array of Jira account IDs to be tagged/notified on weekly reports (e.g., `["account-id-1", "account-id-2"]`). + + ### MCP Servers MCP server configuration lives in `mcp-servers.json`, not `.env`. See the [MCP servers section](https://github.com/forge-sdlc/forge/blob/main/mcp-servers.json) of the repository. diff --git a/src/forge/cli.py b/src/forge/cli.py index fcd0f109..86d66e2d 100644 --- a/src/forge/cli.py +++ b/src/forge/cli.py @@ -356,7 +356,7 @@ async def cmd_list(_args: argparse.Namespace) -> int: while True: cursor, keys = await redis_client.scan( cursor=cursor, - match="langgraph:checkpoint:*", + match="checkpoint:*", count=100, ) @@ -364,8 +364,8 @@ async def cmd_list(_args: argparse.Namespace) -> int: # Extract ticket ID from key key_str = key.decode() if isinstance(key, bytes) else key parts = key_str.split(":") - if len(parts) >= 3: - ticket_id = parts[2] + if len(parts) >= 2: + ticket_id = parts[1] # Get checkpoint data data = await redis_client.get(key) if data: @@ -447,7 +447,7 @@ async def cmd_logs(args: argparse.Namespace) -> int: if not logs: # Try to get checkpoint state for any info - checkpoint_key = f"langgraph:checkpoint:{args.ticket}" + checkpoint_key = f"checkpoint:{args.ticket}" data = await redis_client.get(checkpoint_key) if data: print(f"No logs found, but checkpoint exists for {args.ticket}") @@ -635,6 +635,184 @@ async def cmd_project_setup(args: argparse.Namespace) -> int: await jira.close() +async def cmd_stats(args: argparse.Namespace) -> int: + """Display workflow statistics for a ticket.""" + import json as json_module + from typing import cast + + from forge.orchestrator.checkpointer import get_checkpoint_state + from forge.workflow.stats import StatsState + from forge.workflow.stats.formatter import format_stats_summary + + ticket = args.ticket + + try: + state = await get_checkpoint_state(ticket) + except Exception as e: + print(f"Error retrieving workflow data for {ticket}: {e}", file=sys.stderr) + return 1 + + if state is None: + print(f"No workflow data found for {ticket}") + return 1 + + # stage_timestamps key must be present (even empty dict is valid data) + if "stage_timestamps" not in state: + print(f"No workflow data found for {ticket}") + return 1 + + # Derive outcome from state (same logic as worker._handle_stats_command) + if state.get("workflow_outcome"): + outcome = state["workflow_outcome"] + outcome_detail = state.get("stats_outcome_reason") + elif state.get("is_blocked"): + outcome = "Blocked" + outcome_detail = state.get("feedback_comment") + elif state.get("last_error"): + outcome = "Failed" + outcome_detail = state.get("last_error") + else: + outcome = "In Progress" + outcome_detail = None + + if args.json: + stats_stages = state.get("stage_timestamps") or {} + pr_urls = state.get("stats_pr_urls") or [] + ci_cycles = state.get("stats_ci_cycles") or 0 + output = { + "ticket": ticket, + "outcome": outcome, + "outcome_detail": outcome_detail, + "ci_cycles": ci_cycles, + "pr_urls": pr_urls, + "stages": stats_stages, + } + print(json_module.dumps(output, indent=2)) + else: + # Use the Jira formatter for content, then display as plain text + settings = get_settings() + summary = format_stats_summary( + cast(StatsState, state), outcome, outcome_detail, pricing=settings.llm_pricing + ) + print(summary) + + return 0 + + +async def cmd_weekly_report(args: argparse.Namespace) -> int: + """Generate and output the weekly aggregated report for a Jira project.""" + import datetime + + from forge.workflow.stats.weekly_formatter import ( + format_weekly_report_cli, + format_weekly_report_json, + format_weekly_report_markdown, + ) + from forge.workflow.stats.weekly_report import collect_weekly_data + + project: str = args.project + days: int = args.days + output_path: str | None = args.output + fmt: str = args.format + create_ticket: bool = getattr(args, "create_ticket", False) + notify: bool = getattr(args, "notify", False) + + try: + report = await collect_weekly_data(project, days=days) + except Exception as e: + print(f"Error collecting weekly data for project {project!r}: {e}", file=sys.stderr) + return 1 + + # Fail gracefully when there is no data + total_tickets = ( + len(report.completed_tickets) + + len(report.in_progress_tickets) + + len(report.blocked_tickets) + ) + if total_tickets == 0: + print( + f"No workflow data found for project {project!r} in the last {days} day(s).", + file=sys.stderr, + ) + return 1 + + # Select formatter + if fmt == "json": + content = format_weekly_report_json(report) + elif fmt == "markdown": + content = format_weekly_report_markdown(report) + else: + content = format_weekly_report_cli(report) + + # Write output + if output_path: + try: + with open(output_path, "w", encoding="utf-8") as fh: + fh.write(content) + print(f"Report written to {output_path}") + except OSError as e: + print(f"Error writing to {output_path!r}: {e}", file=sys.stderr) + return 1 + else: + print(content) + + # Optionally create or update a Jira ticket with the report content. + ticket_key: str | None = None + if create_ticket: + from forge.workflow.stats.report_ticket import ensure_report_ticket + + # Derive the week_start date from the reporting window end (today) minus days. + week_start = datetime.date.today() - datetime.timedelta(days=days - 1) + + # Always use the markdown formatter for the Jira ticket description so the + # content is human-readable regardless of the --format flag chosen for stdout. + report_markdown = format_weekly_report_markdown(report) + + try: + ticket_key = await ensure_report_ticket(project, week_start, report_markdown) + print(f"Report ticket: {ticket_key}") + except Exception as e: + print(f"Error creating/updating report ticket: {e}", file=sys.stderr) + return 1 + + # Optionally send Jira notification mentions to project stakeholders. + if notify: + if not create_ticket or ticket_key is None: + print( + "Warning: --notify requires --create-ticket to have a report ticket to comment on.", + file=sys.stderr, + ) + return 1 + + from forge.workflow.stats.notifications import ( + get_notification_recipients, + notify_report_ready, + ) + + try: + recipients = await get_notification_recipients(project) + except Exception as e: + print(f"Error retrieving notification recipients: {e}", file=sys.stderr) + return 1 + + if not recipients: + print( + f"No notification recipients configured for project {project!r}. " + "Set FORGE_WEEKLY_REPORT_NOTIFY or the forge.weekly-report.notify " + "project property to enable notifications.", + file=sys.stderr, + ) + else: + try: + await notify_report_ready(ticket_key, recipients) + print(f"Notification sent to {len(recipients)} recipient(s).") + except Exception as e: + print(f"Error sending notification: {e}", file=sys.stderr) + return 1 + + return 0 + + async def cmd_health(_args: argparse.Namespace) -> int: """Check system health.""" from forge.orchestrator.checkpointer import get_redis_client @@ -872,6 +1050,90 @@ def main() -> int: ), ) + # weekly-report command + weekly_report_parser = subparsers.add_parser( + "weekly-report", + help="Generate a weekly aggregated report for a Jira project", + formatter_class=argparse.RawDescriptionHelpFormatter, + description="""Generate a weekly aggregated report of workflow activity for a Jira project. + +Examples: + # Output report to stdout in text format + forge weekly-report --project PROJ + + # Adjust the reporting window to 14 days + forge weekly-report --project PROJ --days 14 + + # Write report to a Markdown file + forge weekly-report --project PROJ --output report.md --format markdown + + # Output JSON for scripting + forge weekly-report --project PROJ --format json +""", + ) + weekly_report_parser.add_argument( + "--project", + required=True, + metavar="PROJECT_KEY", + help="Jira project key to scope the report (e.g., PROJ)", + ) + weekly_report_parser.add_argument( + "--days", + type=int, + default=7, + metavar="N", + help="Reporting window in days (default: 7)", + ) + weekly_report_parser.add_argument( + "--output", + metavar="FILE", + default=None, + help="File path to write the report to (stdout if omitted)", + ) + weekly_report_parser.add_argument( + "--format", + choices=["text", "markdown", "json"], + default="text", + metavar="FORMAT", + help="Output format: text (default), markdown, or json", + ) + weekly_report_parser.add_argument( + "--create-ticket", + action="store_true", + default=False, + help=( + "Create or update a Jira ticket storing the weekly report. " + "The ticket summary follows the format: " + "'Forge Weekly Report - {PROJECT} - Week of {date}'. " + "Running the command twice is idempotent — the existing ticket " + "is updated rather than duplicated." + ), + ) + weekly_report_parser.add_argument( + "--notify", + action="store_true", + default=False, + help=( + "Post a notification comment on the report ticket mentioning configured " + "stakeholders. Requires --create-ticket. Recipients are read from the " + "FORGE_WEEKLY_REPORT_NOTIFY env var (comma-separated Jira account IDs " + "or 'project-leads') or from the per-project Jira property " + "'forge.weekly-report.notify'." + ), + ) + + # stats command + stats_parser = subparsers.add_parser( + "stats", + help="Display workflow statistics for a ticket", + ) + stats_parser.add_argument("ticket", help="Jira ticket key (e.g., AISOS-123)") + stats_parser.add_argument( + "--json", + action="store_true", + help="Output stats as JSON", + ) + # project-setup command setup_parser = subparsers.add_parser( "project-setup", @@ -982,6 +1244,8 @@ def main() -> int: "list": cmd_list, "retry": cmd_retry, "logs": cmd_logs, + "stats": cmd_stats, + "weekly-report": cmd_weekly_report, "project-setup": cmd_project_setup, } diff --git a/src/forge/config.py b/src/forge/config.py index e1bc7db9..183b03a0 100644 --- a/src/forge/config.py +++ b/src/forge/config.py @@ -1,10 +1,11 @@ """Configuration management using Pydantic settings.""" +import json import logging from functools import cached_property, lru_cache from typing import TYPE_CHECKING, Literal -from pydantic import Field, SecretStr +from pydantic import Field, SecretStr, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict if TYPE_CHECKING: @@ -58,6 +59,16 @@ def jira_domain_resolved(self) -> str: default="", description="Custom field ID for Specification storage (optional)", ) + jira_service_account_id: str = Field( + default="", + description=( + "Jira account ID of the Forge service account used to post comments. " + "When set, only comments authored by this account are treated as Forge " + "comments when checking whether the stats comment is the final comment " + "on a ticket (see ensure_stats_is_final_comment). " + "Set via JIRA_SERVICE_ACCOUNT_ID environment variable." + ), + ) # Jira workflow configuration jira_use_labels: bool = Field( @@ -346,6 +357,76 @@ def ignored_ci_checks(self) -> list[str]: description="Enable Prometheus metrics endpoint in worker", ) + # Weekly Report Notification Configuration + weekly_report_notify: str = Field( + default="", + alias="forge_weekly_report_notify", + description=( + "Recipients to notify when a weekly report is generated. " + "Accepted values: " + "(1) A comma-separated list of Jira account IDs " + "(e.g. 'abc123,def456') — the listed users are mentioned in a " + "comment posted to the report ticket; " + "(2) The special value 'project-leads' — recipients are read from " + "the per-project Jira property 'forge.weekly-report.notify' instead. " + "When empty (default) no notification comment is posted. " + "Set via FORGE_WEEKLY_REPORT_NOTIFY environment variable." + ), + ) + + # Stats Cost Alert Configuration + stats_alert_enabled: bool = Field( + default=True, + description=( + "Enable cost alerting in workflow stats summaries. " + "When enabled and aggregate token usage exceeds stats_alert_threshold_tokens, " + "the stats summary will include a cost alert." + ), + ) + stats_alert_threshold_tokens: int = Field( + default=1_000_000, + description=( + "Total token count threshold (input + output across all stages) that triggers " + "a cost alert in the workflow stats summary. Only active when " + "stats_alert_enabled is True. Default: 1,000,000 tokens." + ), + ) + stats_alert_threshold_cost: float | None = Field( + default=None, + description=( + "Dollar cost threshold that triggers a cost alert in the workflow stats summary. " + "When set, the alert compares total dollar cost (sum of all stage costs) against " + "this value instead of comparing raw token count against " + "stats_alert_threshold_tokens. Only active when stats_alert_enabled is " + "True. Set via STATS_ALERT_THRESHOLD_COST environment variable." + ), + ) + llm_pricing: dict[str, dict[str, float]] = Field( + default_factory=lambda: { + "claude-opus-4": {"input": 15.00, "output": 75.00}, + "claude-sonnet-4": {"input": 3.00, "output": 15.00}, + "claude-haiku-3-5": {"input": 0.80, "output": 4.00}, + "gemini-3.5-flash": {"input": 1.50, "output": 9.00}, + "gemini-2.5-pro": {"input": 1.25, "output": 10.00}, + "gemini-2.5-flash": {"input": 0.30, "output": 2.50}, + "gemini-2.0-flash": {"input": 0.10, "output": 0.40}, + }, + description=( + "LLM pricing table mapping model name substrings/patterns to per-million-token " + "rates. Keys are model name substrings (longest match wins); values are dicts " + "with 'input' and 'output' keys in $/MTok. " + "Set via LLM_PRICING environment variable as a JSON-encoded string." + ), + ) + + @field_validator("llm_pricing", mode="before") + @classmethod + def parse_llm_pricing(cls, v: object) -> object: + """Parse LLM_PRICING from a JSON string when provided as an env var.""" + if isinstance(v, str): + return json.loads(v) + return v + # OpenTelemetry Configuration otlp_endpoint: str = Field( default="", diff --git a/src/forge/integrations/agents/agent.py b/src/forge/integrations/agents/agent.py index 2b69f5a1..b60fabfb 100644 --- a/src/forge/integrations/agents/agent.py +++ b/src/forge/integrations/agents/agent.py @@ -120,6 +120,8 @@ def __init__(self, settings: Settings | None = None): self._ensure_api_key() self._checkpointer = MemorySaver() self._current_repo: str = "" # Set per-task for dynamic MCP URLs + self.last_input_tokens: int = 0 + self.last_output_tokens: int = 0 # Set prompt version from config set_default_version(self.settings.prompt_version) @@ -573,7 +575,7 @@ async def _run_agent( ticket_key: str | None = None, tags: list[str] | None = None, metadata: dict[str, Any] | None = None, - ) -> str: + ) -> tuple[str, int, int]: """Run the agent with the given prompt. Implements exponential backoff retry for rate limit errors. @@ -663,10 +665,37 @@ async def _run_agent( response_text = [] messages = result.get("messages", []) if isinstance(result, dict) else [] + total_input_tokens = 0 + total_output_tokens = 0 + for message in messages: # Check if it's an AI/Assistant message (LangChain message object) msg_type = type(message).__name__ if msg_type in ("AIMessage", "AIMessageChunk"): + # Extract and aggregate tokens from usage_metadata (if present) + usage = getattr(message, "usage_metadata", None) + if not usage: + resp_metadata = getattr(message, "response_metadata", {}) + if isinstance(resp_metadata, dict): + usage = resp_metadata.get("token_usage") or resp_metadata.get("usage") + + if isinstance(usage, dict): + total_input_tokens += ( + usage.get("input_tokens", 0) or usage.get("prompt_tokens", 0) or 0 + ) + total_output_tokens += ( + usage.get("output_tokens", 0) or usage.get("completion_tokens", 0) or 0 + ) + elif usage is not None: + total_input_tokens += ( + getattr(usage, "input_tokens", 0) or getattr(usage, "prompt_tokens", 0) or 0 + ) + total_output_tokens += ( + getattr(usage, "output_tokens", 0) + or getattr(usage, "completion_tokens", 0) + or 0 + ) + content = message.content if isinstance(content, str): response_text.append(content) @@ -677,7 +706,7 @@ async def _run_agent( elif hasattr(block, "text"): response_text.append(block.text) - return "\n".join(response_text) + return "\n".join(response_text), total_input_tokens, total_output_tokens @staticmethod def _strip_preamble(text: str) -> str: @@ -765,7 +794,7 @@ async def run_task( } trace_tags, trace_metadata = resolve_trace_fields(trace_state) - result = await self._run_agent( + agent_resp = await self._run_agent( prompt=prompt, system_prompt=system_prompt, include_tools=include_tools, @@ -775,6 +804,12 @@ async def run_task( tags=trace_tags or None, metadata=trace_metadata or None, ) + if isinstance(agent_resp, tuple): + result, in_tokens, out_tokens = agent_resp + else: + result, in_tokens, out_tokens = agent_resp, 0, 0 + self.last_input_tokens = in_tokens + self.last_output_tokens = out_tokens observe_agent_duration(task_type=task, duration=time.monotonic() - _start) logger.info(f"Task '{task}' completed ({len(result)} chars)") diff --git a/src/forge/integrations/jira/client.py b/src/forge/integrations/jira/client.py index 46abecb4..a43f929f 100644 --- a/src/forge/integrations/jira/client.py +++ b/src/forge/integrations/jira/client.py @@ -23,6 +23,7 @@ # Module-level cache for project properties (persists per worker lifetime) _project_property_cache: dict[tuple[str, str], Any] = {} +_service_account_id_cache: str | None = None class MissingProjectConfig(Exception): @@ -69,6 +70,23 @@ async def close(self) -> None: await self._client.aclose() self._client = None + async def get_service_account_id(self) -> str: + """Fetch the authenticated user's Jira account ID using the /myself endpoint. + + Returns: + The accountId string of the authenticated user. + """ + global _service_account_id_cache + if _service_account_id_cache is not None: + return _service_account_id_cache + + response = await self._request_with_retry("GET", "/myself") + response.raise_for_status() + data = response.json() + account_id = data["accountId"] + _service_account_id_cache = account_id + return account_id + async def _request_with_retry( self, method: str, diff --git a/src/forge/orchestrator/worker.py b/src/forge/orchestrator/worker.py index bc915415..35e74e90 100644 --- a/src/forge/orchestrator/worker.py +++ b/src/forge/orchestrator/worker.py @@ -10,7 +10,7 @@ import uuid from dataclasses import replace as dataclass_replace from pathlib import Path -from typing import Any +from typing import Any, cast from forge.api.routes.metrics import ( record_workflow_completed, @@ -30,6 +30,9 @@ from forge.utils.redaction import redact_secrets from forge.workflow.registry import create_default_router from forge.workflow.router import WorkflowRouter +from forge.workflow.stats import StatsState +from forge.workflow.stats.formatter import format_stats_summary +from forge.workflow.stats.poster import ensure_stats_is_final_comment from forge.workflow.utils.comment_classifier import CommentType, classify_comment from forge.workflow.utils.jira_status import post_status_comment @@ -656,6 +659,32 @@ async def _handle_resume_event( comment_body = self._extract_text_from_adf(comment_body) if comment_body.strip(): + # /forge stats [subcommand] — post workflow statistics and return state + # unchanged. This is a read-only command that works regardless of workflow + # state. Supported subcommands: + # (none) — post current stats as a new comment + # retry — force fresh stats re-post via ensure_stats_is_final_comment + # Unknown subcommands are treated as informational (no-op). + if comment_body.strip().lower().startswith("/forge stats"): + # Parse optional subcommand from the remainder of the line. + remainder = comment_body.strip()[len("/forge stats") :].strip().lower() + subcommand = remainder.split()[0] if remainder.split() else "" + + if subcommand == "retry": + logger.info(f"Detected /forge stats retry command for {message.ticket_key}") + await self._handle_stats_retry_command(message.ticket_key, current_state) + elif subcommand == "": + # Base /forge stats — post current stats as a new comment. + logger.info(f"Detected /forge stats command for {message.ticket_key}") + await self._handle_stats_command(message.ticket_key, current_state) + else: + # Unknown subcommand — treat as informational, no-op. + logger.info( + f"Unknown /forge stats subcommand '{subcommand}' for " + f"{message.ticket_key} — treating as informational" + ) + return current_state + # >option N detection for rca_option_gate (runs before general classification) if current_node == "rca_option_gate": option_match = _OPTION_PATTERN.search(comment_body) @@ -903,7 +932,7 @@ async def _handle_resume_event( repo_full = payload.get("repository", {}).get("full_name", "") pr_number = payload.get("pull_request", {}).get("number") review_id = review.get("id") - inline_comments: list[dict[str, Any]] = [] + inline_comments = [] if repo_full and pr_number and review_id: _owner, _repo = repo_full.split("/", 1) gh = GitHubClient() @@ -1277,6 +1306,121 @@ async def _handle_resume_event( return updated_state + async def _handle_stats_command( + self, + ticket_key: str, + current_state: dict[str, Any], + ) -> None: + """Handle a /forge stats Jira comment command. + + Retrieves workflow statistics from the current checkpoint state, + formats them into a Jira wiki markup comment, and posts the comment + to the originating Jira ticket. The command is read-only — it never + modifies workflow state. + + Args: + ticket_key: Jira ticket key to post the stats comment on. + current_state: Current workflow state from the checkpoint. + """ + await self._post_stats_comment(ticket_key, current_state, force_repost=False) + + async def _handle_stats_retry_command( + self, + ticket_key: str, + current_state: dict[str, Any], + ) -> None: + """Handle a /forge stats retry Jira comment command. + + Forces a fresh stats calculation from the current checkpoint state, + bypassing any cached data, and re-posts the stats comment via the + re-post mechanism so that it appears as the final Forge comment. + This is useful when the original stats comment failed to post or + when the data needs to be refreshed. + + Args: + ticket_key: Jira ticket key to post the stats comment on. + current_state: Current workflow state from the checkpoint. + """ + logger.info(f"Retrying stats post for {ticket_key} — forcing fresh stats calculation") + await self._post_stats_comment(ticket_key, current_state, force_repost=True) + + async def _post_stats_comment( + self, + ticket_key: str, + current_state: dict[str, Any], + *, + force_repost: bool = False, + ) -> None: + """Shared helper for posting stats comments. + + Derives the outcome and detail from the current workflow state, + formats the stats summary, and posts (or re-posts) it to Jira. + + Args: + ticket_key: Jira ticket key to post the stats comment on. + current_state: Current workflow state from the checkpoint. + force_repost: When ``True``, use :func:`ensure_stats_is_final_comment` + to re-post the stats comment even if one was previously posted, + ensuring it appears as the final Forge comment (retry scenario). + When ``False``, post a new comment via ``JiraClient.add_comment``. + """ + stats_stages = current_state.get("stage_timestamps") + if not stats_stages and stats_stages != {}: + # No stats data found at all (missing key, not just empty dict) + logger.info(f"No workflow stats data found for {ticket_key}") + try: + jira = JiraClient() + try: + await jira.add_comment(ticket_key, "No workflow data found.") + finally: + await jira.close() + except Exception as e: + logger.warning(f"Failed to post 'no data' stats comment to {ticket_key}: {e}") + return + + # Determine current outcome from state for the on-demand stats view. + # Use pre-set workflow_outcome if available; otherwise derive from state flags. + outcome = current_state.get("workflow_outcome") or ( + "Blocked" + if current_state.get("is_blocked") + else ("Failed" if current_state.get("last_error") else "In Progress") + ) + outcome_detail = current_state.get("stats_outcome_reason") or current_state.get( + "last_error" + ) + + if force_repost: + # Use the re-post mechanism so stats appears as the final Forge comment. + try: + await ensure_stats_is_final_comment( + ticket_key, cast(StatsState, current_state), outcome, outcome_detail + ) + logger.info(f"Re-posted stats comment to {ticket_key} via retry") + except Exception as e: + logger.warning(f"Failed to re-post stats comment to {ticket_key}: {e}") + return + + try: + comment_body = format_stats_summary( + cast(StatsState, current_state), + outcome, + outcome_detail, + pricing=self.settings.llm_pricing, + ) + except Exception as e: + logger.warning(f"Failed to format stats for {ticket_key}: {e}") + comment_body = "Unable to format workflow statistics." + + try: + jira = JiraClient() + try: + await jira.add_comment(ticket_key, comment_body) + logger.info(f"Posted on-demand stats comment to {ticket_key}") + finally: + await jira.close() + except Exception as e: + logger.warning(f"Failed to post stats comment to {ticket_key}: {e}") + async def _post_resume_ack_comment( self, ticket_key: str, diff --git a/src/forge/sandbox/runner.py b/src/forge/sandbox/runner.py index 5d81afa1..99cf385c 100644 --- a/src/forge/sandbox/runner.py +++ b/src/forge/sandbox/runner.py @@ -47,6 +47,8 @@ class ContainerResult: stderr: str tests_passed: bool | None = None # None if tests were skipped error_message: str | None = None + input_tokens: int = 0 + output_tokens: int = 0 @property def tests_failed(self) -> bool: @@ -446,6 +448,20 @@ async def run( logger.info(f"Container exited with code {exit_code}") + # Parse metrics.json if written by entrypoint.py + input_tokens = 0 + output_tokens = 0 + metrics_file = workspace_path / ".forge" / "metrics.json" + if metrics_file.exists(): + try: + metrics_data = json.loads(metrics_file.read_text()) + input_tokens = int(metrics_data.get("input_tokens", 0) or 0) + output_tokens = int(metrics_data.get("output_tokens", 0) or 0) + except Exception as e: + logger.warning(f"Failed to parse metrics.json in sandbox runner: {e}") + finally: + metrics_file.unlink(missing_ok=True) + # Log container output if exit_code != EXIT_SUCCESS: # Failure: stderr at INFO, stdout at DEBUG @@ -474,6 +490,8 @@ async def run( stdout=stdout_str, stderr=stderr_str, tests_passed=True, + input_tokens=input_tokens, + output_tokens=output_tokens, ) elif exit_code == EXIT_TESTS_FAILED: return ContainerResult( @@ -483,6 +501,8 @@ async def run( stderr=stderr_str, tests_passed=False, error_message="Tests failed after max retries", + input_tokens=input_tokens, + output_tokens=output_tokens, ) else: return ContainerResult( @@ -491,6 +511,8 @@ async def run( stdout=stdout_str, stderr=stderr_str, error_message=f"Task failed with exit code {exit_code}", + input_tokens=input_tokens, + output_tokens=output_tokens, ) finally: diff --git a/src/forge/stats/__init__.py b/src/forge/stats/__init__.py new file mode 100644 index 00000000..ffe445c2 --- /dev/null +++ b/src/forge/stats/__init__.py @@ -0,0 +1,43 @@ +"""Stats service package for Forge workflow statistics. + +This package provides a unified interface for retrieving and validating +workflow statistics data from LangGraph checkpoints. It is consumed by +both Jira command handlers and CLI commands. + +Public API +---------- +``WorkflowStats`` + Dataclass containing fully-validated stats fields extracted from a + checkpoint. + +``get_workflow_stats(ticket_key)`` + Async function that retrieves stats for a ticket. Returns ``None`` + when no checkpoint or no stats data is found. + +``get_workflow_stats_or_error(ticket_key)`` + Async function that returns ``(stats, error_message)``; never raises. + Suitable for CLI / command-handler callers that need a display-ready + error string instead of an exception. + +``format_stats_table(stats, *, use_color=False)`` + Render a ``WorkflowStats`` as a human-readable ASCII table for terminal + display. + +``format_stats_json(stats)`` + Serialize a ``WorkflowStats`` to a pretty-printed JSON string. +""" + +from forge.stats.cli_formatter import format_stats_json, format_stats_table +from forge.stats.retrieval import ( + WorkflowStats, + get_workflow_stats, + get_workflow_stats_or_error, +) + +__all__ = [ + "WorkflowStats", + "format_stats_json", + "format_stats_table", + "get_workflow_stats", + "get_workflow_stats_or_error", +] diff --git a/src/forge/stats/cli_formatter.py b/src/forge/stats/cli_formatter.py new file mode 100644 index 00000000..92246a0f --- /dev/null +++ b/src/forge/stats/cli_formatter.py @@ -0,0 +1,345 @@ +"""CLI formatter for workflow statistics terminal output. + +This module renders ``WorkflowStats`` as human-readable ASCII tables or +pretty-printed JSON, suitable for terminal display via ``forge stats``. + +It complements the Jira wiki markup formatter in +``forge.workflow.stats.formatter`` — that module targets Jira comments +while this one targets terminal output. + +Usage:: + + from forge.stats.cli_formatter import format_stats_table, format_stats_json + + # ASCII table for terminal display + print(format_stats_table(stats)) + + # Pretty-printed JSON for scripting + print(format_stats_json(stats)) +""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + pass + +from forge.stats.retrieval import WorkflowStats +from forge.workflow.stats import ( + ALL_BUG_STAGES, + ALL_FEATURE_STAGES, + StageStats, +) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +#: Em-dash used when a stage was never executed (matches Jira formatter). +_DASH = "\u2014" + +#: Display labels for each stage key. +_STAGE_LABELS: dict[str, str] = { + "prd": "PRD", + "spec": "Spec", + "epics": "Epics", + "tasks": "Tasks", + "implementation": "Implementation", + "ci": "CI", + "review": "Review", + "triage": "Triage", + "rca": "RCA", + "planning": "Planning", +} + +#: ANSI colour codes used for optional colorized output. +_COLOR_GREEN = "\033[32m" +_COLOR_RED = "\033[31m" +_COLOR_YELLOW = "\033[33m" +_COLOR_BOLD = "\033[1m" +_COLOR_RESET = "\033[0m" + +# Column header names. +_HEADERS = ("Stage", "Iterations", "Machine Time", "Tokens In", "Tokens Out") + +# --------------------------------------------------------------------------- +# Internal helpers — formatting primitives +# --------------------------------------------------------------------------- + + +def _fmt_seconds(seconds: float) -> str: + """Format a duration in seconds to a compact string (e.g. ``'1h 23m 45s'``). + + Zero-value components are elided: ``60`` → ``'1m 0s'``, + ``3601`` → ``'1h 0m 1s'``. + """ + total = int(seconds) + hours, remainder = divmod(total, 3600) + minutes, secs = divmod(remainder, 60) + if hours: + return f"{hours}h {minutes}m {secs}s" + if minutes: + return f"{minutes}m {secs}s" + return f"{secs}s" + + +def _fmt_tokens(count: int) -> str: + """Format a token count with thousands separators (e.g. ``'1,234,567'``).""" + return f"{count:,}" + + +def _truncate(text: str, max_len: int) -> str: + """Truncate *text* to *max_len* characters, appending ``'...'`` if needed.""" + if len(text) <= max_len: + return text + return text[: max_len - 3] + "..." + + +def _colorize(text: str, color: str, *, use_color: bool) -> str: + """Wrap *text* in ANSI *color* escape codes if *use_color* is True.""" + if not use_color: + return text + return f"{color}{text}{_COLOR_RESET}" + + +# --------------------------------------------------------------------------- +# Internal helpers — table building +# --------------------------------------------------------------------------- + + +def _stage_row_values(label: str, stage: StageStats | None) -> tuple[str, str, str, str, str]: + """Return the five cell values for a single stage row. + + When *stage* is ``None`` (stage was never executed), all metric cells + contain the em-dash sentinel ``"—"``. + """ + if stage is None: + return (label, _DASH, _DASH, _DASH, _DASH) + + iterations = str(stage.get("iteration_count", 0)) + machine_time = _fmt_seconds(stage.get("machine_time_seconds", 0.0)) + tokens_in = _fmt_tokens(stage.get("input_tokens", 0)) + tokens_out = _fmt_tokens(stage.get("output_tokens", 0)) + return (label, iterations, machine_time, tokens_in, tokens_out) + + +def _totals_row_values(stages: dict[str, StageStats]) -> tuple[str, str, str, str, str]: + """Return the five cell values for the summary totals row.""" + total_machine = sum(s.get("machine_time_seconds", 0.0) or 0.0 for s in stages.values()) + total_in = sum(s.get("input_tokens", 0) or 0 for s in stages.values()) + total_out = sum(s.get("output_tokens", 0) or 0 for s in stages.values()) + return ( + "TOTAL", + "", + _fmt_seconds(total_machine), + _fmt_tokens(total_in), + _fmt_tokens(total_out), + ) + + +def _render_table( + rows: list[tuple[str, ...]], + col_widths: list[int], + *, + header_sep: bool = True, +) -> list[str]: + """Render *rows* as an ASCII table given pre-computed *col_widths*. + + Returns a list of strings (one per line). The first row is always the + header; a separator line is inserted below it when *header_sep* is True. + """ + + def _row_line(cells: tuple[str, ...]) -> str: + padded = [cell.ljust(col_widths[i]) for i, cell in enumerate(cells)] + return "| " + " | ".join(padded) + " |" + + def _sep_line() -> str: + return "+-" + "-+-".join("-" * w for w in col_widths) + "-+" + + lines: list[str] = [] + for i, row in enumerate(rows): + lines.append(_row_line(row)) + if i == 0 and header_sep: + lines.append(_sep_line()) + lines.append(_sep_line()) + return lines + + +def _compute_col_widths( + rows: list[tuple[str, ...]], + max_col_width: int = 20, +) -> list[int]: + """Compute column widths from all rows, capping at *max_col_width*.""" + if not rows: + return [] + n_cols = len(rows[0]) + widths = [0] * n_cols + for row in rows: + for i, cell in enumerate(row): + widths[i] = max(widths[i], min(len(cell), max_col_width)) + return widths + + +def _determine_display_stages(stages: dict[str, StageStats]) -> list[str]: + """Return the ordered list of stage keys to display. + + Uses ``ALL_FEATURE_STAGES`` by default. If the workflow contains any + bug-only stages (``triage``, ``rca``, ``planning``) that are absent from + the feature list, the bug stage ordering is preferred. + """ + bug_only = {"triage", "rca", "planning"} + if any(k in stages for k in bug_only): + return ALL_BUG_STAGES + return ALL_FEATURE_STAGES + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def format_stats_table( + stats: WorkflowStats, + *, + use_color: bool = False, + max_col_width: int = 20, +) -> str: + """Render *stats* as a human-readable ASCII table for terminal display. + + The output includes: + + * A metadata block: ticket key, outcome, CI cycles, workflow run ID. + * A stage-by-stage metrics table with columns: + Stage | Iterations | Machine Time | Tokens In | Tokens Out + * A summary totals row (times and tokens summed across all stages). + * A PR links section (omitted when no PRs were created). + + Stages that were never executed show ``"—"`` in all metric columns, + consistent with the Jira formatter. + + Args: + stats: The ``WorkflowStats`` instance to format. + use_color: When ``True``, ANSI color codes are applied: green for + "Completed", red for "Failed", yellow for "Blocked". + max_col_width: Maximum width of any table column (characters). + Longer values are truncated with ``'...'``. Defaults to 20. + + Returns: + A multi-line string suitable for printing to a terminal. + """ + lines: list[str] = [] + + # ------------------------------------------------------------------ + # Metadata block + # ------------------------------------------------------------------ + outcome_raw = stats.outcome or "In Progress" + outcome_lower = outcome_raw.lower() + + if use_color: + if outcome_lower == "completed": + outcome_display = _colorize(outcome_raw, _COLOR_GREEN, use_color=True) + elif outcome_lower.startswith("failed"): + outcome_display = _colorize(outcome_raw, _COLOR_RED, use_color=True) + elif outcome_lower.startswith("blocked"): + outcome_display = _colorize(outcome_raw, _COLOR_YELLOW, use_color=True) + else: + outcome_display = outcome_raw + else: + outcome_display = outcome_raw + + lines.append(_colorize("Workflow Statistics", _COLOR_BOLD, use_color=use_color)) + lines.append("") + lines.append(f" Ticket: {stats.ticket_key}") + lines.append(f" Outcome: {outcome_display}") + if stats.outcome_reason: + reason = _truncate(stats.outcome_reason, 80) + lines.append(f" Reason: {reason}") + lines.append(f" CI Cycles: {stats.ci_cycles}") + if stats.workflow_run_id: + lines.append(f" Run ID: {stats.workflow_run_id}") + + # Derive created_at / updated_at from stage timestamps. + all_started = [str(s.get("started_at")) for s in stats.stages.values() if s.get("started_at")] + all_ended = [str(s.get("ended_at")) for s in stats.stages.values() if s.get("ended_at")] + if all_started: + lines.append(f" Started: {min(all_started)}") + if all_ended: + lines.append(f" Last Updated: {max(all_ended)}") + + lines.append("") + + # ------------------------------------------------------------------ + # Stage metrics table + # ------------------------------------------------------------------ + display_stages = _determine_display_stages(stats.stages) + + data_rows: list[tuple[str, str, str, str, str]] = [] + for stage_key in display_stages: + label = _STAGE_LABELS.get(stage_key, stage_key.title()) + stage_data = stats.stages.get(stage_key) + data_rows.append(_stage_row_values(label, stage_data)) + + # Totals row (only meaningful when at least one stage ran) + totals = _totals_row_values(stats.stages) + data_rows.append(totals) + + # Truncate cell values to max_col_width before computing widths. + truncated_rows: list[tuple[str, ...]] = [] + for row in data_rows: + truncated_rows.append(tuple(_truncate(cell, max_col_width) for cell in row)) + + all_rows: list[tuple[str, ...]] = [_HEADERS, *truncated_rows] + col_widths = _compute_col_widths(all_rows, max_col_width=max_col_width) + table_lines = _render_table(all_rows, col_widths) + lines.extend(table_lines) + + # ------------------------------------------------------------------ + # PR links section (omitted when no PRs) + # ------------------------------------------------------------------ + if stats.pr_urls: + lines.append("") + lines.append("Pull Requests:") + for url in stats.pr_urls: + lines.append(f" {url}") + + return "\n".join(lines) + + +def format_stats_json(stats: WorkflowStats) -> str: + """Render *stats* as pretty-printed JSON. + + The JSON document includes all ``WorkflowStats`` fields with their + proper Python types serialised to JSON-safe equivalents. The output + is indented with 2 spaces and keys are sorted alphabetically for + stable, diff-friendly output. + + Args: + stats: The ``WorkflowStats`` instance to serialise. + + Returns: + A pretty-printed JSON string. + """ + payload: dict[str, Any] = { + "ticket_key": stats.ticket_key, + "outcome": stats.outcome, + "outcome_reason": stats.outcome_reason, + "ci_cycles": stats.ci_cycles, + "comment_posted": stats.comment_posted, + "workflow_run_id": stats.workflow_run_id, + "pr_urls": stats.pr_urls, + "stages": { + stage_key: { + "stage_name": stage_data.get("stage_name", stage_key), + "iteration_count": stage_data.get("iteration_count", 0), + "machine_time_seconds": stage_data.get("machine_time_seconds", 0.0), + "input_tokens": stage_data.get("input_tokens", 0), + "output_tokens": stage_data.get("output_tokens", 0), + "started_at": stage_data.get("started_at"), + "ended_at": stage_data.get("ended_at"), + } + for stage_key, stage_data in stats.stages.items() + }, + } + return json.dumps(payload, indent=2, sort_keys=True) diff --git a/src/forge/stats/retrieval.py b/src/forge/stats/retrieval.py new file mode 100644 index 00000000..41f3c8c7 --- /dev/null +++ b/src/forge/stats/retrieval.py @@ -0,0 +1,186 @@ +"""Stats retrieval service for workflow checkpoints. + +This module provides a unified interface for retrieving and validating +workflow statistics data from LangGraph checkpoints. It is used by both +Jira command handlers and CLI commands. + +Usage:: + + from forge.stats.retrieval import get_workflow_stats, get_workflow_stats_or_error + + stats = await get_workflow_stats("AISOS-123") + if stats is None: + # No checkpoint or no stats data + ... + + # Or, get a result with an error message suitable for display: + stats, error = await get_workflow_stats_or_error("AISOS-123") + if error: + print(error) +""" + +import logging +from dataclasses import dataclass, field +from typing import Any + +from forge.orchestrator.checkpointer import get_checkpoint_state +from forge.workflow.stats import StageStats + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Return type +# --------------------------------------------------------------------------- + + +@dataclass +class WorkflowStats: + """Validated workflow statistics extracted from a checkpoint. + + All fields mirror the corresponding fields in ``StatsState``. The + dataclass is always fully populated — callers do not need to handle + missing keys individually. Fields that were absent in the checkpoint + carry their zero / empty defaults so that partial (in-progress) + workflows are represented cleanly. + + Attributes: + ticket_key: The Jira ticket key this stats snapshot belongs to. + stages: Per-stage metrics, keyed by stage name. + pr_urls: URLs of pull requests opened during the workflow run. + ci_cycles: Number of CI fix-attempt cycles triggered. + outcome: Final outcome string, or ``None`` while the workflow is + still in progress (e.g. ``"Completed"``, ``"Failed: …"``). + outcome_reason: Human-readable elaboration on the outcome, or + ``None`` when not applicable. + comment_posted: Whether the summary stats comment has already been + posted to the Jira ticket. + workflow_run_id: Unique identifier for this workflow run (UUID4). + Empty string when the checkpoint predates idempotency support. + """ + + ticket_key: str + stages: dict[str, StageStats] = field(default_factory=dict) + pr_urls: list[str] = field(default_factory=list) + ci_cycles: int = 0 + outcome: str | None = None + outcome_reason: str | None = None + comment_posted: bool = False + workflow_run_id: str = "" + + +# --------------------------------------------------------------------------- +# Retrieval helpers +# --------------------------------------------------------------------------- + + +def _extract_stats(ticket_key: str, state: dict[str, Any]) -> WorkflowStats | None: + """Extract and validate stats data from a checkpoint state dict. + + Args: + ticket_key: The Jira ticket key for logging context. + state: The raw checkpoint state dict from ``get_checkpoint_state``. + + Returns: + A populated ``WorkflowStats`` instance, or ``None`` when the + checkpoint contains no stats data (e.g. legacy workflows). + """ + if "stage_timestamps" not in state: + logger.debug( + "Checkpoint for %s has no stage_timestamps key (legacy workflow or pre-stats run)", + ticket_key, + ) + return None + + stages = state.get("stage_timestamps") or {} + if not isinstance(stages, dict): + logger.warning( + "Checkpoint for %s has malformed stage_timestamps (expected dict, got %s); " + "treating as empty", + ticket_key, + type(stages).__name__, + ) + stages = {} + + pr_urls = state.get("stats_pr_urls") or [] + if not isinstance(pr_urls, list): + logger.warning( + "Checkpoint for %s has malformed stats_pr_urls (expected list, got %s); " + "treating as empty", + ticket_key, + type(pr_urls).__name__, + ) + pr_urls = [] + + return WorkflowStats( + ticket_key=ticket_key, + stages=stages, + pr_urls=pr_urls, + ci_cycles=state.get("stats_ci_cycles") or 0, + outcome=state.get("workflow_outcome"), + outcome_reason=state.get("stats_outcome_reason"), + comment_posted=bool(state.get("stats_comment_posted", False)), + workflow_run_id=state.get("workflow_run_id", ""), + ) + + +async def get_workflow_stats(ticket_key: str) -> WorkflowStats | None: + """Retrieve workflow statistics for a ticket from its checkpoint. + + Looks up the LangGraph checkpoint for *ticket_key* and extracts the + ``StatsState`` fields. The function is intentionally tolerant: + + - Returns ``None`` when no checkpoint exists for the ticket. + - Returns ``None`` when the checkpoint exists but contains no stats + data (legacy workflows that predate stats tracking). + - Returns a partially-populated ``WorkflowStats`` for in-progress + workflows (fields that have not yet been set carry their zero/empty + defaults). + + Args: + ticket_key: The Jira ticket key (e.g. ``"AISOS-123"``). + + Returns: + A ``WorkflowStats`` instance with all available data, or ``None`` + if no checkpoint or no stats data was found. + """ + state = await get_checkpoint_state(ticket_key) + + if state is None: + logger.debug("No checkpoint found for %s", ticket_key) + return None + + return _extract_stats(ticket_key, state) + + +async def get_workflow_stats_or_error( + ticket_key: str, +) -> tuple[WorkflowStats | None, str | None]: + """Retrieve workflow statistics, returning a display-ready error on failure. + + A convenience wrapper around ``get_workflow_stats`` that never raises. + On success the error string is ``None``; on failure the stats object is + ``None`` and the error string contains a human-readable message suitable + for printing to a terminal or posting as a Jira comment. + + Args: + ticket_key: The Jira ticket key (e.g. ``"AISOS-123"``). + + Returns: + A ``(WorkflowStats | None, str | None)`` tuple where exactly one + element is always ``None``: + + - ``(stats, None)`` on success. + - ``(None, error_message)`` when no stats are available or an + exception occurred. + """ + try: + stats = await get_workflow_stats(ticket_key) + except Exception as exc: + logger.error("Failed to retrieve stats for %s: %s", ticket_key, exc) + return None, f"Error retrieving workflow data for {ticket_key}: {exc}" + + if stats is None: + return None, f"No workflow data found for {ticket_key}" + + return stats, None diff --git a/src/forge/workflow/__init__.py b/src/forge/workflow/__init__.py index 4d68775c..67e7587d 100644 --- a/src/forge/workflow/__init__.py +++ b/src/forge/workflow/__init__.py @@ -9,6 +9,7 @@ ) from forge.workflow.registry import create_default_router from forge.workflow.router import WorkflowRouter +from forge.workflow.stats import StageStats, StatsState __all__ = [ "BaseState", @@ -16,6 +17,8 @@ "CIIntegrationState", "PRIntegrationState", "ReviewIntegrationState", + "StageStats", + "StatsState", "WorkflowRouter", "create_default_router", ] diff --git a/src/forge/workflow/base.py b/src/forge/workflow/base.py index 857112c0..3159a0a8 100644 --- a/src/forge/workflow/base.py +++ b/src/forge/workflow/base.py @@ -1,4 +1,16 @@ -"""Base workflow classes and state definitions.""" +"""Base workflow classes and state definitions. + +Mixin TypedDicts +---------------- +Compose workflow states from the following mixins: + +* :class:`PRIntegrationState` — for workflows that open pull requests. +* :class:`CIIntegrationState` — for workflows that run CI checks. +* :class:`ReviewIntegrationState` — for workflows with review stages. +* :class:`~forge.workflow.stats.StatsState` — for workflows that record + execution statistics (iteration counts, token usage, timing, outcome). + Defined in :mod:`forge.workflow.stats`. +""" from abc import ABC, abstractmethod from datetime import datetime @@ -8,6 +20,17 @@ from langgraph.graph.message import add_messages from forge.models.workflow import TicketType +from forge.workflow.stats import StageStats, StatsState + +__all__ = [ + "BaseState", + "BaseWorkflow", + "CIIntegrationState", + "PRIntegrationState", + "ReviewIntegrationState", + "StageStats", + "StatsState", +] class BaseState(TypedDict, total=False): diff --git a/src/forge/workflow/bug/graph.py b/src/forge/workflow/bug/graph.py index 01eaf5ae..a2f0329a 100644 --- a/src/forge/workflow/bug/graph.py +++ b/src/forge/workflow/bug/graph.py @@ -45,6 +45,7 @@ route_rca_option, ) from forge.workflow.nodes.rebase import rebase_pr +from forge.workflow.nodes.stats_posting import post_terminal_stats from forge.workflow.nodes.triage import route_triage_gate, triage_check, triage_gate from forge.workflow.nodes.workspace_setup import setup_workspace from forge.workflow.utils import resolve_shared_resume_node @@ -393,10 +394,14 @@ def build_bug_graph() -> StateGraph: 2. Analysis + reflection: analyze_bug ↔ reflect_rca → rca_option_gate (pause) 3. Planning: plan_bug_fix → plan_approval_gate (pause) → decompose_plan → END 4. (Spawned tasks are handled by the task workflow) - 5. Post-merge: human_review_gate → post_merge_summary → END + 5. Post-merge: human_review_gate → post_merge_summary → post_terminal_stats → END Backward-compat implementation/CI/review nodes are preserved for in-flight tickets. + Terminal paths all route through post_terminal_stats before END: + - Success: post_merge_summary → post_terminal_stats → END + - Blocked: escalate_blocked → post_terminal_stats → END + Returns: Configured StateGraph ready for compilation. """ @@ -451,6 +456,9 @@ def build_bug_graph() -> StateGraph: graph.add_node("implement_review", implement_review) graph.add_node("review_response_gate", review_response_gate) + # Stats posting node — always the last node before END on terminal paths + graph.add_node("post_terminal_stats", post_terminal_stats) + # ── Set entry point ── graph.set_entry_point("route_entry") @@ -668,7 +676,7 @@ def build_bug_graph() -> StateGraph: "attempt_ci_fix": "escalate_blocked", }, ) - graph.add_edge("escalate_blocked", END) + graph.add_edge("escalate_blocked", "post_terminal_stats") # ── Review flow (merge path → post_merge_summary) ── # "complete_tasks" is the feature-workflow merge return from route_human_review; @@ -733,6 +741,7 @@ def build_bug_graph() -> StateGraph: ) # ── Post-merge terminal ── - graph.add_edge("post_merge_summary", END) + graph.add_edge("post_merge_summary", "post_terminal_stats") + graph.add_edge("post_terminal_stats", END) return graph diff --git a/src/forge/workflow/bug/state.py b/src/forge/workflow/bug/state.py index 3dac40c3..c5086918 100644 --- a/src/forge/workflow/bug/state.py +++ b/src/forge/workflow/bug/state.py @@ -1,5 +1,6 @@ """Bug workflow state definition.""" +import uuid from datetime import datetime from typing import Any @@ -10,11 +11,17 @@ CIIntegrationState, PRIntegrationState, ReviewIntegrationState, + StatsState, ) class BugState( - BaseState, PRIntegrationState, CIIntegrationState, ReviewIntegrationState, total=False + BaseState, + PRIntegrationState, + CIIntegrationState, + ReviewIntegrationState, + StatsState, + total=False, ): """State specific to Bug workflow.""" @@ -134,6 +141,17 @@ def create_initial_bug_state(ticket_key: str, **kwargs: Any) -> BugState: "qualitative_review_failed": False, "reflect_rca_retry_count": 0, "yolo_mode": False, + # Stats fields + "stage_timestamps": {}, + "revision_counts": {}, + "token_usage": {}, + "stage_token_usage": {}, + "stats_pr_urls": [], + "stats_ci_cycles": 0, + "workflow_outcome": None, + "stats_outcome_reason": None, + "stats_comment_posted": False, + "workflow_run_id": str(uuid.uuid4()), } # Merge with kwargs, letting kwargs override defaults diff --git a/src/forge/workflow/feature/graph.py b/src/forge/workflow/feature/graph.py index a51f010e..40822d19 100644 --- a/src/forge/workflow/feature/graph.py +++ b/src/forge/workflow/feature/graph.py @@ -53,6 +53,7 @@ ) from forge.workflow.nodes.qa_handler import answer_question from forge.workflow.nodes.rebase import rebase_pr +from forge.workflow.nodes.stats_posting import post_terminal_stats from forge.workflow.nodes.task_generation import ( regenerate_all_tasks, regenerate_epic_tasks, @@ -154,17 +155,18 @@ def route_by_ticket_type(state: FeatureState) -> str: def _route_after_generation(state: FeatureState) -> str: """Route based on PRD generation success. - If generation failed (has error and no PRD content), don't advance to approval gate. + If generation failed (has error and no PRD content), route to stats posting + before ending the workflow. Returns: - "prd_approval_gate" on success, END on failure. + "prd_approval_gate" on success, "post_terminal_stats" on unrecoverable failure. """ last_error = state.get("last_error") prd_content = state.get("prd_content", "") if last_error and not prd_content: - logger.error(f"PRD generation failed, workflow paused: {last_error}") - return END + logger.error(f"PRD generation failed, workflow ending: {last_error}") + return "post_terminal_stats" return "prd_approval_gate" @@ -172,17 +174,18 @@ def _route_after_generation(state: FeatureState) -> str: def _route_after_spec_generation(state: FeatureState) -> str: """Route based on spec generation success. - If generation failed (has error and no spec content), don't advance to approval gate. + If generation failed (has error and no spec content), route to stats posting + before ending the workflow. Returns: - "spec_approval_gate" on success, END on failure. + "spec_approval_gate" on success, "post_terminal_stats" on unrecoverable failure. """ last_error = state.get("last_error") spec_content = state.get("spec_content", "") if last_error and not spec_content: - logger.error(f"Spec generation failed, workflow paused: {last_error}") - return END + logger.error(f"Spec generation failed, workflow ending: {last_error}") + return "post_terminal_stats" return "spec_approval_gate" @@ -190,17 +193,18 @@ def _route_after_spec_generation(state: FeatureState) -> str: def _route_after_epic_decomposition(state: FeatureState) -> str: """Route based on epic decomposition success. - If decomposition failed (has error and no epics), don't advance to approval gate. + If decomposition failed (has error and no epics), route to stats posting + before ending the workflow. Returns: - "plan_approval_gate" on success, END ("__end__") on failure. + "plan_approval_gate" on success, "post_terminal_stats" on unrecoverable failure. """ last_error = state.get("last_error") epic_keys = state.get("epic_keys", []) if last_error and not epic_keys: - logger.error(f"Epic decomposition failed, workflow paused: {last_error}") - return END + logger.error(f"Epic decomposition failed, workflow ending: {last_error}") + return "post_terminal_stats" return "plan_approval_gate" @@ -208,17 +212,18 @@ def _route_after_epic_decomposition(state: FeatureState) -> str: def _route_after_task_generation(state: FeatureState) -> str: """Route based on task generation success. - If task generation failed (has error and no tasks), don't advance. + If task generation failed (has error and no tasks), route to stats posting + before ending the workflow. Returns: - "task_approval_gate" on success, END on failure. + "task_approval_gate" on success, "post_terminal_stats" on unrecoverable failure. """ last_error = state.get("last_error") task_keys = state.get("task_keys", []) if last_error and not task_keys: - logger.error(f"Task generation failed, workflow paused: {last_error}") - return END + logger.error(f"Task generation failed, workflow ending: {last_error}") + return "post_terminal_stats" return "task_approval_gate" @@ -419,7 +424,12 @@ def build_feature_graph() -> StateGraph: 22. ci_evaluator: checks CI status, attempts autonomous fixes on failure (up to 5 retries) 23. ci_evaluator (passed) -> human_review_gate 24. human_review_gate -> complete_tasks - 25. complete_tasks -> aggregate_epic_status -> aggregate_feature_status -> END + 25. complete_tasks -> aggregate_epic_status -> aggregate_feature_status -> post_terminal_stats -> END + + Terminal paths all route through post_terminal_stats before END: + - Success: aggregate_feature_status -> post_terminal_stats -> END + - Blocked: escalate_blocked -> post_terminal_stats -> END + - Failure: unrecoverable generation errors -> post_terminal_stats -> END Returns: Configured StateGraph ready for compilation. @@ -480,6 +490,9 @@ def build_feature_graph() -> StateGraph: graph.add_node("aggregate_epic_status", aggregate_epic_status) graph.add_node("aggregate_feature_status", aggregate_feature_status) + # Stats posting node — always the last node before END on terminal paths + graph.add_node("post_terminal_stats", post_terminal_stats) + # Q&A node graph.add_node("answer_question", answer_question) @@ -539,7 +552,7 @@ def build_feature_graph() -> StateGraph: _route_after_generation, { "prd_approval_gate": "prd_approval_gate", - END: END, + "post_terminal_stats": "post_terminal_stats", # unrecoverable failure }, ) graph.add_conditional_edges( @@ -567,7 +580,7 @@ def build_feature_graph() -> StateGraph: _route_after_spec_generation, { "spec_approval_gate": "spec_approval_gate", - END: END, + "post_terminal_stats": "post_terminal_stats", # unrecoverable failure }, ) graph.add_conditional_edges( @@ -595,7 +608,7 @@ def build_feature_graph() -> StateGraph: _route_after_epic_decomposition, { "plan_approval_gate": "plan_approval_gate", - END: END, # Error state - don't advance + "post_terminal_stats": "post_terminal_stats", # unrecoverable failure }, ) graph.add_conditional_edges( @@ -632,7 +645,7 @@ def build_feature_graph() -> StateGraph: _route_after_task_generation, { "task_approval_gate": "task_approval_gate", - END: END, + "post_terminal_stats": "post_terminal_stats", # unrecoverable failure }, ) graph.add_conditional_edges( @@ -747,7 +760,7 @@ def build_feature_graph() -> StateGraph: "attempt_ci_fix": "escalate_blocked", }, ) - graph.add_edge("escalate_blocked", END) + graph.add_edge("escalate_blocked", "post_terminal_stats") # Human Review flow (US9) graph.add_conditional_edges( @@ -781,7 +794,8 @@ def build_feature_graph() -> StateGraph: ) graph.add_edge("complete_tasks", "aggregate_epic_status") graph.add_edge("aggregate_epic_status", "aggregate_feature_status") - graph.add_edge("aggregate_feature_status", END) + graph.add_edge("aggregate_feature_status", "post_terminal_stats") + graph.add_edge("post_terminal_stats", END) # Q&A routing: answer_question returns to the gate it came from graph.add_conditional_edges( diff --git a/src/forge/workflow/feature/state.py b/src/forge/workflow/feature/state.py index d67e84d9..046705c9 100644 --- a/src/forge/workflow/feature/state.py +++ b/src/forge/workflow/feature/state.py @@ -1,5 +1,6 @@ """Feature workflow state definition.""" +import uuid from datetime import datetime from typing import Any @@ -10,11 +11,17 @@ CIIntegrationState, PRIntegrationState, ReviewIntegrationState, + StatsState, ) class FeatureState( - BaseState, PRIntegrationState, CIIntegrationState, ReviewIntegrationState, total=False + BaseState, + PRIntegrationState, + CIIntegrationState, + ReviewIntegrationState, + StatsState, + total=False, ): """State specific to Feature workflow.""" @@ -133,6 +140,17 @@ def create_initial_feature_state(ticket_key: str, **kwargs: Any) -> FeatureState "spec_pr_branch": None, "spec_pr_file_path": None, "yolo_mode": False, + # Stats fields + "stage_timestamps": {}, + "revision_counts": {}, + "token_usage": {}, + "stage_token_usage": {}, + "stats_pr_urls": [], + "stats_ci_cycles": 0, + "workflow_outcome": None, + "stats_outcome_reason": None, + "stats_comment_posted": False, + "workflow_run_id": str(uuid.uuid4()), } # Merge with kwargs, letting kwargs override defaults diff --git a/src/forge/workflow/nodes/ci_evaluator.py b/src/forge/workflow/nodes/ci_evaluator.py index 3a881426..1bbe8ccc 100644 --- a/src/forge/workflow/nodes/ci_evaluator.py +++ b/src/forge/workflow/nodes/ci_evaluator.py @@ -2,6 +2,7 @@ import io import logging +import time import zipfile from pathlib import Path from typing import Any @@ -17,6 +18,13 @@ from forge.workflow.nodes.code_review import run_post_change_review, sync_pr_description from forge.workflow.nodes.error_handler import notify_error from forge.workflow.nodes.workspace_setup import prepare_workspace +from forge.workflow.stats import STAGE_CI +from forge.workflow.stats_utils import ( + increment_revision, + record_stage_end, + record_stage_start, + record_tokens, +) from forge.workflow.utils import update_state_timestamp from forge.workflow.utils.jira_status import ( post_status_comment, @@ -29,6 +37,13 @@ logger = logging.getLogger(__name__) +def _estimate_tokens(text: str) -> int: + """Estimate token count from text length (approx. 4 chars per token).""" + if not text: + return 0 + return max(1, len(text) // 4) + + async def evaluate_ci_status(state: WorkflowState) -> WorkflowState: """Evaluate CI status for the current PR. @@ -49,8 +64,13 @@ async def evaluate_ci_status(state: WorkflowState) -> WorkflowState: ci_fix_max = state.get("ci_fix_max_attempts", 5) settings = get_settings() + state = {**state, **record_stage_start(state, STAGE_CI, model_name=settings.container_model)} + node_start = time.monotonic() + if not pr_urls: logger.info(f"No PRs to evaluate for {ticket_key}") + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_CI, machine_time)} return update_state_timestamp( { **state, @@ -148,8 +168,15 @@ def _is_skipped(check: dict) -> bool: } ) + if all_passed or not any_still_running: + from forge.workflow.stats_utils import increment_ci_cycle + + state = {**state, **increment_ci_cycle(state)} + if all_passed: logger.info(f"All CI checks passed for {ticket_key}") + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_CI, machine_time)} return update_state_timestamp( { **state, @@ -167,6 +194,8 @@ def _is_skipped(check: dict) -> bool: f"CI partially complete for {ticket_key} " f"({len(failed_checks)} failed, more still running) — waiting" ) + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_CI, machine_time)} return update_state_timestamp( { **state, @@ -179,6 +208,8 @@ def _is_skipped(check: dict) -> bool: # This prevents the fix pipeline from firing while real CI jobs are in-progress. if not failed_checks: logger.info(f"CI checks still running for {ticket_key}, waiting for completion") + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_CI, machine_time)} return update_state_timestamp( { **state, @@ -191,6 +222,8 @@ def _is_skipped(check: dict) -> bool: if ci_fix_attempt >= ci_fix_max: logger.warning(f"CI fix attempt limit ({ci_fix_max}) reached for {ticket_key}") record_ci_fix_attempt(repo=state.get("current_repo", "unknown"), result="exhausted") + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_CI, machine_time)} return update_state_timestamp( { **state, @@ -203,6 +236,8 @@ def _is_skipped(check: dict) -> bool: next_attempt = ci_fix_attempt + 1 logger.info(f"CI failed for {ticket_key}, attempt {next_attempt}/{ci_fix_max}") + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_CI, machine_time)} return update_state_timestamp( { **state, @@ -216,6 +251,8 @@ def _is_skipped(check: dict) -> bool: except Exception as e: logger.error(f"CI evaluation failed for {ticket_key}: {e}") await notify_error(state, str(e), "ci_evaluator") + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_CI, machine_time)} return { **state, "last_error": str(e), @@ -255,6 +292,11 @@ async def attempt_ci_fix(state: WorkflowState) -> WorkflowState: logger.info(f"Attempting CI fix for {ticket_key}") + settings = get_settings() + state = {**state, **record_stage_start(state, STAGE_CI, model_name=settings.container_model)} + state = {**state, **increment_revision(state, STAGE_CI)} + node_start = time.monotonic() + # Post status comment to feature ticket at start of CI fix attempt ci_fix_attempt = state.get("ci_fix_attempt", 0) ci_fix_max = state.get("ci_fix_max_attempts", 5) @@ -278,6 +320,8 @@ async def attempt_ci_fix(state: WorkflowState) -> WorkflowState: except Exception as _setup_err: logger.error(f"Workspace setup failed for {ticket_key}: {_setup_err}") await notify_error(state, str(_setup_err), "attempt_ci_fix") + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_CI, machine_time)} return { **state, "last_error": str(_setup_err), @@ -311,7 +355,7 @@ async def attempt_ci_fix(state: WorkflowState) -> WorkflowState: ) runner = ContainerRunner(settings) - await runner.run( + result_phase1 = await runner.run( workspace_path=Path(workspace_path), task_summary=f"Analyze CI failures (attempt {attempt})", task_description=analysis_prompt, @@ -320,8 +364,36 @@ async def attempt_ci_fix(state: WorkflowState) -> WorkflowState: repo_name=state.get("current_repo", ""), ) + # Record tokens (using actual container metrics if available, else falling back to heuristic) + if ( + result_phase1 + and isinstance(getattr(result_phase1, "input_tokens", None), int) + and result_phase1.input_tokens > 0 + ): + input_tokens_1 = result_phase1.input_tokens + else: + input_tokens_1 = max(1, _estimate_tokens(analysis_prompt)) + + if ( + result_phase1 + and isinstance(getattr(result_phase1, "output_tokens", None), int) + and result_phase1.output_tokens > 0 + ): + output_tokens_1 = result_phase1.output_tokens + else: + text_for_est_1 = "" + if result_phase1: + text_for_est_1 = (getattr(result_phase1, "stdout", "") or "") + ( + getattr(result_phase1, "stderr", "") or "" + ) + output_tokens_1 = max(1, _estimate_tokens(text_for_est_1)) + + state = {**state, **record_tokens(state, STAGE_CI, input_tokens_1, output_tokens_1)} + if not fix_plan_file.exists(): logger.warning(f"No fix plan written for {ticket_key} — skipping fix phase") + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_CI, machine_time)} return update_state_timestamp( { **state, @@ -339,7 +411,7 @@ async def attempt_ci_fix(state: WorkflowState) -> WorkflowState: fix_prompt = load_prompt("fix-ci", fix_plan=fix_plan) runner = ContainerRunner(settings) - await runner.run( + result_phase2 = await runner.run( workspace_path=Path(workspace_path), task_summary=f"Apply CI fix plan (attempt {attempt})", task_description=fix_prompt, @@ -348,6 +420,32 @@ async def attempt_ci_fix(state: WorkflowState) -> WorkflowState: repo_name=state.get("current_repo", ""), ) + # Record tokens (using actual container metrics if available, else falling back to heuristic) + if ( + result_phase2 + and isinstance(getattr(result_phase2, "input_tokens", None), int) + and result_phase2.input_tokens > 0 + ): + input_tokens_2 = result_phase2.input_tokens + else: + input_tokens_2 = max(1, _estimate_tokens(fix_prompt)) + + if ( + result_phase2 + and isinstance(getattr(result_phase2, "output_tokens", None), int) + and result_phase2.output_tokens > 0 + ): + output_tokens_2 = result_phase2.output_tokens + else: + text_for_est_2 = "" + if result_phase2: + text_for_est_2 = (getattr(result_phase2, "stdout", "") or "") + ( + getattr(result_phase2, "stderr", "") or "" + ) + output_tokens_2 = max(1, _estimate_tokens(text_for_est_2)) + + state = {**state, **record_tokens(state, STAGE_CI, input_tokens_2, output_tokens_2)} + workspace = Workspace( path=Path(workspace_path), repo_name=state.get("current_repo", ""), @@ -386,6 +484,7 @@ async def attempt_ci_fix(state: WorkflowState) -> WorkflowState: spec_content=state.get("spec_content", ""), guardrails=state.get("context", {}).get("guardrails", ""), label=f"ci-fix-{attempt}", + state=state, ) # Push all commits (CI fix + any review corrections) @@ -409,6 +508,8 @@ async def attempt_ci_fix(state: WorkflowState) -> WorkflowState: attempt=attempt, ) + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_CI, machine_time)} return update_state_timestamp( { **state, @@ -420,6 +521,8 @@ async def attempt_ci_fix(state: WorkflowState) -> WorkflowState: except Exception as e: logger.error(f"CI fix failed for {ticket_key}: {e}") await notify_error(state, str(e), "attempt_ci_fix") + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_CI, machine_time)} return { **state, "last_error": str(e), diff --git a/src/forge/workflow/nodes/code_review.py b/src/forge/workflow/nodes/code_review.py index 95fb7692..55224fea 100644 --- a/src/forge/workflow/nodes/code_review.py +++ b/src/forge/workflow/nodes/code_review.py @@ -7,6 +7,7 @@ """ import logging +import time from pathlib import Path from typing import Any @@ -16,12 +17,21 @@ from forge.integrations.jira.client import JiraClient from forge.prompts import load_prompt from forge.sandbox import ContainerRunner +from forge.workflow.stats import STAGE_REVIEW +from forge.workflow.stats_utils import record_stage_end, record_stage_start, record_tokens from forge.workspace.git_ops import GitOperations from forge.workspace.manager import Workspace logger = logging.getLogger(__name__) +def _estimate_tokens(text: str) -> int: + """Estimate token count from text length (approx. 4 chars per token).""" + if not text: + return 0 + return max(1, len(text) // 4) + + async def run_post_change_review( workspace_path: str, ticket_key: str, @@ -30,6 +40,7 @@ async def run_post_change_review( spec_content: str = "", guardrails: str = "", label: str = "post-change", + state: Any = None, ) -> bool: """Run the local-review container skill after a code-changing step. @@ -45,11 +56,18 @@ async def run_post_change_review( spec_content: Spec to guide the review (optional). guardrails: Repository guidelines (optional). label: Short label for log messages (e.g. "ci-fix", "post-change"). + state: Optional workflow state. Returns: True if the review committed any fixes, False otherwise. """ settings = get_settings() + node_start = None + if state is not None: + start_updates = record_stage_start(state, STAGE_REVIEW, model_name=settings.llm_model) + state.setdefault("stage_timestamps", {}).update(start_updates.get("stage_timestamps", {})) + node_start = time.monotonic() + try: task_description = load_prompt( "local-review", @@ -59,7 +77,7 @@ async def run_post_change_review( ) runner = ContainerRunner(settings) - await runner.run( + result = await runner.run( workspace_path=Path(workspace_path), task_summary=f"Post-{label} code review", task_description=task_description, @@ -68,6 +86,35 @@ async def run_post_change_review( repo_name=current_repo, ) + if state is not None: + # Record tokens (using actual container metrics if available, else falling back to heuristic) + if ( + result + and isinstance(getattr(result, "input_tokens", None), int) + and result.input_tokens > 0 + ): + input_tokens = result.input_tokens + else: + input_tokens = _estimate_tokens(task_description) + + if ( + result + and isinstance(getattr(result, "output_tokens", None), int) + and result.output_tokens > 0 + ): + output_tokens = result.output_tokens + else: + output_tokens = _estimate_tokens(result.stdout) if result.stdout else 0 + + token_updates = record_tokens(state, STAGE_REVIEW, input_tokens, output_tokens) + state.setdefault("stage_timestamps", {}).update( + token_updates.get("stage_timestamps", {}) + ) + state.setdefault("stage_token_usage", {}).update( + token_updates.get("stage_token_usage", {}) + ) + state.setdefault("token_usage", {}).update(token_updates.get("token_usage", {})) + git = GitOperations( Workspace( path=Path(workspace_path), @@ -77,17 +124,28 @@ async def run_post_change_review( ) ) + committed = False if git.has_uncommitted_changes(): git.stage_all() git.commit(f"[{ticket_key}] fix: address issues found in {label} review") logger.info(f"Committed {label} review fixes for {ticket_key}") - return True + committed = True + else: + logger.info(f"Post-{label} review: no fixes needed for {ticket_key}") - logger.info(f"Post-{label} review: no fixes needed for {ticket_key}") - return False + if state is not None and node_start is not None: + machine_time = time.monotonic() - node_start + end_updates = record_stage_end(state, STAGE_REVIEW, machine_time) + state.setdefault("stage_timestamps", {}).update(end_updates.get("stage_timestamps", {})) + + return committed except Exception as e: logger.warning(f"Post-{label} review failed (non-fatal): {e}") + if state is not None and node_start is not None: + machine_time = time.monotonic() - node_start + end_updates = record_stage_end(state, STAGE_REVIEW, machine_time) + state.setdefault("stage_timestamps", {}).update(end_updates.get("stage_timestamps", {})) return False @@ -116,6 +174,11 @@ async def sync_pr_description( if pr_number is None: return + settings = get_settings() + start_updates = record_stage_start(state, STAGE_REVIEW, model_name=settings.llm_model) + state.setdefault("stage_timestamps", {}).update(start_updates.get("stage_timestamps", {})) + node_start = time.monotonic() + try: commit_log = git._run_git( "log", @@ -127,6 +190,9 @@ async def sync_pr_description( if not commit_log: logger.debug("PR description sync skipped — no commits on branch") + machine_time = time.monotonic() - node_start + end_updates = record_stage_end(state, STAGE_REVIEW, machine_time) + state.setdefault("stage_timestamps", {}).update(end_updates.get("stage_timestamps", {})) return github = GitHubClient() @@ -140,7 +206,7 @@ async def sync_pr_description( current_description=current_body, commit_log=commit_log, ) - agent = ForgeAgent(get_settings()) + agent = ForgeAgent(settings) try: updated_body = await agent.run_task( task="sync-pr-description", @@ -160,6 +226,28 @@ async def sync_pr_description( finally: await agent.close() + # Record tokens (using actual agent metadata if available, else falling back to heuristic) + last_in = getattr(agent, "last_input_tokens", 0) + last_out = getattr(agent, "last_output_tokens", 0) + if isinstance(last_in, int) and not isinstance(last_in, bool) and last_in > 0: + input_tokens = last_in + else: + input_tokens = _estimate_tokens(prompt) + + if isinstance(last_out, int) and not isinstance(last_out, bool) and last_out > 0: + output_tokens = last_out + else: + output_tokens = _estimate_tokens(updated_body) if updated_body else 0 + + token_updates = record_tokens(state, STAGE_REVIEW, input_tokens, output_tokens) + state.setdefault("stage_timestamps", {}).update( + token_updates.get("stage_timestamps", {}) + ) + state.setdefault("stage_token_usage", {}).update( + token_updates.get("stage_token_usage", {}) + ) + state.setdefault("token_usage", {}).update(token_updates.get("token_usage", {})) + if updated_body: updated_body = agent._strip_preamble(updated_body) if updated_body and updated_body.strip() != current_body.strip(): @@ -177,5 +265,12 @@ async def sync_pr_description( await github.close() await jira.close() + machine_time = time.monotonic() - node_start + end_updates = record_stage_end(state, STAGE_REVIEW, machine_time) + state.setdefault("stage_timestamps", {}).update(end_updates.get("stage_timestamps", {})) + except Exception as e: logger.warning(f"PR description sync failed (non-fatal): {e}") + machine_time = time.monotonic() - node_start + end_updates = record_stage_end(state, STAGE_REVIEW, machine_time) + state.setdefault("stage_timestamps", {}).update(end_updates.get("stage_timestamps", {})) diff --git a/src/forge/workflow/nodes/epic_decomposition.py b/src/forge/workflow/nodes/epic_decomposition.py index 7081ee3b..d2aa4d3c 100644 --- a/src/forge/workflow/nodes/epic_decomposition.py +++ b/src/forge/workflow/nodes/epic_decomposition.py @@ -1,6 +1,7 @@ """Epic decomposition node for LangGraph workflow.""" import logging +import time from typing import Any from forge.config import get_settings @@ -8,6 +9,13 @@ from forge.integrations.jira.client import JiraClient, MissingProjectConfig from forge.models.workflow import ForgeLabel from forge.workflow.feature.state import FeatureState as WorkflowState +from forge.workflow.stats import STAGE_EPICS +from forge.workflow.stats_utils import ( + increment_revision, + record_stage_end, + record_stage_start, + record_tokens, +) from forge.workflow.utils import update_state_timestamp from forge.workflow.utils.jira_status import post_status_comment from forge.workflow.utils.qa_summary import post_qa_summary_if_needed @@ -15,6 +23,11 @@ logger = logging.getLogger(__name__) +def _estimate_tokens(text: str) -> int: + """Estimate token count from text length (approx. 4 chars per token).""" + return max(1, len(text) // 4) + + def _missing_repo_config_comment(project_key: str) -> str: return ( f"⚠️ Forge configuration required for project {project_key}\n\n" @@ -50,6 +63,10 @@ async def decompose_epics(state: WorkflowState) -> WorkflowState: logger.info(f"Decomposing spec into Epics for {ticket_key}") + settings = get_settings() + state = {**state, **record_stage_start(state, STAGE_EPICS, model_name=settings.llm_model)} + node_start = time.monotonic() + # Post Q&A summary for spec if any qa_history = state.get("qa_history", []) if qa_history: @@ -104,6 +121,8 @@ async def decompose_epics(state: WorkflowState) -> WorkflowState: ) await jira.add_comment(ticket_key, _missing_repo_config_comment(project_key)) await jira.set_workflow_label(ticket_key, ForgeLabel.BLOCKED) + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_EPICS, machine_time)} return {**state, "last_error": str(e), "current_node": "decompose_epics"} logger.warning(f"Project {project_key}: {e} — falling back to GITHUB_KNOWN_REPOS") for repo in settings.known_repos: @@ -128,6 +147,21 @@ async def decompose_epics(state: WorkflowState) -> WorkflowState: # Generate Epic breakdown using Claude - primary operation epics_data = await agent.generate_epics(spec_content, context) + # Record tokens (using actual agent metadata if available, else falling back to heuristic) + last_in = getattr(agent, "last_input_tokens", 0) + last_out = getattr(agent, "last_output_tokens", 0) + if isinstance(last_in, int) and not isinstance(last_in, bool) and last_in > 0: + input_tokens = last_in + else: + input_tokens = _estimate_tokens(spec_content) + + if isinstance(last_out, int) and not isinstance(last_out, bool) and last_out > 0: + output_tokens = last_out + else: + output_tokens = _estimate_tokens(str(epics_data)) if epics_data else 0 + + state = {**state, **record_tokens(state, STAGE_EPICS, input_tokens, output_tokens)} + if not epics_data: logger.warning(f"No Epics generated for {ticket_key}") return { @@ -200,6 +234,8 @@ async def decompose_epics(state: WorkflowState) -> WorkflowState: ) generation_context["plan"] = "\n\n".join(plan_summary_parts) + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_EPICS, machine_time)} return update_state_timestamp( { **state, @@ -214,6 +250,8 @@ async def decompose_epics(state: WorkflowState) -> WorkflowState: ) else: # No Epics created at all - this is a failure + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_EPICS, machine_time)} return { **state, "last_error": jira_error or "Failed to create any Epics in Jira", @@ -228,6 +266,8 @@ async def decompose_epics(state: WorkflowState) -> WorkflowState: await notify_error(state, str(e), "decompose_epics") # Save any Epics we managed to create + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_EPICS, machine_time)} result_state = { **state, "last_error": str(e), @@ -277,6 +317,7 @@ async def regenerate_all_epics(state: WorkflowState) -> WorkflowState: "epic_keys": [], "feedback_comment": feedback, } + updated_state = {**updated_state, **increment_revision(updated_state, STAGE_EPICS)} # Re-run decomposition (which will use context including feedback) return await decompose_epics(updated_state) @@ -314,6 +355,11 @@ async def update_single_epic(state: WorkflowState) -> WorkflowState: logger.info(f"Updating Epic {epic_key} with feedback") + settings = get_settings() + state = {**state, **record_stage_start(state, STAGE_EPICS, model_name=settings.llm_model)} + state = {**state, **increment_revision(state, STAGE_EPICS)} + node_start = time.monotonic() + jira = JiraClient() agent = ForgeAgent() @@ -337,6 +383,11 @@ async def update_single_epic(state: WorkflowState) -> WorkflowState: }, ) + # Record tokens + input_tokens = _estimate_tokens(original_plan) + _estimate_tokens(feedback) + output_tokens = _estimate_tokens(new_plan) + state = {**state, **record_tokens(state, STAGE_EPICS, input_tokens, output_tokens)} + # Update Epic description await jira.update_description(epic_key, new_plan) @@ -348,6 +399,9 @@ async def update_single_epic(state: WorkflowState) -> WorkflowState: logger.info(f"Updated Epic {epic_key} plan") + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_EPICS, machine_time)} + return update_state_timestamp( { **state, @@ -361,6 +415,8 @@ async def update_single_epic(state: WorkflowState) -> WorkflowState: except Exception as e: logger.error(f"Epic update failed for {epic_key}: {e}") + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_EPICS, machine_time)} return { **state, "last_error": str(e), diff --git a/src/forge/workflow/nodes/implement_review.py b/src/forge/workflow/nodes/implement_review.py index fcf66e90..bdb034c2 100644 --- a/src/forge/workflow/nodes/implement_review.py +++ b/src/forge/workflow/nodes/implement_review.py @@ -1,6 +1,7 @@ """implement_review node — addresses PR review feedback on an existing branch.""" import logging +import time from pathlib import Path from typing import Any @@ -14,10 +15,25 @@ from forge.workflow.feature.state import FeatureState as WorkflowState from forge.workflow.nodes.code_review import run_post_change_review, sync_pr_description from forge.workflow.nodes.workspace_setup import prepare_workspace +from forge.workflow.stats import STAGE_REVIEW +from forge.workflow.stats_utils import ( + increment_revision, + record_stage_end, + record_stage_start, + record_tokens, +) from forge.workflow.utils import set_paused, update_state_timestamp logger = logging.getLogger(__name__) + +def _estimate_tokens(text: str) -> int: + """Estimate token count from text length (approx. 4 chars per token).""" + if not text: + return 0 + return max(1, len(text) // 4) + + _REVIEW_COMMENTS_FILE = ".forge/review-comments.md" _REVIEW_PLAN_FILE = ".forge/review-plan.md" _REVIEW_OBJECTIONS_FILE = ".forge/review-objections.md" @@ -122,12 +138,17 @@ async def implement_review(state: WorkflowState) -> WorkflowState: logger.info(f"Implementing PR review feedback for {ticket_key}") settings = get_settings() + state = {**state, **record_stage_start(state, STAGE_REVIEW, model_name=settings.llm_model)} + state = {**state, **increment_revision(state, STAGE_REVIEW)} + node_start = time.monotonic() try: try: workspace_path, git = prepare_workspace(state) state = {**state, "workspace_path": workspace_path} except ValueError as e: + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_REVIEW, machine_time)} return update_state_timestamp( { **state, @@ -171,7 +192,7 @@ async def implement_review(state: WorkflowState) -> WorkflowState: analysis_prompt = load_prompt("implement-review", ticket_key=ticket_key) runner = ContainerRunner(settings) - await runner.run( + result_phase1 = await runner.run( workspace_path=Path(workspace_path), task_summary=f"Analyze PR review feedback for {ticket_key}", task_description=analysis_prompt, @@ -180,6 +201,27 @@ async def implement_review(state: WorkflowState) -> WorkflowState: repo_name=current_repo, ) + # Record tokens (using actual container metrics if available, else falling back to heuristic) + if ( + result_phase1 + and isinstance(getattr(result_phase1, "input_tokens", None), int) + and result_phase1.input_tokens > 0 + ): + input_tokens_1 = result_phase1.input_tokens + else: + input_tokens_1 = _estimate_tokens(analysis_prompt) + + if ( + result_phase1 + and isinstance(getattr(result_phase1, "output_tokens", None), int) + and result_phase1.output_tokens > 0 + ): + output_tokens_1 = result_phase1.output_tokens + else: + output_tokens_1 = _estimate_tokens(result_phase1.stdout) if result_phase1.stdout else 0 + + state = {**state, **record_tokens(state, STAGE_REVIEW, input_tokens_1, output_tokens_1)} + # ── Check for objections ────────────────────────────────────────────── objections_path = Path(workspace_path) / _REVIEW_OBJECTIONS_FILE if objections_path.exists(): @@ -193,6 +235,8 @@ async def implement_review(state: WorkflowState) -> WorkflowState: repo=_repo, pr_number=pr_number, ) + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_REVIEW, machine_time)} return update_state_timestamp( { **state, @@ -213,7 +257,7 @@ async def implement_review(state: WorkflowState) -> WorkflowState: fix_prompt = load_prompt("implement-review-fix", ticket_key=ticket_key) runner = ContainerRunner(settings) - await runner.run( + result_fix = await runner.run( workspace_path=Path(workspace_path), task_summary=f"Implement PR review plan for {ticket_key}", task_description=fix_prompt, @@ -222,6 +266,28 @@ async def implement_review(state: WorkflowState) -> WorkflowState: repo_name=current_repo, ) + # Record tokens (using actual container metrics if available, else falling back to heuristic) + if ( + result_fix + and isinstance(getattr(result_fix, "input_tokens", None), int) + and result_fix.input_tokens > 0 + ): + input_tokens_2 = result_fix.input_tokens + else: + input_tokens_2 = _estimate_tokens(fix_prompt) + + if ( + result_fix + and isinstance(getattr(result_fix, "output_tokens", None), int) + and result_fix.output_tokens > 0 + ): + output_tokens_2 = result_fix.output_tokens + else: + output_tokens_2 = ( + _estimate_tokens(result_fix.stdout) if (result_fix and result_fix.stdout) else 0 + ) + state = {**state, **record_tokens(state, STAGE_REVIEW, input_tokens_2, output_tokens_2)} + # Commit any uncommitted changes the container left if git.has_uncommitted_changes(): git.stage_all() @@ -248,6 +314,7 @@ async def implement_review(state: WorkflowState) -> WorkflowState: spec_content=state.get("spec_content", ""), guardrails=state.get("context", {}).get("guardrails", ""), label="review-impl", + state=state, ) if fork_owner and fork_repo: @@ -271,6 +338,9 @@ async def implement_review(state: WorkflowState) -> WorkflowState: # CI won't re-trigger and wait_for_ci_gate would block forever. next_node = "wait_for_ci_gate" if unpushed else "human_review_gate" + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_REVIEW, machine_time)} + return update_state_timestamp( { **state, @@ -289,6 +359,8 @@ async def implement_review(state: WorkflowState) -> WorkflowState: from forge.workflow.nodes.error_handler import notify_error await notify_error(state, str(e), "implement_review") + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_REVIEW, machine_time)} return { **state, "last_error": str(e), diff --git a/src/forge/workflow/nodes/implementation.py b/src/forge/workflow/nodes/implementation.py index 55ae81c5..193d171a 100644 --- a/src/forge/workflow/nodes/implementation.py +++ b/src/forge/workflow/nodes/implementation.py @@ -12,6 +12,7 @@ """ import logging +import time from pathlib import Path from forge.config import get_settings @@ -20,6 +21,13 @@ from forge.sandbox import ContainerRunner from forge.workflow.feature.state import FeatureState as WorkflowState from forge.workflow.nodes.error_handler import notify_error +from forge.workflow.stats import STAGE_IMPLEMENTATION +from forge.workflow.stats_utils import ( + increment_revision, + record_stage_end, + record_stage_start, + record_tokens, +) from forge.workflow.utils import update_state_timestamp from forge.workflow.utils.jira_status import post_status_comment from forge.workspace.git_ops import GitOperations @@ -28,6 +36,13 @@ logger = logging.getLogger(__name__) +def _estimate_tokens(text: str) -> int: + """Estimate token count from text length (approx. 4 chars per token).""" + if not text: + return 0 + return max(1, len(text) // 4) + + async def implement_task(state: WorkflowState) -> WorkflowState: """Implement a single Task using container sandbox. @@ -110,6 +125,16 @@ async def implement_task(state: WorkflowState) -> WorkflowState: logger.info(f"Implementing Task {current_task} for {ticket_key}") settings = get_settings() + state = { + **state, + **record_stage_start(state, STAGE_IMPLEMENTATION, model_name=settings.llm_model), + } + state = { + **state, + **increment_revision(state, STAGE_IMPLEMENTATION), + } + node_start = time.monotonic() + jira = JiraClient(settings) try: @@ -151,6 +176,27 @@ async def implement_task(state: WorkflowState) -> WorkflowState: previous_task_keys=implemented_tasks, ) + # Record tokens (using actual container metrics if available, else falling back to heuristic) + if ( + result + and isinstance(getattr(result, "input_tokens", None), int) + and result.input_tokens > 0 + ): + input_tokens = result.input_tokens + else: + input_tokens = _estimate_tokens(full_description) + + if ( + result + and isinstance(getattr(result, "output_tokens", None), int) + and result.output_tokens > 0 + ): + output_tokens = result.output_tokens + else: + output_tokens = _estimate_tokens(result.stdout) if (result and result.stdout) else 0 + + state = {**state, **record_tokens(state, STAGE_IMPLEMENTATION, input_tokens, output_tokens)} + if result.success: logger.info(f"Container completed successfully for {current_task}") @@ -165,6 +211,9 @@ async def implement_task(state: WorkflowState) -> WorkflowState: implemented = state.get("implemented_tasks", []) implemented.append(current_task) + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_IMPLEMENTATION, machine_time)} + return update_state_timestamp( { **state, @@ -186,6 +235,8 @@ async def implement_task(state: WorkflowState) -> WorkflowState: except Exception as e: logger.error(f"Implementation failed for {current_task}: {e}") await notify_error(state, str(e), "implement_task") + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_IMPLEMENTATION, machine_time)} return { **state, "last_error": str(e), diff --git a/src/forge/workflow/nodes/local_reviewer.py b/src/forge/workflow/nodes/local_reviewer.py index ffeef0cb..4df68f9a 100644 --- a/src/forge/workflow/nodes/local_reviewer.py +++ b/src/forge/workflow/nodes/local_reviewer.py @@ -2,6 +2,7 @@ import logging import re +import time from pathlib import Path from forge.config import get_settings @@ -10,6 +11,8 @@ from forge.prompts import load_prompt from forge.sandbox import ContainerRunner from forge.workflow.feature.state import FeatureState as WorkflowState +from forge.workflow.stats import STAGE_REVIEW +from forge.workflow.stats_utils import record_stage_end, record_stage_start, record_tokens from forge.workflow.utils import update_state_timestamp from forge.workflow.utils.jira_status import post_status_comment from forge.workspace.git_ops import GitOperations @@ -17,6 +20,14 @@ logger = logging.getLogger(__name__) + +def _estimate_tokens(text: str) -> int: + """Estimate token count from text length (approx. 4 chars per token).""" + if not text: + return 0 + return max(1, len(text) // 4) + + MAX_REVIEW_ATTEMPTS = 2 _QUALITATIVE_CAP = 2 _VALID_VERDICTS = {"adequate", "tests_incomplete", "symptom_only"} @@ -115,10 +126,18 @@ async def local_review_changes(state: WorkflowState) -> WorkflowState: logger.info(f"No workspace for local review on {ticket_key}, skipping") return update_state_timestamp({**state, "current_node": "create_pr"}) + settings = get_settings() + state = {**state, **record_stage_start(state, STAGE_REVIEW, model_name=settings.llm_model)} + node_start = time.monotonic() + if ticket_type == TicketType.BUG: - return await _run_bug_review(state) + result_state = await _run_bug_review(state) else: - return await _run_feature_review(state) + result_state = await _run_feature_review(state) + + machine_time = time.monotonic() - node_start + result_state = {**result_state, **record_stage_end(result_state, STAGE_REVIEW, machine_time)} + return result_state async def _run_bug_review(state: WorkflowState) -> WorkflowState: @@ -154,6 +173,27 @@ async def _run_bug_review(state: WorkflowState) -> WorkflowState: repo_name=current_repo, ) + # Record tokens (using actual container metrics if available, else falling back to heuristic) + if ( + result + and isinstance(getattr(result, "input_tokens", None), int) + and result.input_tokens > 0 + ): + input_tokens = result.input_tokens + else: + input_tokens = _estimate_tokens(task_description) + + if ( + result + and isinstance(getattr(result, "output_tokens", None), int) + and result.output_tokens > 0 + ): + output_tokens = result.output_tokens + else: + output_tokens = _estimate_tokens(result.stdout) if result.stdout else 0 + + state = {**state, **record_tokens(state, STAGE_REVIEW, input_tokens, output_tokens)} + git = GitOperations( Workspace( path=Path(workspace_path), @@ -314,6 +354,27 @@ async def _run_feature_review(state: WorkflowState) -> WorkflowState: repo_name=current_repo, ) + # Record tokens (using actual container metrics if available, else falling back to heuristic) + if ( + result + and isinstance(getattr(result, "input_tokens", None), int) + and result.input_tokens > 0 + ): + input_tokens = result.input_tokens + else: + input_tokens = _estimate_tokens(task_description) + + if ( + result + and isinstance(getattr(result, "output_tokens", None), int) + and result.output_tokens > 0 + ): + output_tokens = result.output_tokens + else: + output_tokens = _estimate_tokens(result.stdout) if result.stdout else 0 + + state = {**state, **record_tokens(state, STAGE_REVIEW, input_tokens, output_tokens)} + git = GitOperations( Workspace( path=Path(workspace_path), @@ -350,13 +411,15 @@ async def _run_feature_review(state: WorkflowState) -> WorkflowState: f"Could not fix all breaking issues after {MAX_REVIEW_ATTEMPTS} attempts " f"for {ticket_key}, proceeding to PR" ) + next_attempts = review_attempts + 1 else: logger.info(f"Local review passed for {ticket_key}") + next_attempts = 0 return update_state_timestamp( { **state, - "local_review_attempts": 0, + "local_review_attempts": next_attempts, "current_node": "create_pr", "last_error": None, } diff --git a/src/forge/workflow/nodes/plan_bug_fix.py b/src/forge/workflow/nodes/plan_bug_fix.py index e59ad448..7ee7db59 100644 --- a/src/forge/workflow/nodes/plan_bug_fix.py +++ b/src/forge/workflow/nodes/plan_bug_fix.py @@ -5,6 +5,7 @@ import logging import re import tempfile +import time from pathlib import Path from langgraph.graph import END @@ -15,10 +16,25 @@ from forge.prompts import load_prompt from forge.sandbox import ContainerRunner from forge.workflow.bug.state import BugState +from forge.workflow.stats import STAGE_PLANNING +from forge.workflow.stats_utils import ( + increment_revision, + record_stage_end, + record_stage_start, + record_tokens, +) from forge.workflow.utils import set_paused, update_state_timestamp logger = logging.getLogger(__name__) + +def _estimate_tokens(text: str) -> int: + """Estimate token count from text length (approx. 4 chars per token).""" + if not text: + return 0 + return max(1, len(text) // 4) + + _MAX_PLAN_RETRIES = 3 _MAX_COMMENT_CHARS = 25_000 _TRUNCATION_NOTE = "*(Plan truncated — full plan available in container logs.)*" @@ -59,6 +75,7 @@ async def regenerate_plan(state: BugState) -> BugState: Returns: Updated state with new plan_content, routed to plan_approval_gate. """ + state = {**state, **increment_revision(state, STAGE_PLANNING)} result = await _run_plan_container(state, "regenerate-plan", retry_node="regenerate_plan") if result["current_node"] == "plan_approval_gate": return { @@ -92,6 +109,11 @@ async def _run_plan_container( original_plan = state.get("plan_content") or "" settings = get_settings() + state = {**state, **record_stage_start(state, STAGE_PLANNING, model_name=settings.llm_model)} + if prompt_name == "regenerate-plan": + state = {**state, **increment_revision(state, STAGE_PLANNING)} + node_start = time.monotonic() + jira = JiraClient() try: @@ -140,6 +162,26 @@ async def _run_plan_container( task_key=f"{ticket_key}-plan", ) + # Record tokens (using actual container metrics if available, else falling back to heuristic) + if ( + result + and isinstance(getattr(result, "input_tokens", None), int) + and result.input_tokens > 0 + ): + input_tokens = result.input_tokens + else: + input_tokens = _estimate_tokens(task_description) + + if ( + result + and isinstance(getattr(result, "output_tokens", None), int) + and result.output_tokens > 0 + ): + output_tokens = result.output_tokens + else: + output_tokens = _estimate_tokens(result.stdout) if (result and result.stdout) else 0 + state = {**state, **record_tokens(state, STAGE_PLANNING, input_tokens, output_tokens)} + if not result.success: raise RuntimeError( f"Container failed with exit_code={result.exit_code}: {result.stderr}" @@ -151,6 +193,9 @@ async def _run_plan_container( await jira.add_comment(ticket_key, comment) await jira.set_workflow_label(ticket_key, ForgeLabel.PLAN_PENDING) + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_PLANNING, machine_time)} + return update_state_timestamp( { **state, @@ -164,6 +209,8 @@ async def _run_plan_container( except Exception as e: logger.error(f"_run_plan_container ({prompt_name}) failed for {ticket_key}: {e}") new_retry = retry_count + 1 + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_PLANNING, machine_time)} return { **state, "last_error": str(e), diff --git a/src/forge/workflow/nodes/pr_creation.py b/src/forge/workflow/nodes/pr_creation.py index 225bed8a..0207b233 100644 --- a/src/forge/workflow/nodes/pr_creation.py +++ b/src/forge/workflow/nodes/pr_creation.py @@ -247,9 +247,14 @@ async def create_pull_request(state: WorkflowState) -> WorkflowState: attempt=0, ) + from forge.workflow.stats_utils import add_pr_url + + stats_updates = add_pr_url(state, pr_url) + return update_state_timestamp( { **state, + **stats_updates, "pr_urls": pr_urls, "current_pr_url": pr_url, "current_pr_number": pr_number, diff --git a/src/forge/workflow/nodes/prd_generation.py b/src/forge/workflow/nodes/prd_generation.py index 61f2a461..b60d2988 100644 --- a/src/forge/workflow/nodes/prd_generation.py +++ b/src/forge/workflow/nodes/prd_generation.py @@ -1,6 +1,7 @@ """PRD generation node for LangGraph workflow.""" import logging +import time from datetime import UTC, datetime from typing import Any @@ -11,12 +12,24 @@ from forge.models.workflow import ForgeLabel from forge.orchestrator.checkpointer import set_pr_ticket_index from forge.workflow.feature.state import FeatureState as WorkflowState +from forge.workflow.stats import STAGE_PRD +from forge.workflow.stats_utils import ( + increment_revision, + record_stage_end, + record_stage_start, + record_tokens, +) from forge.workflow.utils import update_state_timestamp from forge.workflow.utils.jira_status import post_status_comment logger = logging.getLogger(__name__) +def _estimate_tokens(text: str) -> int: + """Estimate token count from text length (approx. 4 chars per token).""" + return max(1, len(text) // 4) + + def _normalize_proposals_path(path: str) -> str: """Normalize a proposals base path for GitHub content paths.""" return path.strip("/") @@ -180,6 +193,11 @@ async def generate_prd(state: WorkflowState) -> WorkflowState: ticket_key = state["ticket_key"] logger.info(f"Generating PRD for {ticket_key}") + # Record stage start and begin timing + settings = get_settings() + state = {**state, **record_stage_start(state, STAGE_PRD, model_name=settings.llm_model)} + node_start = time.monotonic() + jira = JiraClient() agent = ForgeAgent() prd_content = None @@ -198,8 +216,11 @@ async def generate_prd(state: WorkflowState) -> WorkflowState: if not raw_requirements.strip(): logger.warning(f"No description found for {ticket_key}") + machine_time = time.monotonic() - node_start + end_stats = record_stage_end(state, STAGE_PRD, machine_time) return { **state, + **end_stats, "last_error": "No requirements found in issue description", "current_node": "generate_prd", } @@ -219,6 +240,21 @@ async def generate_prd(state: WorkflowState) -> WorkflowState: # Generate PRD using Claude - primary operation prd_content = await agent.generate_prd(raw_requirements, context) + # Record token usage (using actual agent metadata if available, else falling back to heuristic) + last_in = getattr(agent, "last_input_tokens", 0) + last_out = getattr(agent, "last_output_tokens", 0) + if isinstance(last_in, int) and not isinstance(last_in, bool) and last_in > 0: + input_tokens = last_in + else: + input_tokens = _estimate_tokens(raw_requirements) + + if isinstance(last_out, int) and not isinstance(last_out, bool) and last_out > 0: + output_tokens = last_out + else: + output_tokens = _estimate_tokens(prd_content) + + state = {**state, **record_tokens(state, STAGE_PRD, input_tokens, output_tokens)} + # Publish PRD - either as GitHub PR or Jira update # Per-project opt-in: check forge.prd_proposals_repo project property proposals_repo = await _resolve_prd_proposals_repo(issue.project_key, jira) @@ -259,10 +295,15 @@ async def generate_prd(state: WorkflowState) -> WorkflowState: "generated_at": datetime.now(UTC).isoformat(), } + # Record stage end with elapsed wall-clock time + machine_time = time.monotonic() - node_start + end_stats = record_stage_end(state, STAGE_PRD, machine_time) + # If publish failed, set a warning but still advance (content exists) result = update_state_timestamp( { **state, + **end_stats, "prd_content": prd_content, "generation_context": generation_context, "current_node": "prd_approval_gate", @@ -279,8 +320,11 @@ async def generate_prd(state: WorkflowState) -> WorkflowState: await notify_error(state, str(e), "generate_prd") # If we have partial content, save it even on failure + machine_time = time.monotonic() - node_start + end_stats = record_stage_end(state, STAGE_PRD, machine_time) result_state = { **state, + **end_stats, "last_error": str(e), "current_node": "generate_prd", "retry_count": state.get("retry_count", 0) + 1, @@ -316,6 +360,12 @@ async def regenerate_prd_with_feedback(state: WorkflowState) -> WorkflowState: logger.info(f"Regenerating PRD for {ticket_key} with feedback") + # Record stage re-entry: start timer, increment revision count + settings = get_settings() + state = {**state, **record_stage_start(state, STAGE_PRD, model_name=settings.llm_model)} + state = {**state, **increment_revision(state, STAGE_PRD)} + node_start = time.monotonic() + jira = JiraClient() agent = ForgeAgent() @@ -335,6 +385,11 @@ async def regenerate_prd_with_feedback(state: WorkflowState) -> WorkflowState: }, ) + # Record token usage (estimated from content length) + input_tokens = _estimate_tokens(original_prd) + _estimate_tokens(feedback) + output_tokens = _estimate_tokens(new_prd) + state = {**state, **record_tokens(state, STAGE_PRD, input_tokens, output_tokens)} + # Publish revised PRD if state.get("prd_pr_number"): await _update_prd_proposal_pr(ticket_key, new_prd, state) @@ -356,9 +411,14 @@ async def regenerate_prd_with_feedback(state: WorkflowState) -> WorkflowState: logger.info(f"PRD regenerated for {ticket_key} ({len(new_prd)} chars)") + # Record stage end with elapsed wall-clock time + machine_time = time.monotonic() - node_start + end_stats = record_stage_end(state, STAGE_PRD, machine_time) + return update_state_timestamp( { **state, + **end_stats, "prd_content": new_prd, "feedback_comment": None, "revision_requested": False, @@ -372,8 +432,11 @@ async def regenerate_prd_with_feedback(state: WorkflowState) -> WorkflowState: from forge.workflow.nodes.error_handler import notify_error await notify_error(state, str(e), "regenerate_prd") + machine_time = time.monotonic() - node_start + end_stats = record_stage_end(state, STAGE_PRD, machine_time) return { **state, + **end_stats, "last_error": str(e), "current_node": "regenerate_prd", "retry_count": state.get("retry_count", 0) + 1, diff --git a/src/forge/workflow/nodes/rca_analysis.py b/src/forge/workflow/nodes/rca_analysis.py index a95b6f96..b16c3c90 100644 --- a/src/forge/workflow/nodes/rca_analysis.py +++ b/src/forge/workflow/nodes/rca_analysis.py @@ -3,6 +3,7 @@ import json import logging import tempfile +import time from pathlib import Path from forge.config import get_settings @@ -11,11 +12,21 @@ from forge.prompts import load_prompt from forge.sandbox import ContainerRunner from forge.workflow.bug.state import BugState +from forge.workflow.stats import STAGE_RCA +from forge.workflow.stats_utils import record_stage_end, record_stage_start, record_tokens from forge.workflow.utils import update_state_timestamp from forge.workflow.utils.jira_status import post_status_comment logger = logging.getLogger(__name__) + +def _estimate_tokens(text: str) -> int: + """Estimate token count from text length (approx. 4 chars per token).""" + if not text: + return 0 + return max(1, len(text) // 4) + + _RCA_REQUIRED_KEYS = { "summary", "code_location", @@ -49,6 +60,9 @@ async def analyze_bug(state: BugState) -> BugState: reflection_critique = state.get("reflection_critique") or "" settings = get_settings() + state = {**state, **record_stage_start(state, STAGE_RCA, model_name=settings.llm_model)} + node_start = time.monotonic() + jira = JiraClient() try: @@ -72,6 +86,8 @@ async def analyze_bug(state: BugState) -> BugState: f"Details: {e}", ) await jira.set_workflow_label(ticket_key, ForgeLabel.BLOCKED) + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_RCA, machine_time)} return { **state, "last_error": str(e), @@ -98,6 +114,26 @@ async def analyze_bug(state: BugState) -> BugState: task_key=f"{ticket_key}-analysis", ) + # Record tokens (using actual container metrics if available, else falling back to heuristic) + if ( + result + and isinstance(getattr(result, "input_tokens", None), int) + and result.input_tokens > 0 + ): + input_tokens = result.input_tokens + else: + input_tokens = _estimate_tokens(task_description) + + if ( + result + and isinstance(getattr(result, "output_tokens", None), int) + and result.output_tokens > 0 + ): + output_tokens = result.output_tokens + else: + output_tokens = _estimate_tokens(result.stdout) if (result and result.stdout) else 0 + state = {**state, **record_tokens(state, STAGE_RCA, input_tokens, output_tokens)} + if not result.success: raise RuntimeError( f"Container failed with exit_code={result.exit_code}: {result.stderr}" @@ -105,6 +141,9 @@ async def analyze_bug(state: BugState) -> BugState: data = _harvest_rca_json(workspace_path) + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_RCA, machine_time)} + return update_state_timestamp( { **state, @@ -120,6 +159,8 @@ async def analyze_bug(state: BugState) -> BugState: logger.error(f"analyze_bug failed for {ticket_key}: {e}") new_retry = retry_count + 1 next_node = "escalate_blocked" if new_retry >= MAX_ANALYSIS_RETRIES else "analyze_bug" + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_RCA, machine_time)} return { **state, "last_error": str(e), @@ -223,6 +264,9 @@ async def reflect_rca(state: BugState) -> BugState: reflect_rca_retry_count = state.get("reflect_rca_retry_count", 0) settings = get_settings() + state = {**state, **record_stage_start(state, STAGE_RCA, model_name=settings.llm_model)} + node_start = time.monotonic() + jira = JiraClient() try: @@ -244,6 +288,26 @@ async def reflect_rca(state: BugState) -> BugState: task_key=task_key, ) + # Record tokens (using actual container metrics if available, else falling back to heuristic) + if ( + result + and isinstance(getattr(result, "input_tokens", None), int) + and result.input_tokens > 0 + ): + input_tokens = result.input_tokens + else: + input_tokens = _estimate_tokens(task_description) + + if ( + result + and isinstance(getattr(result, "output_tokens", None), int) + and result.output_tokens > 0 + ): + output_tokens = result.output_tokens + else: + output_tokens = _estimate_tokens(result.stdout) if (result and result.stdout) else 0 + state = {**state, **record_tokens(state, STAGE_RCA, input_tokens, output_tokens)} + if not result.success: raise RuntimeError( f"Reflection container failed with exit_code={result.exit_code}: {result.stderr}" @@ -252,6 +316,8 @@ async def reflect_rca(state: BugState) -> BugState: verdict = _extract_reflection_verdict(workspace_path, task_key, result.stdout) if verdict.upper().strip() == "VALID": + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_RCA, machine_time)} return update_state_timestamp( { **state, @@ -268,6 +334,8 @@ async def reflect_rca(state: BugState) -> BugState: f"Reflection cap reached — proceeding with best available RCA after " f"{new_reflection_count} validation attempts.", ) + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_RCA, machine_time)} return update_state_timestamp( { **state, @@ -277,6 +345,8 @@ async def reflect_rca(state: BugState) -> BugState: } ) + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_RCA, machine_time)} return update_state_timestamp( { **state, @@ -292,6 +362,8 @@ async def reflect_rca(state: BugState) -> BugState: next_node = ( "escalate_blocked" if new_reflect_retry >= MAX_ANALYSIS_RETRIES else "reflect_rca" ) + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_RCA, machine_time)} return { **state, "last_error": str(e), diff --git a/src/forge/workflow/nodes/rca_option_gate.py b/src/forge/workflow/nodes/rca_option_gate.py index a1e766ac..e36dcbf3 100644 --- a/src/forge/workflow/nodes/rca_option_gate.py +++ b/src/forge/workflow/nodes/rca_option_gate.py @@ -181,9 +181,15 @@ async def regenerate_rca(state: BugState) -> BugState: finally: await jira.close() + from forge.workflow.stats import STAGE_RCA + from forge.workflow.stats_utils import increment_revision + + stats_updates = increment_revision(state, STAGE_RCA) + return update_state_timestamp( { **state, + **stats_updates, "reflection_critique": feedback or None, "feedback_comment": None, "revision_requested": False, diff --git a/src/forge/workflow/nodes/spec_generation.py b/src/forge/workflow/nodes/spec_generation.py index 396fe78f..b646f3a4 100644 --- a/src/forge/workflow/nodes/spec_generation.py +++ b/src/forge/workflow/nodes/spec_generation.py @@ -1,6 +1,7 @@ """Specification generation node for LangGraph workflow.""" import logging +import time from datetime import UTC, datetime from typing import Any @@ -16,6 +17,13 @@ _resolve_prd_proposals_repo, _resolve_proposals_path, ) +from forge.workflow.stats import STAGE_SPEC +from forge.workflow.stats_utils import ( + increment_revision, + record_stage_end, + record_stage_start, + record_tokens, +) from forge.workflow.utils import update_state_timestamp from forge.workflow.utils.jira_status import post_status_comment from forge.workflow.utils.qa_summary import post_qa_summary_if_needed @@ -23,6 +31,11 @@ logger = logging.getLogger(__name__) +def _estimate_tokens(text: str) -> int: + """Estimate token count from text length (approx. 4 chars per token).""" + return max(1, len(text) // 4) + + async def _create_spec_proposal_pr( ticket_key: str, spec_content: str, @@ -142,6 +155,11 @@ async def generate_spec(state: WorkflowState) -> WorkflowState: logger.info(f"Generating specification for {ticket_key}") + # Record stage start and begin timing + settings = get_settings() + state = {**state, **record_stage_start(state, STAGE_SPEC, model_name=settings.llm_model)} + node_start = time.monotonic() + # Post Q&A summary for PRD if any qa_history = state.get("qa_history", []) if qa_history: @@ -168,8 +186,11 @@ async def generate_spec(state: WorkflowState) -> WorkflowState: if not prd_content.strip(): logger.warning(f"No PRD content found for {ticket_key}") + machine_time = time.monotonic() - node_start + end_stats = record_stage_end(state, STAGE_SPEC, machine_time) return { **state, + **end_stats, "last_error": "No PRD content available for spec generation", "current_node": "generate_spec", } @@ -187,6 +208,21 @@ async def generate_spec(state: WorkflowState) -> WorkflowState: # Generate specification using Claude - primary operation spec_content = await agent.generate_spec(prd_content, context) + # Record token usage (using actual agent metadata if available, else falling back to heuristic) + last_in = getattr(agent, "last_input_tokens", 0) + last_out = getattr(agent, "last_output_tokens", 0) + if isinstance(last_in, int) and not isinstance(last_in, bool) and last_in > 0: + input_tokens = last_in + else: + input_tokens = _estimate_tokens(prd_content) + + if isinstance(last_out, int) and not isinstance(last_out, bool) and last_out > 0: + output_tokens = last_out + else: + output_tokens = _estimate_tokens(spec_content) + + state = {**state, **record_tokens(state, STAGE_SPEC, input_tokens, output_tokens)} + # Publish spec — either as GitHub PR or Jira update proposals_repo = await _resolve_prd_proposals_repo(issue.project_key, jira) spec_pr_result = None @@ -236,9 +272,14 @@ async def generate_spec(state: WorkflowState) -> WorkflowState: "generated_at": datetime.now(UTC).isoformat(), } + # Record stage end with elapsed wall-clock time + machine_time = time.monotonic() - node_start + end_stats = record_stage_end(state, STAGE_SPEC, machine_time) + result = update_state_timestamp( { **state, + **end_stats, "spec_content": spec_content, "generation_context": generation_context, "current_node": "spec_approval_gate", @@ -255,8 +296,11 @@ async def generate_spec(state: WorkflowState) -> WorkflowState: await notify_error(state, str(e), "generate_spec") # If we have partial content, save it even on failure + machine_time = time.monotonic() - node_start + end_stats = record_stage_end(state, STAGE_SPEC, machine_time) result_state = { **state, + **end_stats, "last_error": str(e), "current_node": "generate_spec", "retry_count": state.get("retry_count", 0) + 1, @@ -288,6 +332,12 @@ async def regenerate_spec_with_feedback(state: WorkflowState) -> WorkflowState: logger.info(f"Regenerating spec for {ticket_key} with feedback") + # Record stage re-entry: start timer, increment revision count + settings = get_settings() + state = {**state, **record_stage_start(state, STAGE_SPEC, model_name=settings.llm_model)} + state = {**state, **increment_revision(state, STAGE_SPEC)} + node_start = time.monotonic() + jira = JiraClient() agent = ForgeAgent() @@ -307,6 +357,11 @@ async def regenerate_spec_with_feedback(state: WorkflowState) -> WorkflowState: }, ) + # Record token usage (estimated from content length) + input_tokens = _estimate_tokens(original_spec) + _estimate_tokens(feedback) + output_tokens = _estimate_tokens(new_spec) + state = {**state, **record_tokens(state, STAGE_SPEC, input_tokens, output_tokens)} + # Publish revised spec if state.get("spec_pr_number"): await _update_spec_proposal_pr(ticket_key, new_spec, state) @@ -343,9 +398,14 @@ async def regenerate_spec_with_feedback(state: WorkflowState) -> WorkflowState: logger.info(f"Spec regenerated for {ticket_key} ({len(new_spec)} chars)") + # Record stage end with elapsed wall-clock time + machine_time = time.monotonic() - node_start + end_stats = record_stage_end(state, STAGE_SPEC, machine_time) + return update_state_timestamp( { **state, + **end_stats, "spec_content": new_spec, "feedback_comment": None, "revision_requested": False, @@ -359,8 +419,11 @@ async def regenerate_spec_with_feedback(state: WorkflowState) -> WorkflowState: from forge.workflow.nodes.error_handler import notify_error await notify_error(state, str(e), "regenerate_spec") + machine_time = time.monotonic() - node_start + end_stats = record_stage_end(state, STAGE_SPEC, machine_time) return { **state, + **end_stats, "last_error": str(e), "current_node": "regenerate_spec", "retry_count": state.get("retry_count", 0) + 1, diff --git a/src/forge/workflow/nodes/stats_posting.py b/src/forge/workflow/nodes/stats_posting.py new file mode 100644 index 00000000..8c311d81 --- /dev/null +++ b/src/forge/workflow/nodes/stats_posting.py @@ -0,0 +1,175 @@ +"""Terminal stats posting node for workflow completion. + +Posts a formatted stats summary comment to Jira whenever a workflow reaches a +terminal state (Completed, Blocked, or Failed). This is a *side-effect* node — +it always returns the state unchanged and never fails the workflow, regardless +of whether the Jira posting succeeds. +""" + +import logging +from typing import Any + +from forge.workflow.bug.state import BugState +from forge.workflow.feature.state import FeatureState +from forge.workflow.stats.poster import ensure_stats_is_final_comment, post_stats_comment + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Outcome helpers +# --------------------------------------------------------------------------- + + +def _determine_outcome(state: FeatureState | BugState) -> str: + """Return the outcome category string for the terminal state. + + Precedence: + 1. If ``workflow_outcome`` is already set in state, return it directly. + 2. If ``is_blocked`` is True, return ``"Blocked"``. + 3. If ``last_error`` is set, return ``"Failed"``. + 4. Otherwise, return ``"Completed"``. + + Args: + state: Current feature or bug workflow state. + + Returns: + One of ``"Completed"``, ``"Blocked"``, or ``"Failed"``. + """ + # If the workflow has already classified its own outcome, honour that. + existing = state.get("workflow_outcome") + if existing: + return existing + + if state.get("is_blocked"): + return "Blocked" + + if state.get("last_error"): + return "Failed" + + return "Completed" + + +def _extract_outcome_detail( + state: FeatureState | BugState, + outcome: str, +) -> str | None: + """Extract a human-readable detail string for the given outcome. + + For ``"Failed"`` outcomes the ``last_error`` field is used. + For ``"Blocked"`` outcomes the ``stats_outcome_reason`` field is used + (which is expected to contain the block reason set by the blocking node). + ``"Completed"`` outcomes have no detail. + + If ``stats_outcome_reason`` is already set in state it takes precedence + over the derived values for all outcome types. + + Args: + state: Current feature or bug workflow state. + outcome: The outcome category string (e.g. ``"Blocked"``). + + Returns: + A detail string, or ``None`` if no detail is available. + """ + # A reason already recorded in state always takes precedence. + existing_reason = state.get("stats_outcome_reason") + if existing_reason: + return existing_reason + + normalised = outcome.lower() + if normalised == "failed": + return state.get("last_error") + + if normalised == "blocked": + # Block reason may also be in feedback_comment from a blocking gate. + return state.get("feedback_comment") + + return None + + +# --------------------------------------------------------------------------- +# Node function +# --------------------------------------------------------------------------- + + +async def post_terminal_stats(state: FeatureState | BugState) -> dict[str, Any]: + """Post a workflow stats summary comment when a terminal state is reached. + + Determines the outcome type (Completed / Blocked / Failed) from the current + state, extracts any relevant detail (error message or block reason), then: + + 1. Calls :func:`~forge.workflow.stats.poster.post_stats_comment` to post + the formatted summary comment to the Jira ticket. + 2. Calls :func:`~forge.workflow.stats.poster.ensure_stats_is_final_comment` + to guarantee the stats comment is the last Forge comment on the ticket + (re-posting if necessary). + + This node is *non-blocking on failure*: any exception raised by the posting + service is caught and logged, and the original state is returned unchanged + so that the workflow can continue to its true terminal node. + + Handles both :class:`~forge.workflow.feature.state.FeatureState` and + :class:`~forge.workflow.bug.state.BugState` workflows transparently. + + Args: + state: Current feature or bug workflow state at a terminal node. + + Returns: + An empty dict (state is returned unchanged — this is a side-effect node). + """ + ticket_key: str = state.get("ticket_key", "") + if not ticket_key: + logger.warning("post_terminal_stats: no ticket_key in state — skipping stats post") + return {} + + outcome = _determine_outcome(state) + outcome_detail = _extract_outcome_detail(state, outcome) + + logger.info( + "post_terminal_stats: posting stats for ticket=%s outcome=%s", + ticket_key, + outcome, + ) + + try: + posted = await post_stats_comment( + ticket_key=ticket_key, + stats=state, + outcome=outcome, + outcome_detail=outcome_detail, + ) + if posted: + logger.info("post_terminal_stats: stats comment posted for ticket=%s", ticket_key) + else: + logger.warning( + "post_terminal_stats: post_stats_comment returned False for ticket=%s", + ticket_key, + ) + except Exception: + # post_stats_comment is itself non-blocking, but guard defensively. + logger.exception( + "post_terminal_stats: unexpected error calling post_stats_comment for ticket=%s", + ticket_key, + ) + + try: + await ensure_stats_is_final_comment( + ticket_key=ticket_key, + stats=state, + outcome=outcome, + outcome_detail=outcome_detail, + ) + logger.info( + "post_terminal_stats: ensure_stats_is_final_comment completed for ticket=%s", + ticket_key, + ) + except Exception: + # Non-blocking — log and continue. + logger.exception( + "post_terminal_stats: unexpected error calling ensure_stats_is_final_comment " + "for ticket=%s", + ticket_key, + ) + + # Return empty dict — state is unchanged (LangGraph merges this with no-op). + return {} diff --git a/src/forge/workflow/nodes/task_generation.py b/src/forge/workflow/nodes/task_generation.py index 938a3013..41ef16fa 100644 --- a/src/forge/workflow/nodes/task_generation.py +++ b/src/forge/workflow/nodes/task_generation.py @@ -3,6 +3,7 @@ import asyncio import logging import re +import time from typing import Any from forge.config import get_settings @@ -11,12 +12,24 @@ from forge.models.workflow import ForgeLabel from forge.prompts import load_prompt from forge.workflow.feature.state import FeatureState as WorkflowState +from forge.workflow.stats import STAGE_TASKS +from forge.workflow.stats_utils import ( + increment_revision, + record_stage_end, + record_stage_start, + record_tokens, +) from forge.workflow.utils import update_state_timestamp from forge.workflow.utils.jira_status import post_status_comment logger = logging.getLogger(__name__) +def _estimate_tokens(text: str) -> int: + """Estimate token count from text length (approx. 4 chars per token).""" + return max(1, len(text) // 4) + + async def generate_tasks(state: WorkflowState) -> WorkflowState: """Generate implementation Tasks for each approved Epic. @@ -35,8 +48,14 @@ async def generate_tasks(state: WorkflowState) -> WorkflowState: ticket_key = state["ticket_key"] epic_keys = state.get("epic_keys", []) + settings = get_settings() + state = {**state, **record_stage_start(state, STAGE_TASKS, model_name=settings.llm_model)} + node_start = time.monotonic() + if not epic_keys: logger.warning(f"No Epics found for task generation on {ticket_key}") + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_TASKS, machine_time)} return { **state, "last_error": "No Epics available for task generation", @@ -125,7 +144,7 @@ async def generate_tasks(state: WorkflowState) -> WorkflowState: sibling_epics = [e for e in all_epics_details if e["epic_key"] != epic_key] # Generate Tasks using Deep Agents - primary operation - tasks_data = await _generate_tasks_for_epic( + tasks_resp = await _generate_tasks_for_epic( agent, epic_plan, epic_summary, @@ -134,6 +153,11 @@ async def generate_tasks(state: WorkflowState) -> WorkflowState: sibling_epics=sibling_epics if sibling_epics else None, existing_tasks=created_tasks_context if created_tasks_context else None, ) + if isinstance(tasks_resp, tuple): + tasks_data, in_tok, out_tok = tasks_resp + else: + tasks_data, in_tok, out_tok = tasks_resp, 0, 0 + state = {**state, **record_tokens(state, STAGE_TASKS, in_tok, out_tok)} # Create Tasks in Jira - secondary operation for task in tasks_data: @@ -214,6 +238,8 @@ async def generate_tasks(state: WorkflowState) -> WorkflowState: except Exception as e: jira_error = str(e) logger.warning(f"Failed to set workflow label for {ticket_key}: {e}") + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_TASKS, machine_time)} return update_state_timestamp( { **state, @@ -229,6 +255,8 @@ async def generate_tasks(state: WorkflowState) -> WorkflowState: ) else: # No Tasks created at all - this is a failure + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_TASKS, machine_time)} return { **state, "last_error": jira_error or "Failed to create any Tasks in Jira", @@ -242,6 +270,8 @@ async def generate_tasks(state: WorkflowState) -> WorkflowState: await notify_error(state, str(e), "generate_tasks") # Save any Tasks we managed to create + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_TASKS, machine_time)} result_state = { **state, "last_error": str(e), @@ -264,7 +294,7 @@ async def _generate_tasks_for_epic( spec_content: str = "", sibling_epics: list[dict[str, str]] | None = None, existing_tasks: list[dict[str, str]] | None = None, -) -> list[dict[str, str]]: +) -> tuple[list[dict[str, str]], int, int]: """Generate Tasks for a single Epic. Args: @@ -277,7 +307,7 @@ async def _generate_tasks_for_epic( existing_tasks: Tasks already created for sibling epics (to avoid duplication). Returns: - List of Task dicts with summary, description, repo. + A tuple of (List of Task dicts, input_tokens, output_tokens). """ existing_tasks_section = _format_existing_tasks(existing_tasks) sibling_epics_section = _format_sibling_epics(sibling_epics) @@ -305,7 +335,20 @@ async def _generate_tasks_for_epic( context=context, ) - return _parse_tasks_response(result) + # Record tokens (using actual agent metadata if available, else falling back to heuristic) + last_in = getattr(agent, "last_input_tokens", 0) + last_out = getattr(agent, "last_output_tokens", 0) + if isinstance(last_in, int) and not isinstance(last_in, bool) and last_in > 0: + input_tokens = last_in + else: + input_tokens = _estimate_tokens(prompt) + + if isinstance(last_out, int) and not isinstance(last_out, bool) and last_out > 0: + output_tokens = last_out + else: + output_tokens = _estimate_tokens(result) if result else 0 + + return _parse_tasks_response(result), input_tokens, output_tokens def _format_sibling_epics(sibling_epics: list[dict[str, str]] | None) -> str: @@ -481,6 +524,7 @@ async def regenerate_all_tasks(state: WorkflowState) -> WorkflowState: # Clear task_keys and set feedback for regeneration updated_state = { **state, + **increment_revision(state, STAGE_TASKS), "task_keys": [], "tasks_by_repo": {}, "feedback_comment": feedback, @@ -535,6 +579,10 @@ async def regenerate_epic_tasks(state: WorkflowState) -> WorkflowState: logger.info(f"Regenerating tasks for Epic {epic_key} on {ticket_key} with feedback") settings = get_settings() + state = {**state, **record_stage_start(state, STAGE_TASKS, model_name=settings.llm_model)} + state = {**state, **increment_revision(state, STAGE_TASKS)} + node_start = time.monotonic() + jira = JiraClient() agent = ForgeAgent() @@ -628,7 +676,7 @@ async def _fetch_sibling(ek: str) -> dict[str, str] | None: spec_content = state.get("spec_content", "") - tasks_data = await _generate_tasks_for_epic( + tasks_resp = await _generate_tasks_for_epic( agent, epic_plan, epic_summary, @@ -637,8 +685,15 @@ async def _fetch_sibling(ek: str) -> dict[str, str] | None: sibling_epics=sibling_epics if sibling_epics else None, existing_tasks=existing_tasks_ctx if existing_tasks_ctx else None, ) + if isinstance(tasks_resp, tuple): + tasks_data, in_tok, out_tok = tasks_resp + else: + tasks_data, in_tok, out_tok = tasks_resp, 0, 0 + state = {**state, **record_tokens(state, STAGE_TASKS, in_tok, out_tok)} if not tasks_data: + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_TASKS, machine_time)} return { **state, "last_error": f"No replacement Tasks generated for Epic {epic_key}", @@ -695,6 +750,8 @@ async def _fetch_sibling(ek: str) -> dict[str, str] | None: logger.warning(f"Failed to create Task '{summary}' for {epic_key}: {e}") if not new_task_keys: + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_TASKS, machine_time)} return { **state, "last_error": jira_error @@ -721,6 +778,8 @@ async def _fetch_sibling(ek: str) -> dict[str, str] | None: cleanup_suffix = ( f"; cleanup failures: {'; '.join(cleanup_errors)}" if cleanup_errors else "" ) + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_TASKS, machine_time)} return { **state, "last_error": ( @@ -746,6 +805,8 @@ async def _fetch_sibling(ek: str) -> dict[str, str] | None: all_task_keys = remaining_task_keys + new_task_keys logger.info(f"Regenerated {len(new_task_keys)} tasks for Epic {epic_key} on {ticket_key}") + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_TASKS, machine_time)} return update_state_timestamp( { **state, @@ -764,6 +825,8 @@ async def _fetch_sibling(ek: str) -> dict[str, str] | None: from forge.workflow.nodes.error_handler import notify_error await notify_error(state, str(e), "regenerate_epic_tasks") + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_TASKS, machine_time)} return { **state, "last_error": str(e), @@ -800,6 +863,11 @@ async def update_single_task(state: WorkflowState) -> WorkflowState: logger.info(f"Updating Task {task_key} with feedback") + settings = get_settings() + state = {**state, **record_stage_start(state, STAGE_TASKS, model_name=settings.llm_model)} + state = {**state, **increment_revision(state, STAGE_TASKS)} + node_start = time.monotonic() + jira = JiraClient() agent = ForgeAgent() @@ -823,6 +891,11 @@ async def update_single_task(state: WorkflowState) -> WorkflowState: }, ) + # Record tokens + input_tokens = _estimate_tokens(original_description) + _estimate_tokens(feedback) + output_tokens = _estimate_tokens(new_description) + state = {**state, **record_tokens(state, STAGE_TASKS, input_tokens, output_tokens)} + # Update Task in Jira await jira.update_description(task_key, new_description) @@ -834,6 +907,9 @@ async def update_single_task(state: WorkflowState) -> WorkflowState: logger.info(f"Task {task_key} updated with feedback") + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_TASKS, machine_time)} + return update_state_timestamp( { **state, @@ -850,6 +926,8 @@ async def update_single_task(state: WorkflowState) -> WorkflowState: from forge.workflow.nodes.error_handler import notify_error await notify_error(state, str(e), "update_single_task") + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_TASKS, machine_time)} return { **state, "last_error": str(e), diff --git a/src/forge/workflow/nodes/triage.py b/src/forge/workflow/nodes/triage.py index 85ae5299..fef63a70 100644 --- a/src/forge/workflow/nodes/triage.py +++ b/src/forge/workflow/nodes/triage.py @@ -6,6 +6,7 @@ import json import logging +import time from langgraph.graph import END @@ -15,10 +16,20 @@ from forge.models.workflow import ForgeLabel from forge.prompts import load_prompt from forge.workflow.bug.state import BugState +from forge.workflow.stats import STAGE_TRIAGE +from forge.workflow.stats_utils import record_stage_end, record_stage_start, record_tokens from forge.workflow.utils import set_paused, update_state_timestamp logger = logging.getLogger(__name__) + +def _estimate_tokens(text: str) -> int: + """Estimate token count from text length (approx. 4 chars per token).""" + if not text: + return 0 + return max(1, len(text) // 4) + + _MAX_RETRIES = 3 __all__ = ["triage_check", "triage_gate", "route_triage_gate"] @@ -46,12 +57,17 @@ async def triage_check(state: BugState) -> BugState: is_resume = state.get("current_node") == "triage_gate" settings = get_settings() + state = {**state, **record_stage_start(state, STAGE_TRIAGE, model_name=settings.llm_model)} + node_start = time.monotonic() + jira = JiraClient(settings) agent = ForgeAgent(settings) try: if retry_count >= _MAX_RETRIES: logger.error("triage_check exceeded max retries for %s", ticket_key) + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_TRIAGE, machine_time)} return {**state, "current_node": "escalate_blocked"} # Step 1: Post acknowledgement on first invocation only (not on resume) @@ -79,6 +95,21 @@ async def triage_check(state: BugState) -> BugState: context={"ticket_key": ticket_key}, ) + # Record tokens (using actual agent metadata if available, else falling back to heuristic) + last_in = getattr(agent, "last_input_tokens", 0) + last_out = getattr(agent, "last_output_tokens", 0) + if isinstance(last_in, int) and not isinstance(last_in, bool) and last_in > 0: + input_tokens = last_in + else: + input_tokens = _estimate_tokens(user_prompt) + + if isinstance(last_out, int) and not isinstance(last_out, bool) and last_out > 0: + output_tokens = last_out + else: + output_tokens = _estimate_tokens(raw_result) + + state = {**state, **record_tokens(state, STAGE_TRIAGE, input_tokens, output_tokens)} + # Step 4: Parse result result_stripped = raw_result.strip() if result_stripped.lower() == "sufficient": @@ -89,6 +120,8 @@ async def triage_check(state: BugState) -> BugState: else "Ticket has enough information to proceed. Starting root cause analysis — results will be posted here." ) await jira.add_comment(ticket_key, pass_msg) + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_TRIAGE, machine_time)} return update_state_timestamp( { **state, @@ -123,6 +156,9 @@ async def triage_check(state: BugState) -> BugState: ) await jira.set_workflow_label(ticket_key, ForgeLabel.TRIAGE_PENDING) + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_TRIAGE, machine_time)} + return update_state_timestamp( { **state, @@ -137,6 +173,8 @@ async def triage_check(state: BugState) -> BugState: except Exception as e: logger.error("triage_check failed for %s: %s", ticket_key, e) new_retry = retry_count + 1 + machine_time = time.monotonic() - node_start + state = {**state, **record_stage_end(state, STAGE_TRIAGE, machine_time)} return { **state, "last_error": str(e), diff --git a/src/forge/workflow/stats/__init__.py b/src/forge/workflow/stats/__init__.py new file mode 100644 index 00000000..cec99439 --- /dev/null +++ b/src/forge/workflow/stats/__init__.py @@ -0,0 +1,143 @@ +"""Statistics tracking data structures for workflow execution. + +This module defines the TypedDicts used to capture per-stage metrics and +overall workflow outcome data, as required by SC-001. It also exports +canonical stage-name constants used by recording and formatting code to +ensure consistency across the codebase. +""" + +from typing import TypedDict + +# --------------------------------------------------------------------------- +# Workflow stage constants +# --------------------------------------------------------------------------- +# These string constants are the canonical identifiers for each named stage +# that is tracked in workflow statistics. Use these constants everywhere +# instead of bare strings so that typos are caught at import time. + +# Feature workflow stages +STAGE_PRD = "prd" +STAGE_SPEC = "spec" +STAGE_EPICS = "epics" +STAGE_TASKS = "tasks" +STAGE_IMPLEMENTATION = "implementation" +STAGE_CI = "ci" +STAGE_REVIEW = "review" + +# Bug workflow stages +STAGE_TRIAGE = "triage" +STAGE_RCA = "rca" +STAGE_PLANNING = "planning" + +# Ordered stage lists used by formatting code to display stages in the +# canonical sequence defined by the specification. + +#: Stages for the Feature workflow, in display order. +ALL_FEATURE_STAGES: list[str] = [ + STAGE_PRD, + STAGE_SPEC, + STAGE_EPICS, + STAGE_TASKS, + STAGE_IMPLEMENTATION, + STAGE_CI, + STAGE_REVIEW, +] + +#: Stages for the Bug workflow, in display order. +ALL_BUG_STAGES: list[str] = [ + STAGE_TRIAGE, + STAGE_RCA, + STAGE_PLANNING, + STAGE_IMPLEMENTATION, + STAGE_CI, + STAGE_REVIEW, +] + + +class StageStats(TypedDict, total=False): + """Per-stage execution metrics captured during workflow execution. + + Each stage in a workflow gets one StageStats entry, keyed by stage name + in the StatsState.stage_timestamps mapping. Fields are updated incrementally + as the stage progresses and finalised when the stage ends. + + Fields: + stage_name: Canonical name of the workflow stage (e.g. "implement"). + iteration_count: Number of times this stage has been (re-)entered, + including retries and revision loops. + machine_time_seconds: Wall-clock seconds spent executing automated work + (LLM calls, tool calls, CI waiting, etc.) — i.e. time the system + was actively doing something. + human_time_seconds: Wall-clock seconds the workflow was paused waiting + for human input (approval gates, revision requests, Q&A). + input_tokens: Cumulative LLM prompt tokens consumed by this stage. + output_tokens: Cumulative LLM completion tokens produced by this stage. + started_at: ISO-8601 timestamp when the stage first started, or None + if the stage has not yet been entered. + ended_at: ISO-8601 timestamp when the stage finished (either completed + or abandoned), or None if it is still in progress. + model_name: Name of the LLM model actually used during this stage's + execution (e.g. "claude-sonnet-4-5@20250929"), or None when the + stage does not invoke an LLM (e.g. CI, review) or the model was + not recorded. + """ + + stage_name: str + iteration_count: int + machine_time_seconds: float + input_tokens: int + output_tokens: int + started_at: str | None + ended_at: str | None + model_name: str | None + + +class StatsState(TypedDict, total=False): + """Mixin TypedDict for workflow-level statistics tracking. + + Intended to be composed into workflow state classes alongside BaseState + and other integration mixins. All fields are optional (total=False) so + that existing workflows can adopt the mixin incrementally without + providing values upfront. + + Outcome values follow the convention: + "Completed" — workflow finished successfully. + "Blocked: " — workflow is waiting on an external blocker. + "Failed: " — workflow terminated due to an unrecoverable error. + + Fields: + stage_timestamps: Mapping from stage name to its StageStats snapshot. + Updated in-place as each stage starts and ends. + revision_counts: Mapping from stage name to the number of revision/retry + cycles that stage has undergone. Mirrors the ``iteration_count`` + value from each ``StageStats`` entry but exposed as a flat top-level + field for easy access by formatters and reporting code. + token_usage: Workflow-wide aggregate token counts with keys + ``"input_tokens"`` and ``"output_tokens"``. + stage_token_usage: Per-stage token breakdown keyed by stage name; each + value is a dict with ``"input_tokens"`` and ``"output_tokens"`` keys. + stats_pr_urls: URLs of all pull requests opened during this workflow + run (across all repositories). + stats_ci_cycles: Number of CI fix-attempt cycles that were triggered + during the implementation phase. + workflow_outcome: Final outcome string for the workflow run, or None + while the workflow is still in progress. + stats_outcome_reason: Human-readable elaboration on the outcome (e.g. + the blocking reason or error message), or None when not applicable. + stats_comment_posted: True once the summary statistics comment has been + posted to the Jira ticket (prevents double-posting on retries). + workflow_run_id: A unique identifier for this specific workflow run + (UUID4 string). Used as the idempotency key when posting the stats + comment to prevent duplicate posts across retries or re-invocations. + """ + + stage_timestamps: dict[str, StageStats] + revision_counts: dict[str, int] + token_usage: dict[str, int] + stage_token_usage: dict[str, dict[str, int]] + stats_pr_urls: list[str] + stats_ci_cycles: int + workflow_outcome: str | None + stats_outcome_reason: str | None + stats_comment_posted: bool + workflow_run_id: str diff --git a/src/forge/workflow/stats/costing.py b/src/forge/workflow/stats/costing.py new file mode 100644 index 00000000..9c151f6c --- /dev/null +++ b/src/forge/workflow/stats/costing.py @@ -0,0 +1,53 @@ +"""LLM cost calculation helpers for workflow statistics. + +This module provides utilities for computing per-stage LLM costs from token +counts using a configurable pricing table. +""" + + +def calculate_stage_cost( + model_name: str | None, + input_tokens: int, + output_tokens: int, + pricing: dict[str, dict[str, float]], +) -> tuple[float | None, float | None]: + """Compute the input and output cost for a single stage. + + Performs a substring/prefix match of *model_name* against the keys in + *pricing* (longest matching key wins for disambiguation). Rates are + expressed in dollars per million tokens ($/MTok). + + Args: + model_name: The LLM model name recorded for the stage, or ``None`` + when the stage did not invoke an LLM. + input_tokens: Total prompt tokens consumed by the stage. + output_tokens: Total completion tokens produced by the stage. + pricing: Mapping of model-name substrings to + ``{"input": <$/MTok>, "output": <$/MTok>}`` rate entries. + + Returns: + A ``(input_cost, output_cost)`` tuple in dollars. Both values are + ``None`` when *model_name* is ``None`` or when no pricing key matches. + """ + if model_name is None: + return (None, None) + + name_lower = model_name.lower() + + # Find the longest pricing key that is a substring of the model name. + best_key: str | None = None + for key in pricing: + if key.lower() in name_lower and (best_key is None or len(key) > len(best_key)): + best_key = key + + if best_key is None: + return (None, None) + + rates = pricing[best_key] + input_rate: float = rates.get("input", 0.0) + output_rate: float = rates.get("output", 0.0) + + input_cost = input_tokens / 1_000_000 * input_rate + output_cost = output_tokens / 1_000_000 * output_rate + + return (input_cost, output_cost) diff --git a/src/forge/workflow/stats/formatter.py b/src/forge/workflow/stats/formatter.py new file mode 100644 index 00000000..27871528 --- /dev/null +++ b/src/forge/workflow/stats/formatter.py @@ -0,0 +1,373 @@ +"""Jira wiki markup formatter for workflow statistics summaries. + +This module transforms StatsState data into Jira wiki markup suitable for +posting as a comment on the associated Jira ticket at the end of a workflow run. +""" + +from forge.workflow.stats import ( + ALL_BUG_STAGES, + ALL_FEATURE_STAGES, + StageStats, + StatsState, +) +from forge.workflow.stats.costing import calculate_stage_cost + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +#: Maximum length for outcome_detail before truncation. +_MAX_DETAIL_LEN = 200 + +#: Display labels for each stage key, in the order they appear in the table. +_STAGE_LABELS: dict[str, str] = { + "prd": "PRD", + "spec": "Spec", + "epics": "Epics", + "tasks": "Tasks", + "implementation": "Implementation", + "ci": "CI", + "review": "Review", + # Bug workflow stages (if needed in future extensions) + "triage": "Triage", + "rca": "RCA", + "planning": "Planning", +} + +#: Em-dash used when a stage was never executed. +_DASH = "\u2014" + +#: Stage keys that only appear in Bug workflows. +_BUG_ONLY_STAGES = frozenset({"triage", "rca", "planning"}) + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _truncate(text: str, max_len: int = _MAX_DETAIL_LEN) -> str: + """Return *text* truncated to *max_len* characters with '...' suffix. + + If *text* is already within the limit it is returned unchanged. + """ + if len(text) <= max_len: + return text + return text[:max_len] + "..." + + +def _fmt_seconds(seconds: float) -> str: + """Format a duration in seconds to a human-readable string (e.g. '1h 23m 45s').""" + total = int(seconds) + hours, remainder = divmod(total, 3600) + minutes, secs = divmod(remainder, 60) + if hours: + return f"{hours}h {minutes}m {secs}s" + if minutes: + return f"{minutes}m {secs}s" + return f"{secs}s" + + +def _fmt_tokens(count: int) -> str: + """Format a token count with thousands separators.""" + return f"{count:,}" + + +def _fmt_cost(cost: float) -> str: + """Format a dollar cost value for display (e.g. '$1.23').""" + return f"${cost:.2f}" + + +def _build_stage_row( + label: str, + stage: StageStats | None, + pricing: dict[str, dict[str, float]] | None = None, +) -> str: + """Return a single Jira table row for a workflow stage. + + If *stage* is None (never executed), all metric columns show '—'. + + Args: + label: Human-readable stage label for the first column. + stage: Stage metrics dict, or ``None`` when the stage was not executed. + pricing: Optional LLM pricing table passed to :func:`calculate_stage_cost`. + When ``None``, the cost column shows ``cost unavailable``. + """ + if stage is None: + return f"| {label} | {_DASH} | {_DASH} | {_DASH} | {_DASH} | {_DASH} |" + + iterations = stage.get("iteration_count", 0) + machine_time = _fmt_seconds(stage.get("machine_time_seconds", 0.0)) + input_tok = _fmt_tokens(stage.get("input_tokens", 0)) + output_tok = _fmt_tokens(stage.get("output_tokens", 0)) + + if pricing is not None: + model_name = stage.get("model_name") + input_cost, output_cost = calculate_stage_cost( + model_name, + stage.get("input_tokens", 0), + stage.get("output_tokens", 0), + pricing, + ) + if input_cost is not None and output_cost is not None: + cost_str = _fmt_cost(input_cost + output_cost) + else: + cost_str = "cost unavailable" + else: + cost_str = "cost unavailable" + + return f"| {label} | {iterations} | {machine_time} | {input_tok} | {output_tok} | {cost_str} |" + + +def _build_totals_row( + stages: dict[str, StageStats], + pricing: dict[str, dict[str, float]] | None = None, +) -> str: + """Return the aggregate token totals row summed across all stages. + + Args: + stages: Mapping of stage key to stage metrics. + pricing: Optional LLM pricing table. When provided, computes and + displays a total dollar cost. When ``None`` or any stage has an + unknown model, shows ``cost unavailable``. + """ + total_iterations = sum(s.get("iteration_count", 0) for s in stages.values()) + total_machine_seconds = sum(s.get("machine_time_seconds", 0.0) for s in stages.values()) + total_input = sum(s.get("input_tokens", 0) for s in stages.values()) + total_output = sum(s.get("output_tokens", 0) for s in stages.values()) + + cost_str = _build_total_cost_str(stages, pricing) + + return ( + f"| **Total** | **{total_iterations}** | **{_fmt_seconds(total_machine_seconds)}** |" + f" **{_fmt_tokens(total_input)}** | **{_fmt_tokens(total_output)}** | {cost_str} |" + ) + + +def _build_total_cost_str( + stages: dict[str, StageStats], + pricing: dict[str, dict[str, float]] | None, +) -> str: + """Compute the formatted total cost string for the totals row. + + Returns ``'cost unavailable'`` when *pricing* is ``None`` or any stage + with recorded tokens has an unknown model. Otherwise returns a formatted + dollar amount. + """ + if pricing is None: + return "cost unavailable" + + total_cost = 0.0 + for stage in stages.values(): + model_name = stage.get("model_name") + input_tokens = stage.get("input_tokens", 0) + output_tokens = stage.get("output_tokens", 0) + if input_tokens == 0 and output_tokens == 0: + # Stage used no tokens — skip without penalising the total. + continue + input_cost, output_cost = calculate_stage_cost( + model_name, input_tokens, output_tokens, pricing + ) + if input_cost is None or output_cost is None: + return "cost unavailable" + total_cost += input_cost + output_cost + + return _fmt_cost(total_cost) + + +def _build_outcome_str(outcome: str, outcome_detail: str | None) -> str: + """Construct the formatted outcome string for display. + + Supported outcome values: + ``"completed"`` → ``"Completed"`` + ``"blocked"`` → ``"Blocked: "`` + ``"failed"`` → ``"Failed: "`` + + The *outcome* parameter is matched case-insensitively. Any detail longer + than 200 characters is truncated with '...' suffix. + """ + key = outcome.lower() + if key == "completed": + return "Completed" + detail = _truncate(outcome_detail or "") if outcome_detail else "" + if key == "blocked": + if detail: + return f"Blocked: {detail}" + return "Blocked" + if key == "failed": + if detail: + return f"Failed: {detail}" + return "Failed" + # Fallback for unknown outcome values — display as-is with optional detail. + if detail: + return f"{outcome}: {detail}" + return outcome + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def _build_cost_alert( + total_tokens: int, + threshold: int, +) -> list[str]: + """Return Jira wiki markup lines for a token-based cost alert section. + + The alert is displayed as a visually prominent panel when the aggregate + token usage exceeds *threshold*. + + Args: + total_tokens: Actual aggregate token count (input + output). + threshold: Configured token threshold that was exceeded. + + Returns: + A list of Jira wiki markup lines (without a trailing newline). + """ + return [ + "", + "> **⚠️ COST ALERT**", + "> Token usage has exceeded the configured threshold.", + f"> **Threshold:** {_fmt_tokens(threshold)} tokens", + f"> **Actual usage:** {_fmt_tokens(total_tokens)} tokens", + ] + + +def _build_dollar_cost_alert( + total_cost: float, + threshold: float, +) -> list[str]: + """Return Jira wiki markup lines for a dollar-based cost alert section. + + The alert is displayed as a visually prominent panel when the aggregate + dollar cost exceeds *threshold*. + + Args: + total_cost: Actual aggregate dollar cost across all stages. + threshold: Configured dollar threshold that was exceeded. + + Returns: + A list of Jira wiki markup lines (without a trailing newline). + """ + return [ + "", + "> **⚠️ COST ALERT**", + "> LLM cost has exceeded the configured threshold.", + f"> **Threshold:** {_fmt_cost(threshold)}", + f"> **Actual cost:** {_fmt_cost(total_cost)}", + ] + + +def format_stats_summary( + stats: StatsState, + outcome: str, + outcome_detail: str | None = None, + token_threshold: int | None = None, + dollar_threshold: float | None = None, + pricing: dict[str, dict[str, float]] | None = None, +) -> str: + """Format a StatsState snapshot into a Jira wiki markup comment. + + The generated comment includes: + * A stage-by-stage metrics table (iterations, machine time, + input tokens, output tokens, cost). + * An aggregate token totals row with total cost. + * A PR links section (omitted when no PRs were created). + * A CI cycles line. + * A final outcome field. + * An optional cost alert panel when total token usage exceeds + *token_threshold* or total dollar cost exceeds *dollar_threshold* + (omitted when both thresholds are ``None`` or not exceeded). + + When *dollar_threshold* is set it takes precedence over *token_threshold* + for cost alerting purposes. + + Args: + stats: The workflow statistics state to format. + outcome: Outcome category — one of ``"completed"``, ``"blocked"``, or + ``"failed"`` (matched case-insensitively). + outcome_detail: Optional elaboration on the outcome (e.g. the blocking + reason or error message). Truncated to 200 characters if longer. + token_threshold: Optional token count threshold. When the aggregate + token usage (input + output across all stages) exceeds this value, + a prominent "⚠️ COST ALERT" section is appended to the summary. + Pass ``None`` (the default) to disable token-based cost alerting. + dollar_threshold: Optional dollar cost threshold. When set, compares + total dollar cost against this value rather than using the token + threshold. Pass ``None`` (the default) to use token-based alerting. + pricing: Optional LLM pricing table (mapping model name substrings to + ``{"input": $/MTok, "output": $/MTok}``). When provided, a *Cost* + column is populated in the stage table. Defaults to ``None``. + + Returns: + A Jira wiki markup string ready to post as a ticket comment. + """ + stages: dict[str, StageStats] = stats.get("stage_timestamps") or {} + pr_urls: list[str] = stats.get("stats_pr_urls") or [] + ci_cycles: int = stats.get("stats_ci_cycles") or 0 + + lines: list[str] = [] + + # ------------------------------------------------------------------ + # Stage metrics table + # ------------------------------------------------------------------ + lines.append("### Workflow Statistics") + lines.append("") + lines.append("| Stage | Iterations | Machine Time | Input Tokens | Output Tokens | Cost |") + lines.append("| --- | --- | --- | --- | --- | --- |") + + # Detect workflow type: prefer bug stage ordering when any bug-only stage + # key is present in the recorded data. + display_stages = ( + ALL_BUG_STAGES if any(k in stages for k in _BUG_ONLY_STAGES) else ALL_FEATURE_STAGES + ) + for stage_key in display_stages: + label = _STAGE_LABELS.get(stage_key, stage_key.title()) + stage_data = stages.get(stage_key) + lines.append(_build_stage_row(label, stage_data, pricing=pricing)) + + # Aggregate totals row (always shown, even when no stages ran) + lines.append(_build_totals_row(stages, pricing=pricing)) + + # ------------------------------------------------------------------ + # PR links section (omitted when no PRs) + # ------------------------------------------------------------------ + if pr_urls: + lines.append("") + lines.append("**Pull Requests**") + for url in pr_urls: + lines.append(f"* [{url}]({url})") + + # ------------------------------------------------------------------ + # CI cycles + # ------------------------------------------------------------------ + lines.append("") + lines.append(f"**CI Cycles:** {ci_cycles}") + + # ------------------------------------------------------------------ + # Outcome + # ------------------------------------------------------------------ + lines.append("") + outcome_str = _build_outcome_str(outcome, outcome_detail) + lines.append(f"**Outcome:** {outcome_str}") + + # ------------------------------------------------------------------ + # Cost alert (only when threshold is configured and exceeded) + # ------------------------------------------------------------------ + if dollar_threshold is not None and pricing is not None: + # Dollar-based alerting takes precedence over token-based. + total_cost_str = _build_total_cost_str(stages, pricing) + # Only alert when total cost is computable (not 'cost unavailable'). + if total_cost_str != "cost unavailable": + total_cost = float(total_cost_str.lstrip("$")) + if total_cost > dollar_threshold: + lines.extend(_build_dollar_cost_alert(total_cost, dollar_threshold)) + elif token_threshold is not None: + total_tokens = sum( + s.get("input_tokens", 0) + s.get("output_tokens", 0) for s in stages.values() + ) + if total_tokens > token_threshold: + lines.extend(_build_cost_alert(total_tokens, token_threshold)) + + return "\n".join(lines) diff --git a/src/forge/workflow/stats/idempotency.py b/src/forge/workflow/stats/idempotency.py new file mode 100644 index 00000000..0bc5264f --- /dev/null +++ b/src/forge/workflow/stats/idempotency.py @@ -0,0 +1,135 @@ +"""Idempotency guard for stats comment posting. + +Prevents duplicate stats comments from being posted to the same Jira ticket +for the same workflow run. Markers are stored in Redis with a 7-day TTL, +which is more than sufficient for any workflow to complete. + +Usage:: + + from forge.workflow.stats.idempotency import has_stats_been_posted, mark_stats_posted + + if not await has_stats_been_posted(ticket_key, run_id): + # … post comment … + await mark_stats_posted(ticket_key, run_id) +""" + +import logging + +import redis.asyncio as redis + +from forge.orchestrator.checkpointer import get_redis_client + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +#: Redis key prefix for stats-posted idempotency markers. +_KEY_PREFIX = "forge:stats:posted:" + +#: Time-to-live for idempotency markers (7 days in seconds). +STATS_IDEMPOTENCY_TTL_SECONDS = 7 * 24 * 60 * 60 # 604 800 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_key(ticket_key: str, run_id: str) -> str: + """Return the Redis key for a given ticket / run combination. + + Args: + ticket_key: The Jira issue key (e.g. ``"PROJ-123"``). + run_id: The unique workflow run identifier (UUID4 string). + + Returns: + Redis key string in the form ``forge:stats:posted::``. + """ + return f"{_KEY_PREFIX}{ticket_key}:{run_id}" + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +async def has_stats_been_posted( + ticket_key: str, + run_id: str, + *, + redis_client: redis.Redis | None = None, +) -> bool: + """Check whether a stats comment has already been posted for this run. + + Args: + ticket_key: The Jira issue key (e.g. ``"PROJ-123"``). + run_id: The unique workflow run identifier stored in + ``StatsState.workflow_run_id``. + redis_client: Optional Redis client to use. A shared client is + obtained via :func:`~forge.orchestrator.checkpointer.get_redis_client` + when not provided. + + Returns: + ``True`` if the marker exists in Redis (comment already posted), + ``False`` otherwise. + """ + client = redis_client if redis_client is not None else await get_redis_client() + key = _make_key(ticket_key, run_id) + exists = await client.exists(key) + posted = bool(exists) + if posted: + logger.debug( + "Stats comment already posted for ticket=%s run_id=%s (key=%s)", + ticket_key, + run_id, + key, + ) + return posted + + +async def mark_stats_posted( + ticket_key: str, + run_id: str, + *, + redis_client: redis.Redis | None = None, +) -> None: + """Record that a stats comment has been posted for this run. + + Stores a marker in Redis with a 7-day TTL so that subsequent calls to + :func:`has_stats_been_posted` return ``True`` for the same combination. + + Args: + ticket_key: The Jira issue key (e.g. ``"PROJ-123"``). + run_id: The unique workflow run identifier stored in + ``StatsState.workflow_run_id``. + redis_client: Optional Redis client to use. A shared client is + obtained via :func:`~forge.orchestrator.checkpointer.get_redis_client` + when not provided. + """ + client = redis_client if redis_client is not None else await get_redis_client() + key = _make_key(ticket_key, run_id) + await client.setex(key, STATS_IDEMPOTENCY_TTL_SECONDS, "1") + logger.debug( + "Marked stats comment as posted for ticket=%s run_id=%s (TTL=%ds)", + ticket_key, + run_id, + STATS_IDEMPOTENCY_TTL_SECONDS, + ) + + +def build_run_marker(run_id: str) -> str: + """Return the hidden HTML comment marker to embed in the posted comment. + + Including this marker in the Jira comment body allows independent + verification that a comment was posted for a specific run — useful + for debugging and for future tooling that inspects comment bodies. + + Args: + run_id: The unique workflow run identifier. + + Returns: + HTML comment string of the form ````. + """ + return f"" diff --git a/src/forge/workflow/stats/notifications.py b/src/forge/workflow/stats/notifications.py new file mode 100644 index 00000000..33e37ffa --- /dev/null +++ b/src/forge/workflow/stats/notifications.py @@ -0,0 +1,266 @@ +"""Jira-native notification delivery for weekly report generation. + +This module provides functions to notify project stakeholders when a weekly +report is generated, using Jira's native notification mechanisms (comments +with user mentions). + +Usage:: + + from forge.workflow.stats.notifications import ( + get_notification_recipients, + notify_report_ready, + ) + + recipients = await get_notification_recipients("PROJ") + await notify_report_ready("PROJ-42", recipients) + +Configuration: + - ``FORGE_WEEKLY_REPORT_NOTIFY`` env var: comma-separated Jira account IDs + (e.g. ``"abc123,def456"``) or the special value ``"project-leads"`` to + read recipients from the project property ``forge.weekly-report.notify``. + - Jira project property ``forge.weekly-report.notify``: list of Jira + account IDs (JSON array or comma-separated string) that overrides the + global env var for a specific project. + +Priority: project property > env var. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from forge.config import get_settings +from forge.integrations.jira.client import JiraClient + +logger = logging.getLogger(__name__) + +#: Jira project property key for per-project notification recipients. +_NOTIFY_PROPERTY_KEY = "forge.weekly-report.notify" + +#: Special sentinel value meaning "read recipients from the project property". +_PROJECT_LEADS_SENTINEL = "project-leads" + + +def _format_mention(account_id: str) -> str: + """Format a Jira account ID as a mention string. + + Uses Jira's ``[~accountid:{id}]`` mention syntax so that the user receives + a Jira notification when the comment is posted. + + Args: + account_id: Jira account ID (e.g. ``"5e7e3b1a..."``) + + Returns: + Mention string in the form ``"[~accountid:5e7e3b1a...]"``. + """ + return f"[~accountid:{account_id}]" + + +def _parse_account_ids(raw: Any) -> list[str]: + """Parse a list of Jira account IDs from various raw formats. + + Accepts: + - A JSON array of strings (from a Jira project property) + - A comma-separated string (from an env var or a string property) + - A plain string (single account ID) + + Empty strings and whitespace-only entries are filtered out. + + Args: + raw: Raw value — a list, a comma-separated string, or any other value. + + Returns: + Deduplicated list of non-empty account ID strings, preserving order. + """ + if isinstance(raw, list): + ids = [str(item).strip() for item in raw if str(item).strip()] + elif isinstance(raw, str): + ids = [part.strip() for part in raw.split(",") if part.strip()] + else: + return [] + + # Deduplicate while preserving order + seen: set[str] = set() + unique: list[str] = [] + for aid in ids: + if aid not in seen: + seen.add(aid) + unique.append(aid) + return unique + + +async def _get_project_property_recipients(project: str) -> list[str] | None: + """Fetch the ``forge.weekly-report.notify`` project property. + + Args: + project: Jira project key (e.g. ``"PROJ"``). + + Returns: + Parsed list of account IDs, or ``None`` if the property is not set or + cannot be read. + """ + jira = JiraClient() + try: + value = await jira.get_project_property(project, _NOTIFY_PROPERTY_KEY) + except Exception as exc: + logger.warning( + "Failed to read project property %r for project %r: %s", + _NOTIFY_PROPERTY_KEY, + project, + exc, + ) + return None + finally: + await jira.close() + + if value is None: + return None + + ids = _parse_account_ids(value) + return ids if ids else None + + +async def get_notification_recipients(project: str) -> list[str]: + """Retrieve the list of Jira account IDs to notify for a weekly report. + + Resolution order (highest priority first): + + 1. **Per-project Jira property** ``forge.weekly-report.notify`` — if set, + its value is used unconditionally (overrides the env var). + 2. **Env var** ``FORGE_WEEKLY_REPORT_NOTIFY`` — comma-separated account IDs + or the special value ``"project-leads"`` which triggers a lookup of the + project property instead of being treated as a literal account ID. + 3. Empty list — no notifications are sent. + + Args: + project: Jira project key (e.g. ``"PROJ"``). + + Returns: + List of Jira account IDs. May be empty if no recipients are configured. + """ + # 1. Check per-project property first + project_ids = await _get_project_property_recipients(project) + if project_ids is not None: + logger.debug( + "Using project property recipients for %r: %s", + project, + project_ids, + ) + return project_ids + + # 2. Fall back to the env var + settings = get_settings() + raw_env = settings.weekly_report_notify.strip() if settings.weekly_report_notify else "" + + if not raw_env: + return [] + + if raw_env.lower() == _PROJECT_LEADS_SENTINEL: + # "project-leads" is a sentinel — attempt the property lookup explicitly + # (it already returned None above, so there are no project-level leads) + logger.debug( + "FORGE_WEEKLY_REPORT_NOTIFY='project-leads' but no project property set for %r; " + "no recipients.", + project, + ) + return [] + + env_ids = _parse_account_ids(raw_env) + logger.debug( + "Using env var recipients for %r: %s", + project, + env_ids, + ) + return env_ids + + +async def notify_report_ready( + ticket_key: str, + recipients: list[str], + *, + jira_base_url: str = "", +) -> None: + """Post a notification comment on the report ticket mentioning recipients. + + The comment body includes: + - A brief summary announcing the report is ready. + - A link to the report ticket. + - Mentions for each recipient, so they receive a Jira notification. + + Recipients that appear to be invalid (empty string or clearly + non-account-ID-shaped values) are skipped with a warning log. + + Args: + ticket_key: Jira issue key of the weekly-report ticket (e.g. ``"PROJ-42"``). + recipients: List of Jira account IDs to mention. + jira_base_url: Override for the Jira base URL used in the ticket link. + When empty, the value from settings is used. Useful for tests. + + Returns: + None. The comment is posted as a side effect. + """ + if not recipients: + logger.debug("notify_report_ready: no recipients — skipping comment on %s", ticket_key) + return + + settings = get_settings() + base_url = (jira_base_url or settings.jira_base_url).rstrip("/") + ticket_url = f"{base_url}/browse/{ticket_key}" + + # Validate and build mention strings, skipping obviously invalid IDs + mention_parts: list[str] = [] + for account_id in recipients: + if not account_id or not isinstance(account_id, str): + logger.warning( + "notify_report_ready: skipping invalid account_id %r on ticket %s", + account_id, + ticket_key, + ) + continue + # Basic sanity check: account IDs should be non-empty strings without + # spaces or commas. This guards against accidentally receiving raw + # comma-separated strings that were not split properly. + if " " in account_id or "," in account_id: + logger.warning( + "notify_report_ready: skipping malformed account_id %r (contains space or comma)" + " on ticket %s", + account_id, + ticket_key, + ) + continue + mention_parts.append(_format_mention(account_id)) + + if not mention_parts: + logger.warning( + "notify_report_ready: all recipients were invalid — no comment posted on %s", + ticket_key, + ) + return + + mentions_str = " ".join(mention_parts) + comment_body = ( + f"📊 *Weekly report is ready:* [{ticket_key}|{ticket_url}]\n\n" + f"The Forge weekly report has been generated and is available on the ticket above. " + f"Please review the report for workflow activity, cycle time trends, and any bottlenecks " + f"identified during the reporting period.\n\n" + f"Notifying: {mentions_str}" + ) + + jira = JiraClient() + try: + await jira.add_comment(ticket_key, comment_body) + logger.info( + "Posted notification comment on %s for %d recipient(s)", + ticket_key, + len(mention_parts), + ) + except Exception as exc: + logger.error( + "Failed to post notification comment on %s: %s", + ticket_key, + exc, + ) + raise + finally: + await jira.close() diff --git a/src/forge/workflow/stats/poster.py b/src/forge/workflow/stats/poster.py new file mode 100644 index 00000000..5b8ebf4d --- /dev/null +++ b/src/forge/workflow/stats/poster.py @@ -0,0 +1,358 @@ +"""Stats comment posting service for Jira tickets. + +This module provides non-blocking async functions that format and post +workflow statistics as a comment to the associated Jira ticket at the end +of a workflow run. + +Idempotency +----------- +``post_stats_comment`` checks Redis before posting and skips the comment if +one has already been recorded for the given ``run_id``. After a successful +post the marker is written to Redis with a 7-day TTL via +:func:`~forge.workflow.stats.idempotency.mark_stats_posted`. A hidden HTML +comment (````) is also embedded in the comment +body for independent verification. + +Re-Post Mechanism +----------------- +``ensure_stats_is_final_comment`` guarantees the stats comment is always the +*last* Forge comment on the ticket. It fetches all comments, identifies the +most recent one posted by the Forge service account, and re-posts the stats +summary if a non-stats comment was added after the most recent stats comment. +""" + +import asyncio +import logging + +from forge.config import get_settings +from forge.integrations.jira.client import JiraClient +from forge.workflow.stats import StatsState +from forge.workflow.stats.formatter import format_stats_summary +from forge.workflow.stats.idempotency import ( + build_run_marker, + has_stats_been_posted, + mark_stats_posted, +) + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Retry configuration +# --------------------------------------------------------------------------- + +#: Maximum number of posting attempts (1 initial + 2 retries). +_MAX_ATTEMPTS = 3 + +#: Initial backoff delay in seconds before the first retry. +_INITIAL_BACKOFF_SECONDS = 1.0 + +#: Maximum allowed backoff delay (caps exponential growth). +_MAX_BACKOFF_SECONDS = 16.0 + +#: Overall timeout for the entire post_stats_comment operation (5-minute SLA). +_OPERATION_TIMEOUT_SECONDS = 300.0 + +#: Prefix embedded in all stats comment bodies for identification. +#: This substring is present in every comment posted by post_stats_comment / +#: ensure_stats_is_final_comment and is used by _is_stats_comment() to +#: distinguish stats comments from other Forge comments. +_STATS_BODY_MARKER = "``) + that :func:`post_stats_comment` embeds in every comment it posts. + + Args: + body: The raw text body of a Jira comment. + + Returns: + ``True`` when the body contains the stats marker, ``False`` otherwise. + """ + return _STATS_BODY_MARKER in body diff --git a/src/forge/workflow/stats/report_ticket.py b/src/forge/workflow/stats/report_ticket.py new file mode 100644 index 00000000..02ca65a1 --- /dev/null +++ b/src/forge/workflow/stats/report_ticket.py @@ -0,0 +1,200 @@ +"""Report ticket resolution and auto-creation for weekly reports. + +This module provides functions to create or update a dedicated "Weekly Report" +ticket in Jira that stores the weekly report content, enabling historical +tracking and Jira-native access. + +Usage:: + + from datetime import date + from forge.workflow.stats.report_ticket import ensure_report_ticket + + ticket_key = await ensure_report_ticket( + project="PROJ", + week_start=date(2024, 1, 8), + report_markdown="## Weekly Report\\n...", + ) + print(f"Report ticket: {ticket_key}") +""" + +from __future__ import annotations + +import logging +from datetime import date + +from forge.integrations.jira.client import JiraClient + +logger = logging.getLogger(__name__) + +#: Labels applied to every report ticket. +REPORT_LABELS: list[str] = ["forge:weekly-report", "forge:generated"] + +#: Issue type used for report tickets. +REPORT_ISSUE_TYPE: str = "Task" + + +def _report_summary(project: str, week_start: date) -> str: + """Build the standard summary string for a report ticket. + + Args: + project: Jira project key (e.g. ``"PROJ"``). + week_start: The Monday (or first day) of the reporting week. + + Returns: + Summary string in the form + ``"Forge Weekly Report - PROJ - Week of 2024-01-08"``. + """ + return f"Forge Weekly Report - {project} - Week of {week_start}" + + +def _report_jql(project: str, week_start: date) -> str: + """Build the JQL query to locate an existing report ticket. + + Args: + project: Jira project key. + week_start: The first day of the reporting week. + + Returns: + JQL string. + """ + week_str = str(week_start) + return ( + f'project = "{project}" ' + f'AND labels = "forge:weekly-report" ' + f'AND summary ~ "Week of {week_str}"' + ) + + +async def resolve_report_ticket(project: str, week_start: date) -> str | None: + """Find an existing report ticket for the given project and week. + + Searches Jira using JQL: + ``project = {project} AND labels = "forge:weekly-report" + AND summary ~ "Week of {week_start}"``. + + Args: + project: Jira project key (e.g. ``"PROJ"``). + week_start: The first day of the reporting week. + + Returns: + The ticket key (e.g. ``"PROJ-42"``) if found, or ``None``. + """ + jql = _report_jql(project, week_start) + jira = JiraClient() + try: + issues = await jira.search_issues( + jql=jql, + fields=["summary", "labels"], + max_results=5, + ) + finally: + await jira.close() + + if not issues: + logger.debug( + "No existing report ticket found for project=%r week_start=%s", + project, + week_start, + ) + return None + + # Return the first (most relevant) match. + ticket_key = issues[0].key + logger.info( + "Found existing report ticket %s for project=%r week_start=%s", + ticket_key, + project, + week_start, + ) + return ticket_key + + +async def create_report_ticket( + project: str, + week_start: date, + report_markdown: str, +) -> str: + """Create a new report ticket with the given report as its description. + + Args: + project: Jira project key (e.g. ``"PROJ"``). + week_start: The first day of the reporting week. + report_markdown: Full report content (Markdown / Jira wiki markup). + + Returns: + The key of the newly created ticket (e.g. ``"PROJ-42"``). + """ + summary = _report_summary(project, week_start) + jira = JiraClient() + try: + ticket_key = await jira.create_task( + project_key=project, + summary=summary, + description=report_markdown, + labels=REPORT_LABELS, + ) + finally: + await jira.close() + + logger.info( + "Created report ticket %s for project=%r week_start=%s", + ticket_key, + project, + week_start, + ) + return ticket_key + + +async def update_report_ticket(ticket_key: str, report_markdown: str) -> None: + """Update the description of an existing report ticket. + + Does not create a duplicate — only updates the description field of the + ticket identified by *ticket_key*. + + Args: + ticket_key: The Jira issue key to update (e.g. ``"PROJ-42"``). + report_markdown: New report content (Markdown / Jira wiki markup). + """ + jira = JiraClient() + try: + await jira.update_description(ticket_key, report_markdown) + finally: + await jira.close() + + logger.info("Updated description for report ticket %s", ticket_key) + + +async def ensure_report_ticket( + project: str, + week_start: date, + report_markdown: str, +) -> str: + """Resolve or create the report ticket, then update its description. + + This function is idempotent — calling it twice with the same arguments + produces the same result (the existing ticket is updated in-place rather + than a duplicate being created). + + Steps: + + 1. Search for an existing report ticket via :func:`resolve_report_ticket`. + 2. If none exists, create one via :func:`create_report_ticket`. + 3. Update the description with *report_markdown* via + :func:`update_report_ticket`. + + Args: + project: Jira project key (e.g. ``"PROJ"``). + week_start: The first day of the reporting week. + report_markdown: Full report content (Markdown / Jira wiki markup). + + Returns: + The key of the report ticket (existing or newly created). + """ + ticket_key = await resolve_report_ticket(project, week_start) + + if ticket_key is None: + ticket_key = await create_report_ticket(project, week_start, report_markdown) + else: + await update_report_ticket(ticket_key, report_markdown) + + return ticket_key diff --git a/src/forge/workflow/stats/weekly_formatter.py b/src/forge/workflow/stats/weekly_formatter.py new file mode 100644 index 00000000..7429aecf --- /dev/null +++ b/src/forge/workflow/stats/weekly_formatter.py @@ -0,0 +1,609 @@ +"""Weekly report formatters for CLI, Markdown, and JSON output. + +This module renders :class:`WeeklyReportData` into human-readable terminal +output, exportable Markdown (suitable for Jira posting or file export), and +machine-readable JSON for tooling integration. + +Usage:: + + from forge.workflow.stats.weekly_formatter import ( + format_weekly_report_cli, + format_weekly_report_markdown, + format_weekly_report_json, + ) + + report = await collect_weekly_data("AISOS") + + # Terminal output + print(format_weekly_report_cli(report)) + + # Save Markdown to file + with open("weekly.md", "w") as f: + f.write(format_weekly_report_markdown(report)) + + # JSON for scripting + print(format_weekly_report_json(report)) +""" + +from __future__ import annotations + +import json +from typing import Any + +from forge.workflow.stats.weekly_report import ( + BottleneckAnalysis, + FeatureRollup, + TicketSummary, + WeeklyReportData, +) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +#: Em-dash used for absent / N/A values. +_DASH = "\u2014" + +#: Display labels for workflow stage keys. +_STAGE_LABELS: dict[str, str] = { + "prd": "PRD", + "spec": "Spec", + "epics": "Epics", + "tasks": "Tasks", + "implementation": "Implementation", + "ci": "CI", + "review": "Review", + "triage": "Triage", + "rca": "RCA", + "planning": "Planning", +} + + +# --------------------------------------------------------------------------- +# Internal formatting primitives +# --------------------------------------------------------------------------- + + +def _format_duration(seconds: float) -> str: + """Format *seconds* into a human-readable duration string. + + Examples:: + + _format_duration(0) → "0s" + _format_duration(65) → "1m 5s" + _format_duration(3662) → "1h 1m 2s" + _format_duration(90061) → "25h 1m 1s" + + Args: + seconds: Non-negative duration in seconds. + + Returns: + A compact human-readable string such as ``"3h 42m"`` or ``"7m 30s"``. + Hours are always shown when present; minutes are shown when ≥ 1 or + when hours are shown; seconds are always shown. + """ + total = int(seconds) + hours, remainder = divmod(total, 3600) + minutes, secs = divmod(remainder, 60) + if hours: + return f"{hours}h {minutes}m" + if minutes: + return f"{minutes}m {secs}s" + return f"{secs}s" + + +def _format_token_count(count: int) -> str: + """Format *count* as an abbreviated token count string. + + Large numbers are abbreviated with metric suffixes: + + * ``< 1 000`` → raw integer (e.g. ``"999"``) + * ``1 000–999 999`` → ``"Nk"`` or ``"N.Mk"`` (e.g. ``"31k"``, ``"1.5k"``) + * ``≥ 1 000 000`` → ``"NM"`` or ``"N.MM"`` (e.g. ``"1M"``, ``"1.5M"``) + + Examples:: + + _format_token_count(999) → "999" + _format_token_count(1000) → "1k" + _format_token_count(1500) → "1.5k" + _format_token_count(31000) → "31k" + _format_token_count(1000000) → "1M" + _format_token_count(1500000) → "1.5M" + + Args: + count: Non-negative token count. + + Returns: + A compact abbreviated string representation. + """ + if count < 1_000: + return str(count) + if count < 1_000_000: + value = count / 1_000 + if value == int(value): + return f"{int(value)}k" + return f"{value:.1f}k" + value = count / 1_000_000 + if value == int(value): + return f"{int(value)}M" + return f"{value:.1f}M" + + +def _format_bottleneck_section(bottlenecks: BottleneckAnalysis) -> str: + """Render a *BottleneckAnalysis* as a plain-text section. + + The section includes: + + * Total tickets analysed + * Slowest stage (or N/A) + * CI fix rate as a percentage + * Top revised stages (up to 3) + * Stage average durations table + + Args: + bottlenecks: The bottleneck data to render. + + Returns: + A multi-line plain-text string (no trailing newline). + """ + lines: list[str] = [] + + lines.append(f" Tickets Analysed : {bottlenecks.total_tickets_analyzed}") + + slowest = bottlenecks.slowest_stage + if slowest: + avg_dur = bottlenecks.avg_stage_durations.get(slowest, 0.0) + label = _STAGE_LABELS.get(slowest, slowest.title()) + lines.append(f" Slowest Stage : {label} (avg {_format_duration(avg_dur)})") + else: + lines.append(f" Slowest Stage : {_DASH}") + + ci_pct = bottlenecks.ci_fix_rate * 100.0 + lines.append(f" CI Fix Rate : {ci_pct:.0f}%") + + if bottlenecks.most_revised_stages: + top = bottlenecks.most_revised_stages[:3] + top_labels = [_STAGE_LABELS.get(s, s.title()) for s in top] + lines.append(f" Most Revised : {', '.join(top_labels)}") + else: + lines.append(f" Most Revised : {_DASH}") + + if bottlenecks.avg_stage_durations: + lines.append("") + lines.append(" Stage Avg Durations:") + for stage_key, avg_secs in sorted(bottlenecks.avg_stage_durations.items()): + label = _STAGE_LABELS.get(stage_key, stage_key.title()) + lines.append(f" {label:<16} {_format_duration(avg_secs)}") + + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Internal section builders +# --------------------------------------------------------------------------- + + +def _ticket_list_summary(tickets: list[TicketSummary]) -> list[str]: + """Return a list of formatted lines for a ticket list subsection (CLI).""" + if not tickets: + return [" (none)"] + lines: list[str] = [] + for t in tickets: + duration_str = ( + _format_duration(t.duration_seconds) if t.duration_seconds is not None else _DASH + ) + tokens_str = _format_token_count(t.input_tokens + t.output_tokens) + lines.append( + f" {t.ticket_key:<16} {t.ticket_type:<10} dur={duration_str:<10} tokens={tokens_str}" + ) + return lines + + +def _token_by_stage_section(tokens_by_stage: dict[str, tuple[int, int]]) -> list[str]: + """Return CLI lines for the token breakdown by stage.""" + if not tokens_by_stage: + return [" (no stage data)"] + lines: list[str] = [] + for stage_key, (in_tok, out_tok) in sorted(tokens_by_stage.items()): + label = _STAGE_LABELS.get(stage_key, stage_key.title()) + total = in_tok + out_tok + lines.append( + f" {label:<16} in={_format_token_count(in_tok):<8} " + f"out={_format_token_count(out_tok):<8} " + f"total={_format_token_count(total)}" + ) + return lines + + +def _feature_rollup_section_cli(feature_rollups: dict[str, FeatureRollup]) -> list[str]: + """Return CLI lines for the Feature rollup section.""" + if not feature_rollups: + return [] + lines: list[str] = ["", "Feature Rollup", "=" * 60] + for feature_key, rollup in sorted(feature_rollups.items()): + summary = rollup.feature_summary or "(no summary)" + lines.append(f" {feature_key}: {summary}") + total_tickets = len(rollup.linked_tickets) + lines.append( + f" Tickets : {total_tickets} total, " + f"{rollup.tickets_completed} completed, " + f"{rollup.tickets_in_progress} in progress" + ) + lines.append(f" Progress: {rollup.completion_percentage:.0f}%") + tokens_total = rollup.total_input_tokens + rollup.total_output_tokens + lines.append(f" Tokens : {_format_token_count(tokens_total)}") + if rollup.total_duration is not None: + lines.append(f" Duration: {_format_duration(rollup.total_duration)}") + lines.append("") + return lines + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def format_weekly_report_cli(data: WeeklyReportData) -> str: + """Render *data* as a terminal-friendly plain text weekly report. + + The output matches the design spec format (Section 4) and includes: + + * Report header (project, period, date range) + * Summary section (ticket counts, avg cycle time, token totals) + * Ticket breakdown by status (completed, in-progress, blocked) + * Token usage by stage + * Bottleneck analysis + * Feature rollup section (when feature_rollups is populated) + + Args: + data: Aggregated weekly report data. + + Returns: + A multi-line plain text string suitable for terminal display. + """ + lines: list[str] = [] + + # ------------------------------------------------------------------ + # Header + # ------------------------------------------------------------------ + period_label = f"Last {data.period_days} days" + lines.append("=" * 60) + lines.append(f" WEEKLY REPORT — {data.project}") + lines.append(f" Period : {period_label}") + lines.append(f" From : {data.report_start}") + lines.append(f" To : {data.report_end}") + lines.append("=" * 60) + lines.append("") + + # ------------------------------------------------------------------ + # Summary + # ------------------------------------------------------------------ + n_completed = len(data.completed_tickets) + n_in_progress = len(data.in_progress_tickets) + n_blocked = len(data.blocked_tickets) + n_total = n_completed + n_in_progress + n_blocked + + avg_cycle = _format_duration(data.avg_cycle_time) if data.avg_cycle_time is not None else _DASH + total_tokens = data.total_input_tokens + data.total_output_tokens + + lines.append("Summary") + lines.append("-" * 40) + lines.append(f" Total Tickets : {n_total}") + lines.append(f" Completed : {n_completed}") + lines.append(f" In Progress : {n_in_progress}") + lines.append(f" Blocked : {n_blocked}") + lines.append(f" Avg Cycle Time : {avg_cycle}") + lines.append(f" Total Tokens : {_format_token_count(total_tokens)}") + lines.append(f" Input Tokens : {_format_token_count(data.total_input_tokens)}") + lines.append(f" Output Tokens : {_format_token_count(data.total_output_tokens)}") + lines.append("") + + # ------------------------------------------------------------------ + # Ticket lists + # ------------------------------------------------------------------ + lines.append("Completed Tickets") + lines.append("-" * 40) + lines.extend(_ticket_list_summary(data.completed_tickets)) + lines.append("") + + lines.append("In-Progress Tickets") + lines.append("-" * 40) + lines.extend(_ticket_list_summary(data.in_progress_tickets)) + lines.append("") + + lines.append("Blocked Tickets") + lines.append("-" * 40) + lines.extend(_ticket_list_summary(data.blocked_tickets)) + lines.append("") + + # ------------------------------------------------------------------ + # Token usage by stage + # ------------------------------------------------------------------ + lines.append("Token Usage by Stage") + lines.append("-" * 40) + lines.extend(_token_by_stage_section(data.tokens_by_stage)) + lines.append("") + + # ------------------------------------------------------------------ + # Bottleneck analysis + # ------------------------------------------------------------------ + lines.append("Bottleneck Analysis") + lines.append("-" * 40) + lines.append(_format_bottleneck_section(data.bottlenecks)) + lines.append("") + + # ------------------------------------------------------------------ + # Feature rollup (when populated) + # ------------------------------------------------------------------ + rollup_lines = _feature_rollup_section_cli(data.feature_rollups) + if rollup_lines: + lines.extend(rollup_lines) + + return "\n".join(lines) + + +def format_weekly_report_markdown(data: WeeklyReportData) -> str: + """Render *data* as a Markdown weekly report. + + The output is valid GitHub-flavored Markdown with headers and tables, + suitable for: + + * Saving to a ```.md`` file + * Posting to Jira as a Markdown code block or using a Markdown plugin + * Sharing in Slack/Teams channels + + Args: + data: Aggregated weekly report data. + + Returns: + A Markdown string. + """ + lines: list[str] = [] + + # ------------------------------------------------------------------ + # Header + # ------------------------------------------------------------------ + period_label = f"Last {data.period_days} Days" + lines.append(f"# Weekly Report — {data.project}") + lines.append("") + lines.append(f"**Period:** {period_label} ") + lines.append(f"**From:** {data.report_start} ") + lines.append(f"**To:** {data.report_end}") + lines.append("") + + # ------------------------------------------------------------------ + # Summary table + # ------------------------------------------------------------------ + n_completed = len(data.completed_tickets) + n_in_progress = len(data.in_progress_tickets) + n_blocked = len(data.blocked_tickets) + n_total = n_completed + n_in_progress + n_blocked + + avg_cycle = _format_duration(data.avg_cycle_time) if data.avg_cycle_time is not None else _DASH + total_tokens = data.total_input_tokens + data.total_output_tokens + + lines.append("## Summary") + lines.append("") + lines.append("| Metric | Value |") + lines.append("|--------|-------|") + lines.append(f"| Total Tickets | {n_total} |") + lines.append(f"| Completed | {n_completed} |") + lines.append(f"| In Progress | {n_in_progress} |") + lines.append(f"| Blocked | {n_blocked} |") + lines.append(f"| Avg Cycle Time | {avg_cycle} |") + lines.append(f"| Total Tokens | {_format_token_count(total_tokens)} |") + lines.append(f"| Input Tokens | {_format_token_count(data.total_input_tokens)} |") + lines.append(f"| Output Tokens | {_format_token_count(data.total_output_tokens)} |") + lines.append("") + + # ------------------------------------------------------------------ + # Tickets table + # ------------------------------------------------------------------ + def _ticket_md_row(t: TicketSummary) -> str: + duration_str = ( + _format_duration(t.duration_seconds) if t.duration_seconds is not None else _DASH + ) + tokens_str = _format_token_count(t.input_tokens + t.output_tokens) + return f"| {t.ticket_key} | {t.ticket_type} | {duration_str} | {tokens_str} |" + + ticket_header = "| Ticket | Type | Duration | Tokens |" + ticket_sep = "|--------|------|----------|--------|" + + lines.append("## Completed Tickets") + lines.append("") + if data.completed_tickets: + lines.append(ticket_header) + lines.append(ticket_sep) + for t in data.completed_tickets: + lines.append(_ticket_md_row(t)) + else: + lines.append("_No completed tickets this period._") + lines.append("") + + lines.append("## In-Progress Tickets") + lines.append("") + if data.in_progress_tickets: + lines.append(ticket_header) + lines.append(ticket_sep) + for t in data.in_progress_tickets: + lines.append(_ticket_md_row(t)) + else: + lines.append("_No in-progress tickets this period._") + lines.append("") + + lines.append("## Blocked Tickets") + lines.append("") + if data.blocked_tickets: + lines.append(ticket_header) + lines.append(ticket_sep) + for t in data.blocked_tickets: + lines.append(_ticket_md_row(t)) + else: + lines.append("_No blocked tickets this period._") + lines.append("") + + # ------------------------------------------------------------------ + # Token usage by stage + # ------------------------------------------------------------------ + lines.append("## Token Usage by Stage") + lines.append("") + if data.tokens_by_stage: + lines.append("| Stage | Input | Output | Total |") + lines.append("|-------|-------|--------|-------|") + for stage_key, (in_tok, out_tok) in sorted(data.tokens_by_stage.items()): + label = _STAGE_LABELS.get(stage_key, stage_key.title()) + total = in_tok + out_tok + lines.append( + f"| {label} | {_format_token_count(in_tok)} " + f"| {_format_token_count(out_tok)} " + f"| {_format_token_count(total)} |" + ) + else: + lines.append("_No stage token data available._") + lines.append("") + + # ------------------------------------------------------------------ + # Bottleneck analysis + # ------------------------------------------------------------------ + b = data.bottlenecks + lines.append("## Bottleneck Analysis") + lines.append("") + lines.append("| Metric | Value |") + lines.append("|--------|-------|") + lines.append(f"| Tickets Analysed | {b.total_tickets_analyzed} |") + + slowest = b.slowest_stage + if slowest: + avg_dur = b.avg_stage_durations.get(slowest, 0.0) + slowest_label = _STAGE_LABELS.get(slowest, slowest.title()) + lines.append(f"| Slowest Stage | {slowest_label} (avg {_format_duration(avg_dur)}) |") + else: + lines.append(f"| Slowest Stage | {_DASH} |") + + ci_pct = b.ci_fix_rate * 100.0 + lines.append(f"| CI Fix Rate | {ci_pct:.0f}% |") + + if b.most_revised_stages: + top = b.most_revised_stages[:3] + top_labels = [_STAGE_LABELS.get(s, s.title()) for s in top] + lines.append(f"| Most Revised | {', '.join(top_labels)} |") + else: + lines.append(f"| Most Revised | {_DASH} |") + lines.append("") + + if b.avg_stage_durations: + lines.append("### Stage Average Durations") + lines.append("") + lines.append("| Stage | Avg Duration |") + lines.append("|-------|-------------|") + for stage_key, avg_secs in sorted(b.avg_stage_durations.items()): + label = _STAGE_LABELS.get(stage_key, stage_key.title()) + lines.append(f"| {label} | {_format_duration(avg_secs)} |") + lines.append("") + + # ------------------------------------------------------------------ + # Feature rollup + # ------------------------------------------------------------------ + if data.feature_rollups: + lines.append("## Feature Rollup") + lines.append("") + lines.append( + "| Feature | Summary | Tickets | Completed | In Progress | Progress | Tokens |" + ) + lines.append( + "|---------|---------|---------|-----------|-------------|----------|--------|" + ) + for feature_key, rollup in sorted(data.feature_rollups.items()): + summary = rollup.feature_summary or "" + total_tickets = len(rollup.linked_tickets) + tokens_total = rollup.total_input_tokens + rollup.total_output_tokens + lines.append( + f"| {feature_key} | {summary} | {total_tickets} " + f"| {rollup.tickets_completed} | {rollup.tickets_in_progress} " + f"| {rollup.completion_percentage:.0f}% " + f"| {_format_token_count(tokens_total)} |" + ) + lines.append("") + + return "\n".join(lines) + + +def format_weekly_report_json(data: WeeklyReportData) -> str: + """Serialise *data* as pretty-printed JSON for tooling integration. + + All dataclass fields are included in the output. Token counts are left as + raw integers (not abbreviated) so that downstream tooling can perform its + own formatting. + + Args: + data: Aggregated weekly report data. + + Returns: + A pretty-printed, sorted-key JSON string. + """ + + def _ticket_dict(t: TicketSummary) -> dict[str, Any]: + return { + "ticket_key": t.ticket_key, + "ticket_type": t.ticket_type, + "status": t.status, + "duration_seconds": t.duration_seconds, + "input_tokens": t.input_tokens, + "output_tokens": t.output_tokens, + "ci_cycles": t.ci_cycles, + "outcome": t.outcome, + "tokens_by_stage": { + stage: {"input": in_tok, "output": out_tok} + for stage, (in_tok, out_tok) in t.tokens_by_stage.items() + }, + "revision_counts": t.revision_counts, + "stage_durations": t.stage_durations, + } + + def _rollup_dict(rollup: FeatureRollup) -> dict[str, Any]: + return { + "feature_key": rollup.feature_key, + "feature_summary": rollup.feature_summary, + "total_input_tokens": rollup.total_input_tokens, + "total_output_tokens": rollup.total_output_tokens, + "total_duration": rollup.total_duration, + "tickets_completed": rollup.tickets_completed, + "tickets_in_progress": rollup.tickets_in_progress, + "completion_percentage": rollup.completion_percentage, + "linked_tickets": [t.ticket_key for t in rollup.linked_tickets], + } + + payload: dict[str, Any] = { + "project": data.project, + "period_days": data.period_days, + "report_start": data.report_start, + "report_end": data.report_end, + "summary": { + "total_tickets": len(data.all_tickets), + "completed": len(data.completed_tickets), + "in_progress": len(data.in_progress_tickets), + "blocked": len(data.blocked_tickets), + "avg_cycle_time_seconds": data.avg_cycle_time, + "total_input_tokens": data.total_input_tokens, + "total_output_tokens": data.total_output_tokens, + }, + "tokens_by_stage": { + stage: {"input": in_tok, "output": out_tok} + for stage, (in_tok, out_tok) in data.tokens_by_stage.items() + }, + "bottlenecks": { + "total_tickets_analyzed": data.bottlenecks.total_tickets_analyzed, + "slowest_stage": data.bottlenecks.slowest_stage, + "ci_fix_rate": data.bottlenecks.ci_fix_rate, + "most_revised_stages": data.bottlenecks.most_revised_stages, + "avg_stage_durations": data.bottlenecks.avg_stage_durations, + }, + "completed_tickets": [_ticket_dict(t) for t in data.completed_tickets], + "in_progress_tickets": [_ticket_dict(t) for t in data.in_progress_tickets], + "blocked_tickets": [_ticket_dict(t) for t in data.blocked_tickets], + "feature_rollups": { + key: _rollup_dict(rollup) for key, rollup in data.feature_rollups.items() + }, + } + + return json.dumps(payload, indent=2, sort_keys=True) diff --git a/src/forge/workflow/stats/weekly_report.py b/src/forge/workflow/stats/weekly_report.py new file mode 100644 index 00000000..cb96d2f7 --- /dev/null +++ b/src/forge/workflow/stats/weekly_report.py @@ -0,0 +1,793 @@ +"""Weekly report data aggregation module. + +Collects and aggregates workflow statistics from Redis checkpoints to produce +a summary of activity over a configurable time window (default: 7 days). + +Usage:: + + from forge.workflow.stats.weekly_report import collect_weekly_data + + report = await collect_weekly_data("AISOS", days=7) + print(f"Completed: {report.completed_tickets}") + print(f"In Progress: {report.in_progress_tickets}") + print(f"Blocked: {report.blocked_tickets}") + print(f"Avg Cycle Time: {report.avg_cycle_time:.1f}s") +""" + +from __future__ import annotations + +import contextlib +import logging +from dataclasses import dataclass, field +from datetime import UTC, datetime, timedelta +from typing import Any + +from forge.integrations.jira.client import JiraClient +from forge.orchestrator.checkpointer import get_checkpoint_state, get_redis_client + +#: Sentinel key used to group tickets that could not be linked to any Feature. +UNASSIGNED_FEATURE_KEY = "Unassigned" + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Redis key pattern used by langgraph-checkpoint-redis +# --------------------------------------------------------------------------- + +#: Prefix used by langgraph-checkpoint-redis for checkpoint storage. +_CHECKPOINT_KEY_PREFIX = "checkpoint:" + + +# --------------------------------------------------------------------------- +# Dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class TicketSummary: + """Per-ticket statistics extracted from a workflow checkpoint. + + Attributes: + ticket_key: The Jira issue key (e.g. ``"AISOS-123"``). + ticket_type: Workflow type — ``"Feature"`` or ``"Bug"``. + status: Derived status — one of ``"completed"``, ``"in_progress"``, + or ``"blocked"``. + duration_seconds: Wall-clock seconds from the first stage start to + workflow completion, or to *now* when still in progress. ``None`` + when no stage timing is available. + input_tokens: Total LLM prompt tokens consumed across all stages. + output_tokens: Total LLM completion tokens consumed across all stages. + tokens_by_stage: Per-stage token totals as ``{stage_name: (in, out)}``. + revision_counts: Per-stage iteration count as ``{stage_name: count}``. + ci_cycles: Number of CI fix-attempt cycles triggered during the run. + outcome: The raw ``stats_outcome`` string from the checkpoint, or + ``None`` when the workflow is still in progress. + stage_durations: Per-stage machine time in seconds ``{stage_name: secs}``. + """ + + ticket_key: str + ticket_type: str = "Feature" + status: str = "in_progress" + duration_seconds: float | None = None + input_tokens: int = 0 + output_tokens: int = 0 + tokens_by_stage: dict[str, tuple[int, int]] = field(default_factory=dict) + revision_counts: dict[str, int] = field(default_factory=dict) + ci_cycles: int = 0 + outcome: str | None = None + stage_durations: dict[str, float] = field(default_factory=dict) + + +@dataclass +class BottleneckAnalysis: + """Stage-level performance metrics computed across a set of tickets. + + Attributes: + avg_stage_durations: Average machine time per stage across all tickets + that executed that stage, in seconds. ``{stage_name: avg_seconds}``. + most_revised_stages: Stage names ordered by average iteration count + (descending). The first element is the most-revised stage. + ci_fix_rate: Fraction of tickets (0.0–1.0) that triggered at least one + CI fix cycle. ``0.0`` when no tickets are present. + slowest_stage: The stage with the highest average duration, or ``None`` + when no stage data is available. + total_tickets_analyzed: Number of tickets used to compute these metrics. + """ + + avg_stage_durations: dict[str, float] = field(default_factory=dict) + most_revised_stages: list[str] = field(default_factory=list) + ci_fix_rate: float = 0.0 + slowest_stage: str | None = None + total_tickets_analyzed: int = 0 + + +@dataclass +class FeatureRollup: + """Aggregated statistics for all tickets linked to a single Feature. + + Tickets may be linked to the Feature directly (when their parent is the + Feature itself) or indirectly (when their parent is an Epic whose parent + is the Feature). + + Attributes: + feature_key: The Jira key of the parent Feature (e.g. ``"AISOS-10"``), + or the ``UNASSIGNED_FEATURE_KEY`` sentinel for tickets that could + not be resolved to any Feature. + feature_summary: The summary/title of the Feature issue, or an empty + string when the Feature could not be fetched (e.g. network error). + linked_tickets: All :class:`TicketSummary` objects grouped under this + Feature. + total_input_tokens: Sum of prompt tokens across all linked tickets. + total_output_tokens: Sum of completion tokens across all linked tickets. + total_duration: Sum of ``duration_seconds`` across all linked tickets + that have timing data. ``None`` when no ticket has timing data. + tickets_completed: Number of linked tickets with status ``"completed"``. + tickets_in_progress: Number of linked tickets with status + ``"in_progress"``. + completion_percentage: Fraction of linked tickets that are completed, + expressed as a value in ``[0.0, 100.0]``. ``0.0`` when there are + no linked tickets. + """ + + feature_key: str + feature_summary: str = "" + linked_tickets: list[TicketSummary] = field(default_factory=list) + total_input_tokens: int = 0 + total_output_tokens: int = 0 + total_duration: float | None = None + tickets_completed: int = 0 + tickets_in_progress: int = 0 + completion_percentage: float = 0.0 + + +@dataclass +class WeeklyReportData: + """Aggregated weekly report data across all matching workflow checkpoints. + + Attributes: + project: The Jira project key used to filter checkpoints. + period_days: Number of days covered by the report window. + report_start: ISO-8601 UTC timestamp marking the start of the window. + report_end: ISO-8601 UTC timestamp marking the end of the window (now). + completed_tickets: Tickets whose workflow completed successfully during + the window. + in_progress_tickets: Tickets still actively running during the window. + blocked_tickets: Tickets that are currently blocked. + total_input_tokens: Sum of all prompt tokens across every ticket. + total_output_tokens: Sum of all completion tokens across every ticket. + tokens_by_stage: Aggregate token totals per stage + ``{stage_name: (total_in, total_out)}``. + avg_cycle_time: Average duration in seconds from first stage start to + workflow completion, computed over completed tickets only. ``None`` + when no completed tickets have timing data. + bottlenecks: Stage-level performance metrics for the entire period. + all_tickets: All ``TicketSummary`` objects included in this report. + """ + + project: str + period_days: int = 7 + report_start: str = "" + report_end: str = "" + completed_tickets: list[TicketSummary] = field(default_factory=list) + in_progress_tickets: list[TicketSummary] = field(default_factory=list) + blocked_tickets: list[TicketSummary] = field(default_factory=list) + total_input_tokens: int = 0 + total_output_tokens: int = 0 + tokens_by_stage: dict[str, tuple[int, int]] = field(default_factory=dict) + avg_cycle_time: float | None = None + bottlenecks: BottleneckAnalysis = field(default_factory=BottleneckAnalysis) + all_tickets: list[TicketSummary] = field(default_factory=list) + feature_rollups: dict[str, FeatureRollup] = field(default_factory=dict) + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _parse_timestamp(ts: str | None) -> datetime | None: + """Parse an ISO-8601 timestamp string into an aware UTC datetime. + + Args: + ts: ISO-8601 timestamp string (e.g. ``"2024-01-01T12:00:00+00:00"``), + or ``None``. + + Returns: + An aware :class:`datetime` in UTC, or ``None`` when *ts* is absent or + unparseable. + """ + if not ts: + return None + try: + dt = datetime.fromisoformat(ts) + # Ensure the datetime is timezone-aware (convert naive to UTC) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=UTC) + return dt + except (ValueError, TypeError): + logger.debug("Could not parse timestamp %r", ts) + return None + + +def _parse_checkpoint_stats(state: dict[str, Any]) -> TicketSummary | None: + """Extract a :class:`TicketSummary` from a single checkpoint state dict. + + Reads the ``stage_timestamps``, ``stats_ci_cycles``, ``workflow_outcome``, + ``ticket_key``, and ``ticket_type`` fields produced by the stats + recording utilities. + + Args: + state: Raw checkpoint state dict as returned by the checkpoint reader. + + Returns: + A populated :class:`TicketSummary`, or ``None`` when the state lacks + the minimum required fields (``ticket_key``, ``stage_timestamps``). + """ + ticket_key: str | None = state.get("ticket_key") + if not ticket_key: + logger.debug("Checkpoint state missing ticket_key; skipping") + return None + + if "stage_timestamps" not in state: + logger.debug("Checkpoint for %s has no stage_timestamps; skipping", ticket_key) + return None + + stats_stages: dict[str, Any] = state.get("stage_timestamps") or {} + if not isinstance(stats_stages, dict): + logger.warning( + "Malformed stage_timestamps for %s (type %s); treating as empty", + ticket_key, + type(stats_stages).__name__, + ) + stats_stages = {} + + # --- Ticket type --- + raw_type = state.get("ticket_type", "") + ticket_type = str(raw_type) if raw_type else "Feature" + + # --- Outcome / status --- + outcome: str | None = state.get("workflow_outcome") + is_blocked: bool = bool(state.get("is_blocked", False)) + + if outcome and outcome.lower().startswith("completed"): + status = "completed" + elif is_blocked or (outcome and outcome.lower().startswith("blocked")): + status = "blocked" + else: + status = "in_progress" + + # --- Token aggregation --- + input_tokens = 0 + output_tokens = 0 + tokens_by_stage: dict[str, tuple[int, int]] = {} + revision_counts: dict[str, int] = {} + stage_durations: dict[str, float] = {} + + for stage_name, stage_data in stats_stages.items(): + if not isinstance(stage_data, dict): + continue + stage_in = int(stage_data.get("input_tokens", 0) or 0) + stage_out = int(stage_data.get("output_tokens", 0) or 0) + input_tokens += stage_in + output_tokens += stage_out + tokens_by_stage[stage_name] = (stage_in, stage_out) + revision_counts[stage_name] = int(stage_data.get("iteration_count", 0) or 0) + machine_time = float(stage_data.get("machine_time_seconds", 0.0) or 0.0) + stage_durations[stage_name] = machine_time + + # --- Cycle time: first stage start → last stage end (or now) --- + duration_seconds: float | None = None + + start_times = [] + end_times = [] + for stage_data in stats_stages.values(): + if not isinstance(stage_data, dict): + continue + started = _parse_timestamp(stage_data.get("started_at")) + ended = _parse_timestamp(stage_data.get("ended_at")) + if started: + start_times.append(started) + if ended: + end_times.append(ended) + + if start_times: + earliest_start = min(start_times) + if status == "completed" and end_times: + latest_end = max(end_times) + duration_seconds = (latest_end - earliest_start).total_seconds() + elif status != "completed": + # Still in-progress: measure up to now + duration_seconds = (datetime.now(UTC) - earliest_start).total_seconds() + + ci_cycles = int(state.get("stats_ci_cycles", 0) or 0) + + return TicketSummary( + ticket_key=ticket_key, + ticket_type=ticket_type, + status=status, + duration_seconds=duration_seconds, + input_tokens=input_tokens, + output_tokens=output_tokens, + tokens_by_stage=tokens_by_stage, + revision_counts=revision_counts, + ci_cycles=ci_cycles, + outcome=outcome, + stage_durations=stage_durations, + ) + + +def _calculate_bottlenecks(tickets: list[TicketSummary]) -> BottleneckAnalysis: + """Compute stage-level performance metrics from a collection of tickets. + + For each stage that appears in at least one ticket, the following are + computed: + + * **avg_stage_durations** — mean machine time in seconds across tickets + that executed the stage. + * **most_revised_stages** — stages ordered by mean iteration count + (descending); stages with equal counts preserve insertion order. + * **ci_fix_rate** — fraction of tickets that triggered ≥ 1 CI cycle. + * **slowest_stage** — stage name with the highest average duration. + + Args: + tickets: The list of :class:`TicketSummary` objects to analyse. + + Returns: + A populated :class:`BottleneckAnalysis`. + """ + if not tickets: + return BottleneckAnalysis(total_tickets_analyzed=0) + + # Accumulate stage durations and revision counts across all tickets + stage_duration_totals: dict[str, float] = {} + stage_duration_counts: dict[str, int] = {} + stage_revision_totals: dict[str, int] = {} + stage_revision_counts: dict[str, int] = {} + + ci_triggered = 0 + + for ticket in tickets: + if ticket.ci_cycles > 0: + ci_triggered += 1 + + for stage_name, duration in ticket.stage_durations.items(): + stage_duration_totals[stage_name] = ( + stage_duration_totals.get(stage_name, 0.0) + duration + ) + stage_duration_counts[stage_name] = stage_duration_counts.get(stage_name, 0) + 1 + + for stage_name, rev_count in ticket.revision_counts.items(): + stage_revision_totals[stage_name] = stage_revision_totals.get(stage_name, 0) + rev_count + stage_revision_counts[stage_name] = stage_revision_counts.get(stage_name, 0) + 1 + + # Compute averages + avg_stage_durations: dict[str, float] = { + stage: stage_duration_totals[stage] / stage_duration_counts[stage] + for stage in stage_duration_totals + } + + avg_revision_counts: dict[str, float] = { + stage: stage_revision_totals[stage] / stage_revision_counts[stage] + for stage in stage_revision_totals + } + + # Order stages by mean revision count descending + most_revised_stages = sorted( + avg_revision_counts.keys(), + key=lambda s: avg_revision_counts[s], + reverse=True, + ) + + # CI fix rate + ci_fix_rate = ci_triggered / len(tickets) + + # Slowest stage by average duration + slowest_stage: str | None = None + if avg_stage_durations: + slowest_stage = max(avg_stage_durations, key=lambda s: avg_stage_durations[s]) + + return BottleneckAnalysis( + avg_stage_durations=avg_stage_durations, + most_revised_stages=most_revised_stages, + ci_fix_rate=ci_fix_rate, + slowest_stage=slowest_stage, + total_tickets_analyzed=len(tickets), + ) + + +def _is_within_window(state: dict[str, Any], cutoff: datetime) -> bool: + """Return True if the checkpoint falls within the reporting time window. + + A checkpoint is considered *within the window* when any of the following + conditions hold: + + 1. The ``updated_at`` timestamp is ≥ *cutoff*. + 2. Any ``started_at`` or ``ended_at`` timestamp in ``stats_stages`` is + ≥ *cutoff*. + + Args: + state: Raw checkpoint state dict. + cutoff: The earliest datetime (inclusive) to include. + + Returns: + ``True`` if the checkpoint falls within the window. + """ + updated_at = _parse_timestamp(state.get("updated_at")) + if updated_at and updated_at >= cutoff: + return True + + stats_stages = state.get("stage_timestamps") or {} + if not isinstance(stats_stages, dict): + return False + + for stage_data in stats_stages.values(): + if not isinstance(stage_data, dict): + continue + for ts_key in ("started_at", "ended_at"): + ts = _parse_timestamp(stage_data.get(ts_key)) + if ts and ts >= cutoff: + return True + + return False + + +def _aggregate_tokens( + tickets: list[TicketSummary], +) -> tuple[int, int, dict[str, tuple[int, int]]]: + """Sum token counts across all tickets. + + Args: + tickets: The ticket summaries to aggregate. + + Returns: + A 3-tuple of ``(total_input, total_output, tokens_by_stage)`` where + *tokens_by_stage* maps stage name to ``(total_in, total_out)`` across + all tickets. + """ + total_in = 0 + total_out = 0 + by_stage: dict[str, list[int]] = {} # stage -> [total_in, total_out] + + for ticket in tickets: + total_in += ticket.input_tokens + total_out += ticket.output_tokens + for stage_name, (s_in, s_out) in ticket.tokens_by_stage.items(): + if stage_name not in by_stage: + by_stage[stage_name] = [0, 0] + by_stage[stage_name][0] += s_in + by_stage[stage_name][1] += s_out + + tokens_by_stage: dict[str, tuple[int, int]] = { + stage: (totals[0], totals[1]) for stage, totals in by_stage.items() + } + return total_in, total_out, tokens_by_stage + + +def _avg_cycle_time(tickets: list[TicketSummary]) -> float | None: + """Compute the average cycle time for completed tickets. + + Only completed tickets with non-None ``duration_seconds`` are included. + + Args: + tickets: All ticket summaries (not just completed ones). + + Returns: + Average cycle time in seconds, or ``None`` when no applicable tickets + are found. + """ + durations = [ + t.duration_seconds + for t in tickets + if t.status == "completed" and t.duration_seconds is not None + ] + if not durations: + return None + return sum(durations) / len(durations) + + +# --------------------------------------------------------------------------- +# Feature rollup helpers +# --------------------------------------------------------------------------- + + +async def _resolve_feature_key( + ticket: TicketSummary, + jira: JiraClient, +) -> str | None: + """Resolve the parent Feature key for a ticket by traversing the hierarchy. + + The lookup strategy is: + + 1. Fetch the Jira issue for *ticket.ticket_key*. + 2. If its ``issue_type`` is ``"Feature"``, return its own key (the ticket + *is* the Feature). + 3. If it has a parent, fetch the parent. + 4. If the parent ``issue_type`` is ``"Feature"``, return the parent key + (ticket is directly under a Feature). + 5. If the parent is an ``"Epic"``, fetch *its* parent and return that key + when the grandparent is a ``"Feature"``. + 6. Return ``None`` when no Feature ancestor is found within two hops, or + when any Jira API call fails. + + Args: + ticket: The ticket whose Feature ancestry should be resolved. + jira: An open :class:`JiraClient` to use for API calls. + + Returns: + The Jira key of the nearest Feature ancestor, or ``None`` when + resolution fails or no Feature is found. + """ + with contextlib.suppress(Exception): + issue = await jira.get_issue(ticket.ticket_key) + + # The ticket itself is a Feature (unusual but possible) + if issue.issue_type == "Feature": + return issue.key + + if not issue.parent_key: + return None + + parent = await jira.get_issue(issue.parent_key) + + if parent.issue_type == "Feature": + return parent.key + + # Parent is an Epic — climb one more level to find the Feature + if parent.issue_type == "Epic" and parent.parent_key: + grandparent = await jira.get_issue(parent.parent_key) + if grandparent.issue_type == "Feature": + return grandparent.key + + return None + + +def _build_feature_rollup( + feature_key: str, + feature_summary: str, + tickets: list[TicketSummary], +) -> FeatureRollup: + """Build a :class:`FeatureRollup` from a pre-grouped list of tickets. + + Args: + feature_key: The Feature key (or ``UNASSIGNED_FEATURE_KEY``). + feature_summary: Human-readable summary of the Feature issue. + tickets: All tickets that belong to this Feature. + + Returns: + A fully populated :class:`FeatureRollup`. + """ + total_in = sum(t.input_tokens for t in tickets) + total_out = sum(t.output_tokens for t in tickets) + + durations = [t.duration_seconds for t in tickets if t.duration_seconds is not None] + total_duration: float | None = sum(durations) if durations else None + + tickets_completed = sum(1 for t in tickets if t.status == "completed") + tickets_in_progress = sum(1 for t in tickets if t.status == "in_progress") + + completion_pct = (tickets_completed / len(tickets) * 100.0) if tickets else 0.0 + + return FeatureRollup( + feature_key=feature_key, + feature_summary=feature_summary, + linked_tickets=list(tickets), + total_input_tokens=total_in, + total_output_tokens=total_out, + total_duration=total_duration, + tickets_completed=tickets_completed, + tickets_in_progress=tickets_in_progress, + completion_percentage=completion_pct, + ) + + +async def _group_by_feature( + tickets: list[TicketSummary], + jira: JiraClient, +) -> dict[str, FeatureRollup]: + """Group tickets by their parent Feature and return per-Feature rollups. + + For each ticket: + + * If the ticket can be resolved to a Feature via the Jira hierarchy, + it is placed in that Feature's rollup. + * Otherwise it is placed under the ``UNASSIGNED_FEATURE_KEY`` sentinel. + + Feature summaries are fetched from Jira for each resolved Feature key. + The ``UNASSIGNED_FEATURE_KEY`` group always has an empty ``feature_summary``. + + Args: + tickets: The ticket summaries to group. + jira: An open :class:`JiraClient` used for hierarchy resolution. + + Returns: + A dict mapping Feature key (or ``UNASSIGNED_FEATURE_KEY``) to a + :class:`FeatureRollup`. Returns an empty dict when *tickets* is empty. + """ + if not tickets: + return {} + + # Map each ticket to its resolved feature key (or None → Unassigned) + groups: dict[str, list[TicketSummary]] = {} + feature_summaries: dict[str, str] = {} + + for ticket in tickets: + feature_key = await _resolve_feature_key(ticket, jira) + bucket = feature_key if feature_key is not None else UNASSIGNED_FEATURE_KEY + groups.setdefault(bucket, []).append(ticket) + + # Fetch the Feature summary once per unique key + if feature_key is not None and feature_key not in feature_summaries: + with contextlib.suppress(Exception): + feature_issue = await jira.get_issue(feature_key) + feature_summaries[feature_key] = feature_issue.summary + + result: dict[str, FeatureRollup] = {} + for bucket_key, bucket_tickets in groups.items(): + summary = feature_summaries.get(bucket_key, "") + result[bucket_key] = _build_feature_rollup(bucket_key, summary, bucket_tickets) + + return result + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +async def collect_weekly_data( + project: str, + days: int = 7, + *, + jira_client: JiraClient | None = None, +) -> WeeklyReportData: + """Collect and aggregate workflow statistics for a project over a time window. + + Scans all Redis keys matching ``langgraph:checkpoint:{project}-*``, reads + each checkpoint's serialised state, filters to entries whose activity falls + within the last *days* days, and aggregates the results into a + :class:`WeeklyReportData`. + + Args: + project: The Jira project key to filter checkpoints (e.g. ``"AISOS"``). + The scan pattern is ``langgraph:checkpoint:{project}-*``. + days: Number of days to look back from *now* (default: 7). + jira_client: Optional :class:`JiraClient` instance to use for Feature + hierarchy resolution. When ``None`` a new client is created and + closed automatically. Pass an explicit client in tests to avoid + real HTTP calls. + + Returns: + A fully populated :class:`WeeklyReportData`. If no matching + checkpoints exist, the report contains zero-value aggregates. + """ + now = datetime.now(UTC) + cutoff = now - timedelta(days=days) + report_end = now.isoformat() + report_start = cutoff.isoformat() + + pattern = f"{_CHECKPOINT_KEY_PREFIX}{project}-*" + logger.info( + "Collecting weekly report for project=%s, days=%d, pattern=%s", + project, + days, + pattern, + ) + + redis_client = await get_redis_client() + all_tickets: list[TicketSummary] = [] + + try: + cursor = 0 + scanned_keys: list[str] = [] + + while True: + cursor, keys = await redis_client.scan(cursor=cursor, match=pattern, count=100) + scanned_keys.extend(keys) + if cursor == 0: + break + + logger.debug("Found %d checkpoint keys for project=%s", len(scanned_keys), project) + + unique_ticket_keys: set[str] = set() + for key in scanned_keys: + if key.startswith(_CHECKPOINT_KEY_PREFIX): + remaining = key[len(_CHECKPOINT_KEY_PREFIX) :] + ticket_key = remaining.split(":", 1)[0] + unique_ticket_keys.add(ticket_key) + + for ticket_key in sorted(unique_ticket_keys): + try: + state = await get_checkpoint_state(ticket_key) + if state is None: + continue + if not isinstance(state, dict): + logger.debug( + "Unexpected checkpoint value type for ticket %s; skipping", ticket_key + ) + continue + + # Filter by time window + if not _is_within_window(state, cutoff): + logger.debug("Checkpoint for %s outside reporting window; skipping", ticket_key) + continue + + ticket = _parse_checkpoint_stats(state) + if ticket is not None: + all_tickets.append(ticket) + + except Exception as exc: # noqa: BLE001 + logger.warning( + "Unexpected error reading checkpoint for ticket %s: %s", ticket_key, exc + ) + + except Exception as exc: # noqa: BLE001 + logger.error("Failed to scan Redis for project=%s: %s", project, exc) + + # --- Categorise tickets --- + completed = [t for t in all_tickets if t.status == "completed"] + in_progress = [t for t in all_tickets if t.status == "in_progress"] + blocked = [t for t in all_tickets if t.status == "blocked"] + + # --- Aggregate tokens --- + total_in, total_out, tokens_by_stage = _aggregate_tokens(all_tickets) + + # --- Average cycle time (completed tickets only) --- + avg_ct = _avg_cycle_time(all_tickets) + + # --- Bottleneck analysis --- + bottlenecks = _calculate_bottlenecks(all_tickets) + + # --- Per-Feature rollup --- + _owns_jira_client = False + active_jira_client: JiraClient + if jira_client is None: + active_jira_client = JiraClient() + _owns_jira_client = True + else: + active_jira_client = jira_client + try: + feature_rollups = await _group_by_feature(all_tickets, active_jira_client) + except Exception as exc: # noqa: BLE001 + logger.error("Failed to build feature rollups: %s", exc) + feature_rollups = {} + finally: + if _owns_jira_client: + await active_jira_client.close() + + report = WeeklyReportData( + project=project, + period_days=days, + report_start=report_start, + report_end=report_end, + completed_tickets=completed, + in_progress_tickets=in_progress, + blocked_tickets=blocked, + total_input_tokens=total_in, + total_output_tokens=total_out, + tokens_by_stage=tokens_by_stage, + avg_cycle_time=avg_ct, + bottlenecks=bottlenecks, + all_tickets=all_tickets, + feature_rollups=feature_rollups, + ) + + logger.info( + "Weekly report for project=%s: completed=%d, in_progress=%d, blocked=%d, total_tokens=%d", + project, + len(completed), + len(in_progress), + len(blocked), + total_in + total_out, + ) + + return report + + +__all__ = [ + "BottleneckAnalysis", + "FeatureRollup", + "TicketSummary", + "UNASSIGNED_FEATURE_KEY", + "WeeklyReportData", + "collect_weekly_data", +] diff --git a/src/forge/workflow/stats_utils.py b/src/forge/workflow/stats_utils.py new file mode 100644 index 00000000..0035173d --- /dev/null +++ b/src/forge/workflow/stats_utils.py @@ -0,0 +1,220 @@ +"""Utility functions for recording workflow execution statistics. + +These helpers are called by workflow nodes to update stats fields in the +LangGraph state. Every function returns a dict suitable for merging into +the state via LangGraph's reducer (partial state updates). + +All timestamps are UTC ISO-8601 strings (e.g. "2024-01-01T12:00:00.000000+00:00"). +""" + +from datetime import UTC, datetime +from typing import Any + + +def _utc_now() -> str: + """Return the current UTC time as an ISO-8601 string.""" + return datetime.now(UTC).isoformat() + + +def _get_stage(state: dict[str, Any], stage_name: str) -> dict[str, Any]: + """Return a copy of the stage entry, or a zeroed default if absent.""" + stages: dict[str, Any] = state.get("stage_timestamps") or {} + existing = stages.get(stage_name) + if existing is None: + return { + "stage_name": stage_name, + "iteration_count": 0, + "machine_time_seconds": 0.0, + "input_tokens": 0, + "output_tokens": 0, + "started_at": None, + "ended_at": None, + "model_name": None, + } + # Return a shallow copy so callers can mutate freely + return dict(existing) + + +def record_stage_start( + state: dict[str, Any], + stage_name: str, + model_name: str | None = None, +) -> dict[str, Any]: + """Initialize a stage entry in stats_stages with a started_at timestamp. + + If the stage already exists (e.g. a retry), the started_at timestamp is + updated to now but accumulated metrics are preserved. iteration_count is + left as-is; call :func:`increment_revision` to bump it. + + Args: + state: Current workflow state dict. + stage_name: Name of the stage being started (e.g. ``"implement"``). + model_name: Optional name of the LLM model used in this stage + (e.g. ``"claude-sonnet-4-5@20250929"``). Pass ``None`` for stages + that do not invoke an LLM (e.g. CI, review). + + Returns: + Partial state update dict with ``stage_timestamps`` key. + """ + stages: dict[str, Any] = dict(state.get("stage_timestamps") or {}) + stage = _get_stage(state, stage_name) + stage["started_at"] = _utc_now() + stage["ended_at"] = None # reset end marker when re-entering + if model_name is not None: + stage["model_name"] = model_name + stages[stage_name] = stage + return {"stage_timestamps": stages} + + +def record_stage_end( + state: dict[str, Any], + stage_name: str, + machine_time: float, +) -> dict[str, Any]: + """Mark a stage as ended and accumulate time metrics. + + Time values are *accumulated* (not replaced) so that repeated calls for + the same stage (e.g. after retries) add up correctly. + + Args: + state: Current workflow state dict. + stage_name: Name of the stage that has finished. + machine_time: Wall-clock seconds of automated work to add. + + Returns: + Partial state update dict with ``stage_timestamps`` key. + """ + stages: dict[str, Any] = dict(state.get("stage_timestamps") or {}) + stage = _get_stage(state, stage_name) + stage["ended_at"] = _utc_now() + stage["machine_time_seconds"] = stage.get("machine_time_seconds", 0.0) + machine_time + stages[stage_name] = stage + return {"stage_timestamps": stages} + + +def record_tokens( + state: dict[str, Any], + stage_name: str, + input_tokens: int, + output_tokens: int, +) -> dict[str, Any]: + """Accumulate LLM token counts for a stage. + + Tokens are *accumulated* (not replaced) so that multiple LLM calls within + the same stage all contribute to the total. + + Args: + state: Current workflow state dict. + stage_name: Name of the stage consuming tokens. + input_tokens: Number of prompt tokens to add. + output_tokens: Number of completion tokens to add. + + Returns: + Partial state update dict with ``stage_timestamps``, ``stage_token_usage``, + and ``token_usage`` keys. + """ + stages: dict[str, Any] = dict(state.get("stage_timestamps") or {}) + stage = _get_stage(state, stage_name) + stage["input_tokens"] = stage.get("input_tokens", 0) + input_tokens + stage["output_tokens"] = stage.get("output_tokens", 0) + output_tokens + stages[stage_name] = stage + + # Update per-stage token usage map + stage_token_usage: dict[str, Any] = dict(state.get("stage_token_usage") or {}) + existing_stage_tokens = stage_token_usage.get(stage_name) or {} + stage_token_usage[stage_name] = { + "input_tokens": (existing_stage_tokens.get("input_tokens") or 0) + input_tokens, + "output_tokens": (existing_stage_tokens.get("output_tokens") or 0) + output_tokens, + } + + # Update aggregate token usage + agg: dict[str, Any] = dict(state.get("token_usage") or {}) + agg["input_tokens"] = (agg.get("input_tokens") or 0) + input_tokens + agg["output_tokens"] = (agg.get("output_tokens") or 0) + output_tokens + + return { + "stage_timestamps": stages, + "stage_token_usage": stage_token_usage, + "token_usage": agg, + } + + +def increment_revision(state: dict[str, Any], stage_name: str) -> dict[str, Any]: + """Increment the iteration_count for a stage by 1. + + Should be called each time a stage is re-entered due to a revision + request or retry. + + Args: + state: Current workflow state dict. + stage_name: Name of the stage being revised. + + Returns: + Partial state update dict with ``stage_timestamps`` and + ``revision_counts`` keys. + """ + stages: dict[str, Any] = dict(state.get("stage_timestamps") or {}) + stage = _get_stage(state, stage_name) + new_count = stage.get("iteration_count", 0) + 1 + stage["iteration_count"] = new_count + stages[stage_name] = stage + + revision_counts: dict[str, Any] = dict(state.get("revision_counts") or {}) + revision_counts[stage_name] = new_count + + return { + "stage_timestamps": stages, + "revision_counts": revision_counts, + } + + +def increment_ci_cycle(state: dict[str, Any]) -> dict[str, Any]: + """Increment the workflow-level CI fix-attempt cycle counter by 1. + + Args: + state: Current workflow state dict. + + Returns: + Partial state update dict with ``stats_ci_cycles`` key. + """ + current: int = state.get("stats_ci_cycles") or 0 + return {"stats_ci_cycles": current + 1} + + +def add_pr_url(state: dict[str, Any], pr_url: str) -> dict[str, Any]: + """Append a PR URL to stats_pr_urls (idempotent — no duplicates). + + Args: + state: Current workflow state dict. + pr_url: The pull-request URL to record. + + Returns: + Partial state update dict with ``stats_pr_urls`` key. + """ + existing: list[str] = list(state.get("stats_pr_urls") or []) + if pr_url not in existing: + existing.append(pr_url) + return {"stats_pr_urls": existing} + + +def set_outcome(_state: dict[str, Any], outcome: str, reason: str | None = None) -> dict[str, Any]: + """Set the workflow outcome and optional reason. + + Conventional outcome values: + - ``"Completed"`` — finished successfully. + - ``"Blocked: "`` — waiting on an external blocker. + - ``"Failed: "`` — terminated due to an unrecoverable error. + + Args: + _state: Current workflow state dict (unused — outcome is set unconditionally). + outcome: Outcome string to record. + reason: Optional human-readable elaboration (e.g. blocking reason). + + Returns: + Partial state update dict with ``workflow_outcome`` and + ``stats_outcome_reason`` keys. + """ + return { + "workflow_outcome": outcome, + "stats_outcome_reason": reason, + } diff --git a/src/forge/workflow/utils/__init__.py b/src/forge/workflow/utils/__init__.py index da6a659c..ccf81cf2 100644 --- a/src/forge/workflow/utils/__init__.py +++ b/src/forge/workflow/utils/__init__.py @@ -30,7 +30,12 @@ "rebase_pr": "rebase_pr", } -_TERMINAL_NODES: frozenset[str] = frozenset({"complete"}) +_TERMINAL_NODES: frozenset[str] = frozenset( + { + "complete", + "post_terminal_stats", + } +) def resolve_shared_resume_node(current_node: str) -> str | None: diff --git a/tests/conftest.py b/tests/conftest.py index c20c4c47..edb464c2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,26 @@ from forge.main import app +@pytest.fixture(autouse=True) +def _ensure_add_structured_comment_is_async_mock(monkeypatch): + """Automatically ensure add_structured_comment is always an AsyncMock in any MagicMock. + + This acts as a global fallback for any test that manually mocks JiraClient without + fully defining all required methods. + """ + from unittest.mock import MagicMock, AsyncMock + original_getattr = MagicMock.__getattr__ + + def custom_getattr(self, name): + if name == "add_structured_comment": + am = AsyncMock() + self.__dict__[name] = am + return am + return original_getattr(self, name) + + monkeypatch.setattr(MagicMock, "__getattr__", custom_getattr) + + @pytest.fixture def mock_settings() -> Settings: """Create mock settings for testing.""" @@ -63,6 +83,7 @@ def mock_jira_client() -> Generator[MagicMock, None, None]: mock.create_task = AsyncMock(return_value="TEST-125") mock.delete_issue = AsyncMock() mock.add_comment = AsyncMock() + mock.add_structured_comment = AsyncMock() mock.close = AsyncMock() yield mock diff --git a/tests/integration/orchestrator/test_local_review_status_comments.py b/tests/integration/orchestrator/test_local_review_status_comments.py index f7da13b8..4c19b0e6 100644 --- a/tests/integration/orchestrator/test_local_review_status_comments.py +++ b/tests/integration/orchestrator/test_local_review_status_comments.py @@ -129,7 +129,10 @@ def track_comment(ticket_key, message): with ( patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner_pass1), + patch( + "forge.workflow.nodes.local_reviewer.ContainerRunner", + return_value=mock_runner_pass1, + ), patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), ): state = await local_review_changes(state) @@ -139,61 +142,28 @@ def track_comment(ticket_key, message): with ( patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner_pass2), + patch( + "forge.workflow.nodes.local_reviewer.ContainerRunner", + return_value=mock_runner_pass2, + ), patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), ): state = await local_review_changes(state) - # Pass 3: no unfixed issues, should post fix comment with pass 3 and route to create_pr - # Note: MAX_REVIEW_ATTEMPTS is 2, so pass 3 would be the final attempt - # We need to test the scenario where it succeeds on the last attempt - mock_runner_pass3 = create_mock_container_runner(has_unfixed_issues=False) - - with ( - patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner_pass3), - patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), - ): - result = await local_review_changes(state) - - # Verify all comments were posted: initial + fix(2) + fix(3) - # Note: Only 2 comments will be posted because MAX_REVIEW_ATTEMPTS=2 - # Pass 1: initial comment, Pass 2: fix comment (pass 2) - # Pass 3 would exceed max attempts, so it doesn't run the container - # Let me reconsider the test scenario based on MAX_REVIEW_ATTEMPTS=2 - - # With MAX_REVIEW_ATTEMPTS=2: - # Pass 1 (attempt 0): initial comment, finds issues, increments to attempt 1, pass 2 - # Pass 2 (attempt 1): fix comment (pass 2), finds no issues OR hits max attempts - - # For a 3-comment scenario (initial + 2 fix comments), we need: - # Pass 1: initial, finds issues -> retry - # Pass 2: fix (pass 2), finds issues -> retry - # Pass 3: Would be attempt 2 which equals MAX_REVIEW_ATTEMPTS, so it runs one more time - - # Actually reviewing the code: review_attempts + 1 < MAX_REVIEW_ATTEMPTS - # So with MAX_REVIEW_ATTEMPTS=2: - # - attempt 0: runs, if issues and 0+1 < 2, retry (yes) - # - attempt 1: runs, if issues and 1+1 < 2, retry (no, 2 is not < 2) - - # So we can only get 2 passes max with MAX_REVIEW_ATTEMPTS=2 - # Pass 1 (attempt 0): initial comment - # Pass 2 (attempt 1): fix comment (pass 2) - - # For TS-005 to work as specified (3 fix passes), I need to adjust the test - # or acknowledge that MAX_REVIEW_ATTEMPTS limits this - - # Let me verify what comments were actually posted + # Verify all comments were posted: initial + fix(2) assert len(all_comments) == 2 # Initial + fix(pass 2) - + # Verify initial comment assert all_comments[0][0] == "FEAT-201" assert all_comments[0][1] == "🔍 Running local code review on changes before creating PR." - + # Verify fix comment with pass 2 assert all_comments[1][0] == "FEAT-201" assert all_comments[1][1] == "🔧 Local review found issues, applying fixes (pass 2)." + # Verify workflow routed to create_pr + assert state["current_node"] == "create_pr" + @pytest.mark.asyncio async def test_three_pass_scenario_with_max_attempts_override(self): """TS-005: Verify 3-pass scenario by temporarily overriding MAX_REVIEW_ATTEMPTS.""" @@ -225,7 +195,10 @@ def track_comment(ticket_key, message): with ( patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner_pass1), + patch( + "forge.workflow.nodes.local_reviewer.ContainerRunner", + return_value=mock_runner_pass1, + ), patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), ): state = await local_review_changes(state) @@ -235,7 +208,10 @@ def track_comment(ticket_key, message): with ( patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner_pass2), + patch( + "forge.workflow.nodes.local_reviewer.ContainerRunner", + return_value=mock_runner_pass2, + ), patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), ): state = await local_review_changes(state) @@ -245,22 +221,25 @@ def track_comment(ticket_key, message): with ( patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner_pass3), + patch( + "forge.workflow.nodes.local_reviewer.ContainerRunner", + return_value=mock_runner_pass3, + ), patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), ): result = await local_review_changes(state) # Verify all comments were posted: initial + fix(2) + fix(3) assert len(all_comments) == 3 - + # Verify initial comment assert all_comments[0][0] == "FEAT-202" assert all_comments[0][1] == "🔍 Running local code review on changes before creating PR." - + # Verify fix comment with pass 2 assert all_comments[1][0] == "FEAT-202" assert all_comments[1][1] == "🔧 Local review found issues, applying fixes (pass 2)." - + # Verify fix comment with pass 3 assert all_comments[2][0] == "FEAT-202" assert all_comments[2][1] == "🔧 Local review found issues, applying fixes (pass 3)." @@ -307,23 +286,31 @@ def track_comment(ticket_key, message): with ( patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner), - patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), + patch( + "forge.workflow.nodes.local_reviewer.ContainerRunner", + return_value=mock_runner, + ), + patch( + "forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git + ), ): state = await local_review_changes(state) # Verify all comments were posted: initial + fix(2) + fix(3) + fix(4) + fix(5) + fix(6) assert len(all_comments) == 6 - + # Verify initial comment assert all_comments[0][0] == "FEAT-203" assert all_comments[0][1] == "🔍 Running local code review on changes before creating PR." - + # Verify fix comments with incrementing pass numbers for i in range(1, 6): pass_num = i + 1 assert all_comments[i][0] == "FEAT-203" - assert all_comments[i][1] == f"🔧 Local review found issues, applying fixes (pass {pass_num})." + assert ( + all_comments[i][1] + == f"🔧 Local review found issues, applying fixes (pass {pass_num})." + ) # Verify workflow routed to create_pr assert state["current_node"] == "create_pr" @@ -363,7 +350,7 @@ async def test_pass_number_resets_when_transitioning_from_implementation_to_loca ): mock_git = create_mock_git_operations(has_changes=False) mock_git_class.return_value = mock_git - + result = await implement_task(state) # Verify pass_number was reset to 1 when entering local_review phase @@ -405,7 +392,10 @@ async def test_pass_number_persists_and_increments_within_same_feature(self): with ( patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner_pass1), + patch( + "forge.workflow.nodes.local_reviewer.ContainerRunner", + return_value=mock_runner_pass1, + ), patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), ): state = await local_review_changes(state) @@ -420,7 +410,10 @@ async def test_pass_number_persists_and_increments_within_same_feature(self): with ( patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner_pass2), + patch( + "forge.workflow.nodes.local_reviewer.ContainerRunner", + return_value=mock_runner_pass2, + ), patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), ): result = await local_review_changes(state) @@ -448,13 +441,18 @@ async def test_pass_number_increments_correctly_across_multiple_iterations(self) # Passes 1-3: have unfixed issues for expected_pass_num in [1, 2, 3]: assert state["local_review_pass_number"] == expected_pass_num - + mock_runner = create_mock_container_runner(has_unfixed_issues=True) with ( patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner), - patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), + patch( + "forge.workflow.nodes.local_reviewer.ContainerRunner", + return_value=mock_runner, + ), + patch( + "forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git + ), ): state = await local_review_changes(state) @@ -468,7 +466,9 @@ async def test_pass_number_increments_correctly_across_multiple_iterations(self) with ( patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner), + patch( + "forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner + ), patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), ): result = await local_review_changes(state) diff --git a/tests/integration/orchestrator/test_task_handoff.py b/tests/integration/orchestrator/test_task_handoff.py index c4c36ce1..05cc12d7 100644 --- a/tests/integration/orchestrator/test_task_handoff.py +++ b/tests/integration/orchestrator/test_task_handoff.py @@ -41,7 +41,7 @@ async def test_workspace_setup_creates_forge_directory(self): async def test_workspace_setup_node_creates_forge_directory(self): """The setup_workspace node should create .forge directory structure.""" - from forge.orchestrator.nodes import setup_workspace + from forge.workflow.nodes.workspace_setup import setup_workspace initial_state = create_initial_state( thread_id="TEST-123", @@ -50,14 +50,17 @@ async def test_workspace_setup_node_creates_forge_directory(self): ) initial_state["tasks_by_repo"] = {"test-org/test-repo": ["TASK-1", "TASK-2"]} - with patch("forge.workflow.nodes.workspace_setup.GitOperations") as MockGit, \ - patch("forge.workflow.nodes.workspace_setup.GuardrailsLoader") as MockGuardrails: - + with ( + patch("forge.workflow.nodes.workspace_setup.GitOperations") as MockGit, + patch("forge.workflow.nodes.workspace_setup.GuardrailsLoader") as MockGuardrails, + ): mock_git = MagicMock() MockGit.return_value = mock_git mock_guardrails = MagicMock() - mock_guardrails.load.return_value = MagicMock(get_system_context=MagicMock(return_value="")) + mock_guardrails.load.return_value = MagicMock( + get_system_context=MagicMock(return_value="") + ) MockGuardrails.return_value = mock_guardrails result = await setup_workspace(initial_state) @@ -66,7 +69,9 @@ async def test_workspace_setup_node_creates_forge_directory(self): if result.get("workspace_path"): workspace_path = Path(result["workspace_path"]) assert (workspace_path / ".forge").exists(), ".forge should be created" - assert (workspace_path / ".forge" / "history").exists(), ".forge/history should be created" + assert (workspace_path / ".forge" / "history").exists(), ( + ".forge/history should be created" + ) class TestPreviousTaskKeysPassing: @@ -80,9 +85,10 @@ async def test_runner_passes_previous_task_keys_in_task_file(self): workspace = Path(workspace_dir) # Mock podman and settings - with patch("forge.sandbox.runner.shutil.which", return_value="/usr/bin/podman"), \ - patch("forge.sandbox.runner.get_settings") as mock_settings: - + with ( + patch("forge.sandbox.runner.shutil.which", return_value="/usr/bin/podman"), + patch("forge.sandbox.runner.get_settings") as mock_settings, + ): settings = MagicMock() settings.anthropic_api_key.get_secret_value.return_value = "test-key" settings.use_vertex_ai = False @@ -96,9 +102,10 @@ async def test_runner_passes_previous_task_keys_in_task_file(self): runner = ContainerRunner(settings) # Mock the actual run to just create the task file - with patch.object(runner, "_build_podman_command", return_value=["echo", "test"]), \ - patch("asyncio.create_subprocess_exec") as mock_exec: - + with ( + patch.object(runner, "_build_podman_command", return_value=["echo", "test"]), + patch("asyncio.create_subprocess_exec") as mock_exec, + ): mock_process = AsyncMock() mock_process.communicate = AsyncMock(return_value=(b"", b"")) mock_process.returncode = 0 @@ -118,8 +125,8 @@ async def test_runner_passes_previous_task_keys_in_task_file(self): async def test_implementation_node_passes_implemented_tasks(self): """Implementation node should pass implemented_tasks as previous_task_keys.""" - from forge.orchestrator.nodes import implement_task from forge.workflow.feature.state import FeatureState as WorkflowState + from forge.workflow.nodes.implementation import implement_task with tempfile.TemporaryDirectory() as workspace_dir: state: WorkflowState = { @@ -133,10 +140,11 @@ async def test_implementation_node_passes_implemented_tasks(self): "context": {"guardrails": ""}, } - with patch("forge.workflow.nodes.implementation.JiraClient") as MockJira, \ - patch("forge.workflow.nodes.implementation.ContainerRunner") as MockRunner, \ - patch("forge.workflow.nodes.implementation.get_settings") as mock_settings: - + with ( + patch("forge.workflow.nodes.implementation.JiraClient") as MockJira, + patch("forge.workflow.nodes.implementation.ContainerRunner") as MockRunner, + patch("forge.workflow.nodes.implementation.get_settings") as mock_settings, + ): # Setup mocks mock_jira = MagicMock() mock_jira.get_issue = AsyncMock( @@ -149,9 +157,7 @@ async def test_implementation_node_passes_implemented_tasks(self): MockJira.return_value = mock_jira mock_runner = MagicMock() - mock_runner.run = AsyncMock( - return_value=MagicMock(success=True, exit_code=0) - ) + mock_runner.run = AsyncMock(return_value=MagicMock(success=True, exit_code=0)) MockRunner.return_value = mock_runner mock_settings.return_value = MagicMock() @@ -178,8 +184,9 @@ def test_container_system_prompt_includes_handoff_instructions(self): assert ".forge/history/" in prompt, "Prompt should reference history directory" # Check for handoff writing instructions - assert "Update handoff" in prompt or "update `.forge/handoff.md`" in prompt, \ + assert "Update handoff" in prompt or "update `.forge/handoff.md`" in prompt, ( "Prompt should instruct agent to update handoff" + ) def test_entrypoint_builds_prompt_with_previous_task_keys(self): """Entrypoint build_system_prompt should include previous task keys.""" @@ -228,8 +235,9 @@ def test_entrypoint_handles_empty_previous_tasks(self): ) # Should indicate this is the first task - assert "first task" in prompt.lower() or "none" in prompt.lower(), \ + assert "first task" in prompt.lower() or "none" in prompt.lower(), ( "Prompt should indicate no previous tasks" + ) finally: sys.path.remove(str(containers_path)) @@ -301,8 +309,9 @@ def test_container_prompt_includes_gitignore_instructions(self): # Prompt should warn against committing .forge/ (using "NEVER commit" wording) assert ".forge/" in prompt, "Prompt should mention .forge/ directory" - assert "NEVER commit" in prompt or "never commit" in prompt.lower(), \ + assert "NEVER commit" in prompt or "never commit" in prompt.lower(), ( "Prompt should warn against committing .forge/" + ) class TestHistoryPersistence: diff --git a/tests/integration/orchestrator/test_task_implementation_status.py b/tests/integration/orchestrator/test_task_implementation_status.py index 76060b86..b1e7de9a 100644 --- a/tests/integration/orchestrator/test_task_implementation_status.py +++ b/tests/integration/orchestrator/test_task_implementation_status.py @@ -76,7 +76,9 @@ async def test_single_task_receives_start_comment(self): assert mock_jira.add_comment.call_count >= 1 start_call = mock_jira.add_comment.call_args_list[0] assert start_call[0][0] == "TASK-001" - assert start_call[0][1] == "🔨 Forge is implementing this task." + assert ( + start_call[0][1] == "🔨 Forge started implementing [TASK-001]: Task summary for testing" + ) @pytest.mark.asyncio async def test_single_task_receives_completion_comment_on_success(self): @@ -105,12 +107,17 @@ async def test_single_task_receives_completion_comment_on_success(self): # Verify start comment start_call = mock_jira.add_comment.call_args_list[0] assert start_call[0][0] == "TASK-001" - assert start_call[0][1] == "🔨 Forge is implementing this task." + assert ( + start_call[0][1] == "🔨 Forge started implementing [TASK-001]: Task summary for testing" + ) # Verify completion comment with exact text completion_call = mock_jira.add_comment.call_args_list[1] assert completion_call[0][0] == "TASK-001" - assert completion_call[0][1] == "✅ Implementation complete. Running local code review before PR." + assert ( + completion_call[0][1] + == "✅ Implementation complete. Running local code review before PR." + ) # Verify task was marked as implemented assert "TASK-001" in result["implemented_tasks"] @@ -119,7 +126,9 @@ async def test_single_task_receives_completion_comment_on_success(self): async def test_single_task_no_completion_comment_on_failure(self): """TS-003: Verify NO completion comment when task implementation fails.""" mock_jira = create_mock_jira_client() - mock_runner = create_mock_container_runner(success=False, error_message="Implementation error") + mock_runner = create_mock_container_runner( + success=False, error_message="Implementation error" + ) state = create_initial_feature_state( ticket_key="FEAT-100", @@ -141,7 +150,9 @@ async def test_single_task_no_completion_comment_on_failure(self): assert mock_jira.add_comment.call_count == 1 start_call = mock_jira.add_comment.call_args_list[0] assert start_call[0][0] == "TASK-001" - assert start_call[0][1] == "🔨 Forge is implementing this task." + assert ( + start_call[0][1] == "🔨 Forge started implementing [TASK-001]: Task summary for testing" + ) # Verify error state assert result["last_error"] == "Implementation error" @@ -176,7 +187,10 @@ async def test_multiple_tasks_receive_independent_start_comments(self): # Verify first task got start and completion comments with correct task_key assert mock_jira1.add_comment.call_count == 2 assert mock_jira1.add_comment.call_args_list[0][0][0] == "TASK-100" - assert mock_jira1.add_comment.call_args_list[0][0][1] == "🔨 Forge is implementing this task." + assert ( + mock_jira1.add_comment.call_args_list[0][0][1] + == "🔨 Forge started implementing [TASK-100]: Task summary for testing" + ) assert mock_jira1.add_comment.call_args_list[1][0][0] == "TASK-100" # Reset mock for second task @@ -191,12 +205,15 @@ async def test_multiple_tasks_receive_independent_start_comments(self): patch("forge.workflow.nodes.implementation.JiraClient", return_value=mock_jira2), patch("forge.workflow.nodes.implementation.ContainerRunner", return_value=mock_runner2), ): - result2 = await implement_task(state2) + await implement_task(state2) # Verify second task got its own independent start and completion comments assert mock_jira2.add_comment.call_count == 2 assert mock_jira2.add_comment.call_args_list[0][0][0] == "TASK-101" - assert mock_jira2.add_comment.call_args_list[0][0][1] == "🔨 Forge is implementing this task." + assert ( + mock_jira2.add_comment.call_args_list[0][0][1] + == "🔨 Forge started implementing [TASK-101]: Task summary for testing" + ) assert mock_jira2.add_comment.call_args_list[1][0][0] == "TASK-101" @pytest.mark.asyncio @@ -226,8 +243,14 @@ async def test_multiple_tasks_receive_independent_completion_comments(self): call for call in mock_jira1.add_comment.call_args_list if call[0][0] == "TASK-200" ] assert len(task200_calls) == 2 - assert task200_calls[0][0][1] == "🔨 Forge is implementing this task." - assert task200_calls[1][0][1] == "✅ Implementation complete. Running local code review before PR." + assert ( + task200_calls[0][0][1] + == "🔨 Forge started implementing [TASK-200]: Task summary for testing" + ) + assert ( + task200_calls[1][0][1] + == "✅ Implementation complete. Running local code review before PR." + ) # Second task mock_jira2 = create_mock_jira_client() @@ -247,8 +270,14 @@ async def test_multiple_tasks_receive_independent_completion_comments(self): call for call in mock_jira2.add_comment.call_args_list if call[0][0] == "TASK-201" ] assert len(task201_calls) == 2 - assert task201_calls[0][0][1] == "🔨 Forge is implementing this task." - assert task201_calls[1][0][1] == "✅ Implementation complete. Running local code review before PR." + assert ( + task201_calls[0][0][1] + == "🔨 Forge started implementing [TASK-201]: Task summary for testing" + ) + assert ( + task201_calls[1][0][1] + == "✅ Implementation complete. Running local code review before PR." + ) # Third task mock_jira3 = create_mock_jira_client() @@ -268,8 +297,14 @@ async def test_multiple_tasks_receive_independent_completion_comments(self): call for call in mock_jira3.add_comment.call_args_list if call[0][0] == "TASK-202" ] assert len(task202_calls) == 2 - assert task202_calls[0][0][1] == "🔨 Forge is implementing this task." - assert task202_calls[1][0][1] == "✅ Implementation complete. Running local code review before PR." + assert ( + task202_calls[0][0][1] + == "🔨 Forge started implementing [TASK-202]: Task summary for testing" + ) + assert ( + task202_calls[1][0][1] + == "✅ Implementation complete. Running local code review before PR." + ) # Verify all three tasks are marked as implemented assert result3["implemented_tasks"] == ["TASK-200", "TASK-201", "TASK-202"] @@ -304,7 +339,10 @@ async def test_task_implementation_fails_midway_no_completion_comment(self): # Verify only start comment, no completion comment assert mock_jira.add_comment.call_count == 1 assert mock_jira.add_comment.call_args_list[0][0][0] == "TASK-300" - assert mock_jira.add_comment.call_args_list[0][0][1] == "🔨 Forge is implementing this task." + assert ( + mock_jira.add_comment.call_args_list[0][0][1] + == "🔨 Forge started implementing [TASK-300]: Task summary for testing" + ) # Verify error is set and task not implemented assert "Container crashed" in result["last_error"] @@ -388,7 +426,8 @@ async def test_workflow_continues_when_start_comment_posting_fails(self, caplog) # Verify error was logged (from jira_status utility) assert any( - "Failed to post status comment to TASK-500" in record.message for record in caplog.records + "Failed to post status comment to TASK-500" in record.message + for record in caplog.records ) @pytest.mark.asyncio @@ -430,7 +469,8 @@ async def add_comment_side_effect(*args, **kwargs): # Verify error was logged assert any( - "Failed to post status comment to TASK-501" in record.message for record in caplog.records + "Failed to post status comment to TASK-501" in record.message + for record in caplog.records ) @pytest.mark.asyncio @@ -462,6 +502,8 @@ async def test_workflow_continues_when_all_comment_posting_fails(self, caplog): # Verify errors were logged for both start and completion attempts error_logs = [ - record for record in caplog.records if "Failed to post status comment to TASK-502" in record.message + record + for record in caplog.records + if "Failed to post status comment to TASK-502" in record.message ] assert len(error_logs) == 2 # Both start and completion comments should have logged errors diff --git a/tests/integration/test_qa_mode.py b/tests/integration/test_qa_mode.py index e1e4c64f..673c12ca 100644 --- a/tests/integration/test_qa_mode.py +++ b/tests/integration/test_qa_mode.py @@ -15,8 +15,8 @@ def test_question_comment_classified_correctly(self): """Verify comment classifier detects questions.""" assert classify_comment("?Why REST?") == CommentType.QUESTION assert classify_comment("@forge ask explain") == CommentType.QUESTION - assert classify_comment("Add more detail") == CommentType.FEEDBACK - assert classify_comment("LGTM") == CommentType.FEEDBACK + assert classify_comment("!Add more detail") == CommentType.FEEDBACK + assert classify_comment("LGTM") == CommentType.INFORMATIONAL def test_state_has_qa_fields(self): """Verify initial state includes Q&A fields.""" @@ -49,9 +49,11 @@ async def test_answer_question_node_posts_to_jira(self): mock_agent.answer_question = AsyncMock(return_value="Because of X") mock_agent.close = AsyncMock() - with patch("forge.workflow.nodes.qa_handler.JiraClient", return_value=mock_jira): - with patch("forge.workflow.nodes.qa_handler.ForgeAgent", return_value=mock_agent): - result = await answer_question(state) + with ( + patch("forge.workflow.nodes.qa_handler.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.qa_handler.ForgeAgent", return_value=mock_agent), + ): + result = await answer_question(state) # Verify Jira comment was posted mock_jira.add_comment.assert_called_once() @@ -187,9 +189,11 @@ async def test_answer_question_handles_agent_error(self): mock_agent.answer_question = AsyncMock(side_effect=Exception("API Error")) mock_agent.close = AsyncMock() - with patch("forge.workflow.nodes.qa_handler.JiraClient", return_value=mock_jira): - with patch("forge.workflow.nodes.qa_handler.ForgeAgent", return_value=mock_agent): - result = await answer_question(state) + with ( + patch("forge.workflow.nodes.qa_handler.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.qa_handler.ForgeAgent", return_value=mock_agent), + ): + result = await answer_question(state) # Should still clear question state and stay paused assert result["is_paused"] is True diff --git a/tests/integration/test_stats_commands.py b/tests/integration/test_stats_commands.py new file mode 100644 index 00000000..7f9f6556 --- /dev/null +++ b/tests/integration/test_stats_commands.py @@ -0,0 +1,1039 @@ +"""Integration tests for on-demand stats commands. + +These tests verify the end-to-end behavior of: +- /forge stats — Jira comment command (post current stats as a new comment) +- /forge stats retry — Jira comment command (re-post stats as final comment) +- forge stats — CLI command (table and JSON output) + +Each test scenario uses pytest fixtures that provide realistic mock checkpoint +state, then exercises the full command path from trigger to Jira comment +or stdout — mocking only the network boundary (JiraClient, checkpointer). +""" + +import argparse +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from forge.models.events import EventSource +from forge.orchestrator.worker import OrchestratorWorker +from forge.queue.models import QueueMessage + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + + +def _make_jira_message(ticket_key: str, comment_body: str) -> QueueMessage: + """Build a minimal Jira comment QueueMessage.""" + return QueueMessage( + message_id="9999999999-0", + event_id="integ-test-event-001", + source=EventSource.JIRA, + event_type="comment_created", + ticket_key=ticket_key, + payload={ + "issue": { + "key": ticket_key, + "fields": { + "issuetype": {"name": "Feature"}, + "labels": [], + }, + }, + "comment": {"body": comment_body}, + "changelog": {"items": []}, + }, + ) + + +def _make_mock_jira() -> MagicMock: + """Return a mock JiraClient with relevant async methods.""" + jira = MagicMock() + jira.add_comment = AsyncMock() + jira.close = AsyncMock() + jira.get_comments = AsyncMock(return_value=[]) + return jira + + +# --------------------------------------------------------------------------- +# Checkpoint fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def checkpoint_with_stats() -> dict: + """Checkpoint state containing populated stats data (PRD + Spec stages).""" + return { + "ticket_key": "INT-100", + "ticket_type": "Feature", + "current_node": "spec_approval_gate", + "is_paused": True, + "is_blocked": False, + "last_error": None, + "feedback_comment": None, + "context": {}, + "stage_timestamps": { + "prd": { + "stage_name": "prd", + "iteration_count": 2, + "machine_time_seconds": 45.0, + "human_time_seconds": 300.0, + "input_tokens": 1200, + "output_tokens": 2000, + "started_at": "2024-01-15T10:00:00+00:00", + "ended_at": "2024-01-15T10:00:45+00:00", + }, + "spec": { + "stage_name": "spec", + "iteration_count": 1, + "machine_time_seconds": 30.0, + "human_time_seconds": 180.0, + "input_tokens": 800, + "output_tokens": 1500, + "started_at": "2024-01-15T10:05:00+00:00", + "ended_at": "2024-01-15T10:05:30+00:00", + }, + }, + "stats_pr_urls": ["https://github.com/org/repo/pull/42"], + "stats_ci_cycles": 1, + "workflow_outcome": None, + "stats_outcome_reason": None, + "stats_comment_posted": False, + "workflow_run_id": "test-run-abc123", + } + + +@pytest.fixture +def checkpoint_without_stats_key() -> dict: + """Checkpoint state that has no stage_timestamps key (legacy workflow).""" + return { + "ticket_key": "INT-101", + "ticket_type": "Feature", + "current_node": "prd_approval_gate", + "is_paused": True, + "context": {}, + # Deliberately no stats_* keys — simulates pre-stats-tracking run + } + + +@pytest.fixture +def checkpoint_with_empty_stages() -> dict: + """Checkpoint state with stage_timestamps present but empty (workflow just started).""" + return { + "ticket_key": "INT-102", + "ticket_type": "Feature", + "current_node": "generate_prd", + "is_paused": False, + "is_blocked": False, + "last_error": None, + "context": {}, + "stage_timestamps": {}, # Present key, empty dict — in-progress workflow + "stats_pr_urls": [], + "stats_ci_cycles": 0, + "workflow_outcome": None, + "stats_outcome_reason": None, + "stats_comment_posted": False, + "workflow_run_id": "test-run-def456", + } + + +@pytest.fixture +def checkpoint_blocked() -> dict: + """Checkpoint state representing a blocked workflow.""" + return { + "ticket_key": "INT-103", + "ticket_type": "Feature", + "current_node": "escalate_blocked", + "is_paused": True, + "is_blocked": True, + "last_error": None, + "feedback_comment": "Requirements unclear — needs stakeholder input.", + "context": {}, + "stage_timestamps": { + "prd": { + "stage_name": "prd", + "iteration_count": 3, + "machine_time_seconds": 120.0, + "human_time_seconds": 600.0, + "input_tokens": 3000, + "output_tokens": 4000, + } + }, + "stats_pr_urls": [], + "stats_ci_cycles": 0, + "workflow_outcome": None, + "stats_outcome_reason": None, + "stats_comment_posted": False, + "workflow_run_id": "test-run-ghi789", + } + + +@pytest.fixture +def checkpoint_failed() -> dict: + """Checkpoint state representing a failed workflow.""" + return { + "ticket_key": "INT-104", + "ticket_type": "Feature", + "current_node": "generate_spec", + "is_paused": False, + "is_blocked": False, + "last_error": "LLM call timed out after 60 seconds", + "feedback_comment": None, + "context": {}, + "stage_timestamps": { + "prd": { + "stage_name": "prd", + "iteration_count": 1, + "machine_time_seconds": 60.0, + "human_time_seconds": 0.0, + "input_tokens": 1000, + "output_tokens": 1800, + } + }, + "stats_pr_urls": [], + "stats_ci_cycles": 0, + "workflow_outcome": None, + "stats_outcome_reason": None, + "stats_comment_posted": False, + "workflow_run_id": "test-run-jkl012", + } + + +@pytest.fixture +def checkpoint_completed() -> dict: + """Checkpoint state for a fully completed workflow.""" + return { + "ticket_key": "INT-105", + "ticket_type": "Feature", + "current_node": "aggregate_feature_status", + "is_paused": False, + "is_blocked": False, + "last_error": None, + "feedback_comment": None, + "context": {}, + "stage_timestamps": { + "prd": { + "stage_name": "prd", + "iteration_count": 1, + "machine_time_seconds": 40.0, + "human_time_seconds": 200.0, + "input_tokens": 1000, + "output_tokens": 1800, + }, + "spec": { + "stage_name": "spec", + "iteration_count": 1, + "machine_time_seconds": 30.0, + "human_time_seconds": 150.0, + "input_tokens": 900, + "output_tokens": 1600, + }, + "implementation": { + "stage_name": "implementation", + "iteration_count": 2, + "machine_time_seconds": 900.0, + "human_time_seconds": 0.0, + "input_tokens": 8000, + "output_tokens": 12000, + }, + }, + "stats_pr_urls": [ + "https://github.com/org/repo/pull/99", + ], + "stats_ci_cycles": 2, + "workflow_outcome": "Completed", + "stats_outcome_reason": None, + "stats_comment_posted": True, + "workflow_run_id": "test-run-mno345", + } + + +@pytest.fixture +def worker() -> OrchestratorWorker: + """OrchestratorWorker with a unique consumer name for isolation.""" + return OrchestratorWorker(consumer_name="integ-test-worker") + + +# --------------------------------------------------------------------------- +# Section 1: /forge stats — Jira comment command +# --------------------------------------------------------------------------- + + +class TestForgeStatsWithValidCheckpoint: + """/forge stats posts a formatted stats comment when checkpoint has data.""" + + @pytest.mark.asyncio + async def test_stats_comment_is_posted_to_jira( + self, worker: OrchestratorWorker, checkpoint_with_stats + ): + """/forge stats results in a call to JiraClient.add_comment.""" + message = _make_jira_message("INT-100", "/forge stats") + mock_jira = _make_mock_jira() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + result = await worker._handle_resume_event(message, checkpoint_with_stats) + + mock_jira.add_comment.assert_awaited_once() + assert result is checkpoint_with_stats, "State must be returned unchanged" + + @pytest.mark.asyncio + async def test_stats_comment_posted_to_correct_ticket( + self, worker: OrchestratorWorker, checkpoint_with_stats + ): + """/forge stats posts the comment to the correct Jira ticket key.""" + message = _make_jira_message("INT-100", "/forge stats") + mock_jira = _make_mock_jira() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, checkpoint_with_stats) + + call_args = mock_jira.add_comment.call_args + ticket_arg = call_args[0][0] + assert ticket_arg == "INT-100" + + @pytest.mark.asyncio + async def test_stats_comment_body_contains_stage_metrics( + self, worker: OrchestratorWorker, checkpoint_with_stats + ): + """Comment body includes stage-level metrics (PRD iterations visible).""" + message = _make_jira_message("INT-100", "/forge stats") + mock_jira = _make_mock_jira() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, checkpoint_with_stats) + + comment_body = mock_jira.add_comment.call_args[0][1] + # The Jira formatter produces a table; stage names appear as rows + assert "PRD" in comment_body or "prd" in comment_body + + @pytest.mark.asyncio + async def test_stats_comment_body_contains_outcome( + self, worker: OrchestratorWorker, checkpoint_with_stats + ): + """Comment body includes the derived outcome string.""" + message = _make_jira_message("INT-100", "/forge stats") + mock_jira = _make_mock_jira() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, checkpoint_with_stats) + + comment_body = mock_jira.add_comment.call_args[0][1] + # Outcome for an in-progress workflow is "In Progress" + assert "In Progress" in comment_body or "Outcome" in comment_body + + @pytest.mark.asyncio + async def test_jira_client_closed_after_posting( + self, worker: OrchestratorWorker, checkpoint_with_stats + ): + """JiraClient.close() is always called even on success.""" + message = _make_jira_message("INT-100", "/forge stats") + mock_jira = _make_mock_jira() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, checkpoint_with_stats) + + mock_jira.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_workflow_state_returned_unchanged( + self, worker: OrchestratorWorker, checkpoint_with_stats + ): + """/forge stats is read-only — returned state is the same object.""" + message = _make_jira_message("INT-100", "/forge stats") + mock_jira = _make_mock_jira() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + result = await worker._handle_resume_event(message, checkpoint_with_stats) + + assert result is checkpoint_with_stats + + @pytest.mark.asyncio + async def test_stats_derived_outcome_in_progress( + self, worker: OrchestratorWorker, checkpoint_with_stats + ): + """In-progress workflow (no outcome/blocked/error) → 'In Progress' outcome.""" + # Ensure no pre-set outcome, no blocked, no error + assert checkpoint_with_stats.get("workflow_outcome") is None + assert not checkpoint_with_stats.get("is_blocked") + assert checkpoint_with_stats.get("last_error") is None + + message = _make_jira_message("INT-100", "/forge stats") + mock_jira = _make_mock_jira() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, checkpoint_with_stats) + + comment_body = mock_jira.add_comment.call_args[0][1] + assert "In Progress" in comment_body + + +class TestForgeStatsWithBlockedWorkflow: + """/forge stats correctly reports a blocked workflow outcome.""" + + @pytest.mark.asyncio + async def test_blocked_outcome_reported(self, worker: OrchestratorWorker, checkpoint_blocked): + """Comment body contains 'Blocked' when workflow is_blocked=True.""" + message = _make_jira_message("INT-103", "/forge stats") + mock_jira = _make_mock_jira() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, checkpoint_blocked) + + comment_body = mock_jira.add_comment.call_args[0][1] + assert "Blocked" in comment_body + + @pytest.mark.asyncio + async def test_blocked_comment_posted_to_correct_ticket( + self, worker: OrchestratorWorker, checkpoint_blocked + ): + """Stats for blocked workflow are posted to the blocked ticket key.""" + message = _make_jira_message("INT-103", "/forge stats") + mock_jira = _make_mock_jira() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, checkpoint_blocked) + + ticket_arg = mock_jira.add_comment.call_args[0][0] + assert ticket_arg == "INT-103" + + +class TestForgeStatsWithFailedWorkflow: + """/forge stats correctly reports a failed workflow outcome.""" + + @pytest.mark.asyncio + async def test_failed_outcome_reported(self, worker: OrchestratorWorker, checkpoint_failed): + """Comment body contains 'Failed' when workflow has last_error.""" + message = _make_jira_message("INT-104", "/forge stats") + mock_jira = _make_mock_jira() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, checkpoint_failed) + + comment_body = mock_jira.add_comment.call_args[0][1] + assert "Failed" in comment_body + + @pytest.mark.asyncio + async def test_failed_comment_posted_once(self, worker: OrchestratorWorker, checkpoint_failed): + """Exactly one comment is posted for a failed workflow stats request.""" + message = _make_jira_message("INT-104", "/forge stats") + mock_jira = _make_mock_jira() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, checkpoint_failed) + + assert mock_jira.add_comment.call_count == 1 + + +# --------------------------------------------------------------------------- +# Section 2: /forge stats with missing checkpoint +# --------------------------------------------------------------------------- + + +class TestForgeStatsWithMissingCheckpoint: + """/forge stats posts a fallback message when no stats data exists.""" + + @pytest.mark.asyncio + async def test_missing_stage_timestamps_key_posts_no_data_message( + self, worker: OrchestratorWorker, checkpoint_without_stats_key + ): + """When stage_timestamps key is absent, posts 'No workflow data found.'.""" + message = _make_jira_message("INT-101", "/forge stats") + mock_jira = _make_mock_jira() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + result = await worker._handle_resume_event(message, checkpoint_without_stats_key) + + mock_jira.add_comment.assert_awaited_once() + comment_body = mock_jira.add_comment.call_args[0][1] + assert "No workflow data found" in comment_body + assert result is checkpoint_without_stats_key + + @pytest.mark.asyncio + async def test_missing_data_comment_posted_to_correct_ticket( + self, worker: OrchestratorWorker, checkpoint_without_stats_key + ): + """Fallback message is posted to the correct ticket key.""" + message = _make_jira_message("INT-101", "/forge stats") + mock_jira = _make_mock_jira() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, checkpoint_without_stats_key) + + ticket_arg = mock_jira.add_comment.call_args[0][0] + assert ticket_arg == "INT-101" + + @pytest.mark.asyncio + async def test_empty_stages_dict_does_not_trigger_fallback( + self, worker: OrchestratorWorker, checkpoint_with_empty_stages + ): + """Empty stage_timestamps dict (key present) uses formatter, not fallback.""" + message = _make_jira_message("INT-102", "/forge stats") + mock_jira = _make_mock_jira() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, checkpoint_with_empty_stages) + + # Should post a formatted comment (not "No workflow data found.") + mock_jira.add_comment.assert_awaited_once() + comment_body = mock_jira.add_comment.call_args[0][1] + assert "No workflow data found" not in comment_body + + @pytest.mark.asyncio + async def test_state_returned_unchanged_when_no_stats( + self, worker: OrchestratorWorker, checkpoint_without_stats_key + ): + """State identity is preserved even when no stats data is found.""" + message = _make_jira_message("INT-101", "/forge stats") + mock_jira = _make_mock_jira() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + result = await worker._handle_resume_event(message, checkpoint_without_stats_key) + + assert result is checkpoint_without_stats_key + + +# --------------------------------------------------------------------------- +# Section 3: /forge stats retry +# --------------------------------------------------------------------------- + + +class TestForgeStatsRetry: + """/forge stats retry re-posts stats via ensure_stats_is_final_comment.""" + + @pytest.mark.asyncio + async def test_retry_calls_ensure_stats_is_final_comment( + self, worker: OrchestratorWorker, checkpoint_with_stats + ): + """/forge stats retry delegates to ensure_stats_is_final_comment, not add_comment.""" + message = _make_jira_message("INT-100", "/forge stats retry") + + with patch( + "forge.orchestrator.worker.ensure_stats_is_final_comment", + new_callable=AsyncMock, + ) as mock_ensure: + result = await worker._handle_resume_event(message, checkpoint_with_stats) + + mock_ensure.assert_awaited_once() + assert result is checkpoint_with_stats + + @pytest.mark.asyncio + async def test_retry_does_not_call_add_comment_directly( + self, worker: OrchestratorWorker, checkpoint_with_stats + ): + """/forge stats retry must not call JiraClient.add_comment for normal re-post.""" + message = _make_jira_message("INT-100", "/forge stats retry") + mock_jira = _make_mock_jira() + + with ( + patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira), + patch( + "forge.orchestrator.worker.ensure_stats_is_final_comment", + new_callable=AsyncMock, + ), + ): + await worker._handle_resume_event(message, checkpoint_with_stats) + + # add_comment should NOT be called by the retry path (it's used by the base path) + mock_jira.add_comment.assert_not_awaited() + + @pytest.mark.asyncio + async def test_retry_passes_correct_ticket_key( + self, worker: OrchestratorWorker, checkpoint_with_stats + ): + """/forge stats retry passes the correct ticket key to ensure_stats_is_final_comment.""" + message = _make_jira_message("INT-100", "/forge stats retry") + + with patch( + "forge.orchestrator.worker.ensure_stats_is_final_comment", + new_callable=AsyncMock, + ) as mock_ensure: + await worker._handle_resume_event(message, checkpoint_with_stats) + + call_args = mock_ensure.call_args + ticket_arg = call_args[0][0] + assert ticket_arg == "INT-100" + + @pytest.mark.asyncio + async def test_retry_state_unchanged(self, worker: OrchestratorWorker, checkpoint_with_stats): + """/forge stats retry returns the same state object unchanged.""" + message = _make_jira_message("INT-100", "/forge stats retry") + + with patch( + "forge.orchestrator.worker.ensure_stats_is_final_comment", + new_callable=AsyncMock, + ): + result = await worker._handle_resume_event(message, checkpoint_with_stats) + + assert result is checkpoint_with_stats + + @pytest.mark.asyncio + async def test_retry_with_missing_stats_posts_no_data_message( + self, worker: OrchestratorWorker, checkpoint_without_stats_key + ): + """/forge stats retry posts 'No workflow data found.' when no stats data.""" + message = _make_jira_message("INT-101", "/forge stats retry") + mock_jira = _make_mock_jira() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + result = await worker._handle_resume_event(message, checkpoint_without_stats_key) + + mock_jira.add_comment.assert_awaited_once() + comment_body = mock_jira.add_comment.call_args[0][1] + assert "No workflow data found" in comment_body + assert result is checkpoint_without_stats_key + + @pytest.mark.asyncio + async def test_retry_ensure_failure_does_not_raise( + self, worker: OrchestratorWorker, checkpoint_with_stats + ): + """/forge stats retry does not propagate exceptions from ensure_stats_is_final_comment.""" + message = _make_jira_message("INT-100", "/forge stats retry") + + with patch( + "forge.orchestrator.worker.ensure_stats_is_final_comment", + new_callable=AsyncMock, + side_effect=RuntimeError("Network error"), + ): + # Should not raise + result = await worker._handle_resume_event(message, checkpoint_with_stats) + + assert result is checkpoint_with_stats + + +# --------------------------------------------------------------------------- +# Section 4: forge stats CLI — table output +# --------------------------------------------------------------------------- + + +class TestCLIStatsTableOutput: + """forge stats displays a human-readable table.""" + + @pytest.mark.asyncio + async def test_table_output_exits_zero_on_success(self, checkpoint_with_stats): + """forge stats returns exit code 0 when checkpoint has stats.""" + from forge.cli import cmd_stats + + args = argparse.Namespace(ticket="INT-100", json=False) + + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=checkpoint_with_stats), + ): + exit_code = await cmd_stats(args) + + assert exit_code == 0 + + @pytest.mark.asyncio + async def test_table_output_contains_stage_labels(self, checkpoint_with_stats, capsys): + """Table output includes stage labels (PRD, Spec) for populated stages.""" + from forge.cli import cmd_stats + + args = argparse.Namespace(ticket="INT-100", json=False) + + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=checkpoint_with_stats), + ): + await cmd_stats(args) + + captured = capsys.readouterr() + assert "PRD" in captured.out + + @pytest.mark.asyncio + async def test_table_output_contains_outcome(self, checkpoint_with_stats, capsys): + """Table output contains an Outcome line.""" + from forge.cli import cmd_stats + + args = argparse.Namespace(ticket="INT-100", json=False) + + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=checkpoint_with_stats), + ): + await cmd_stats(args) + + captured = capsys.readouterr() + assert "Outcome" in captured.out or "In Progress" in captured.out + + @pytest.mark.asyncio + async def test_table_output_is_not_json(self, checkpoint_with_stats, capsys): + """Without --json flag, output is human-readable text, not JSON.""" + from forge.cli import cmd_stats + + args = argparse.Namespace(ticket="INT-100", json=False) + + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=checkpoint_with_stats), + ): + await cmd_stats(args) + + captured = capsys.readouterr() + try: + json.loads(captured.out) + is_json = True + except (json.JSONDecodeError, ValueError): + is_json = False + assert not is_json + + @pytest.mark.asyncio + async def test_table_output_missing_checkpoint_exits_one(self, capsys): + """forge stats exits 1 when no checkpoint is found.""" + from forge.cli import cmd_stats + + args = argparse.Namespace(ticket="INT-999", json=False) + + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=None), + ): + exit_code = await cmd_stats(args) + + assert exit_code == 1 + captured = capsys.readouterr() + assert "No workflow data found" in captured.out + + @pytest.mark.asyncio + async def test_table_output_missing_stats_key_exits_one( + self, checkpoint_without_stats_key, capsys + ): + """forge stats exits 1 when checkpoint lacks stage_timestamps key.""" + from forge.cli import cmd_stats + + args = argparse.Namespace(ticket="INT-101", json=False) + + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=checkpoint_without_stats_key), + ): + exit_code = await cmd_stats(args) + + assert exit_code == 1 + captured = capsys.readouterr() + assert "No workflow data found" in captured.out + + @pytest.mark.asyncio + async def test_table_output_empty_stages_exits_zero(self, checkpoint_with_empty_stages): + """forge stats exits 0 for an in-progress workflow with no stages recorded yet.""" + from forge.cli import cmd_stats + + args = argparse.Namespace(ticket="INT-102", json=False) + + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=checkpoint_with_empty_stages), + ): + exit_code = await cmd_stats(args) + + assert exit_code == 0 + + @pytest.mark.asyncio + async def test_table_output_connection_error_exits_one(self, capsys): + """forge stats exits 1 when checkpointer raises a connection error.""" + from forge.cli import cmd_stats + + args = argparse.Namespace(ticket="INT-100", json=False) + + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(side_effect=ConnectionError("Redis unavailable")), + ): + exit_code = await cmd_stats(args) + + assert exit_code == 1 + captured = capsys.readouterr() + assert "Error" in captured.err + + +# --------------------------------------------------------------------------- +# Section 5: forge stats CLI — JSON output +# --------------------------------------------------------------------------- + + +class TestCLIStatsJsonOutput: + """forge stats --json outputs structured JSON.""" + + @pytest.mark.asyncio + async def test_json_output_is_valid_json(self, checkpoint_with_stats, capsys): + """--json flag produces parseable JSON.""" + from forge.cli import cmd_stats + + args = argparse.Namespace(ticket="INT-100", json=True) + + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=checkpoint_with_stats), + ): + await cmd_stats(args) + + captured = capsys.readouterr() + data = json.loads(captured.out) # Should not raise + assert isinstance(data, dict) + + @pytest.mark.asyncio + async def test_json_output_contains_required_fields(self, checkpoint_with_stats, capsys): + """JSON output includes ticket, outcome, ci_cycles, pr_urls, and stages fields.""" + from forge.cli import cmd_stats + + args = argparse.Namespace(ticket="INT-100", json=True) + + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=checkpoint_with_stats), + ): + await cmd_stats(args) + + data = json.loads(capsys.readouterr().out) + assert "ticket" in data + assert "outcome" in data + assert "ci_cycles" in data + assert "pr_urls" in data + assert "stages" in data + + @pytest.mark.asyncio + async def test_json_output_ticket_matches_requested(self, checkpoint_with_stats, capsys): + """JSON ticket field matches the requested ticket key.""" + from forge.cli import cmd_stats + + args = argparse.Namespace(ticket="INT-100", json=True) + + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=checkpoint_with_stats), + ): + await cmd_stats(args) + + data = json.loads(capsys.readouterr().out) + assert data["ticket"] == "INT-100" + + @pytest.mark.asyncio + async def test_json_output_stages_contains_prd_data(self, checkpoint_with_stats, capsys): + """JSON stages dict includes the prd stage from checkpoint.""" + from forge.cli import cmd_stats + + args = argparse.Namespace(ticket="INT-100", json=True) + + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=checkpoint_with_stats), + ): + await cmd_stats(args) + + data = json.loads(capsys.readouterr().out) + assert "prd" in data["stages"] + + @pytest.mark.asyncio + async def test_json_output_ci_cycles_value(self, checkpoint_with_stats, capsys): + """JSON ci_cycles matches the value stored in checkpoint.""" + from forge.cli import cmd_stats + + args = argparse.Namespace(ticket="INT-100", json=True) + + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=checkpoint_with_stats), + ): + await cmd_stats(args) + + data = json.loads(capsys.readouterr().out) + assert data["ci_cycles"] == checkpoint_with_stats["stats_ci_cycles"] + + @pytest.mark.asyncio + async def test_json_output_pr_urls_present(self, checkpoint_with_stats, capsys): + """JSON pr_urls list matches checkpoint data.""" + from forge.cli import cmd_stats + + args = argparse.Namespace(ticket="INT-100", json=True) + + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=checkpoint_with_stats), + ): + await cmd_stats(args) + + data = json.loads(capsys.readouterr().out) + assert data["pr_urls"] == checkpoint_with_stats["stats_pr_urls"] + + @pytest.mark.asyncio + async def test_json_output_exits_zero_on_success(self, checkpoint_with_stats): + """--json flag returns exit code 0 on success.""" + from forge.cli import cmd_stats + + args = argparse.Namespace(ticket="INT-100", json=True) + + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=checkpoint_with_stats), + ): + exit_code = await cmd_stats(args) + + assert exit_code == 0 + + @pytest.mark.asyncio + async def test_json_output_missing_checkpoint_exits_one(self): + """--json flag still exits 1 when no checkpoint is found.""" + from forge.cli import cmd_stats + + args = argparse.Namespace(ticket="INT-999", json=True) + + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=None), + ): + exit_code = await cmd_stats(args) + + assert exit_code == 1 + + +# --------------------------------------------------------------------------- +# Section 6: Partial / failed / blocked workflow stats +# --------------------------------------------------------------------------- + + +class TestPartialAndSpecialOutcomes: + """Stats commands handle partial, failed, and blocked workflow states correctly.""" + + @pytest.mark.asyncio + async def test_jira_stats_completed_workflow_shows_completed_outcome( + self, worker: OrchestratorWorker, checkpoint_completed + ): + """Pre-set workflow_outcome='Completed' is forwarded directly to comment.""" + message = _make_jira_message("INT-105", "/forge stats") + mock_jira = _make_mock_jira() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, checkpoint_completed) + + comment_body = mock_jira.add_comment.call_args[0][1] + assert "Completed" in comment_body + + @pytest.mark.asyncio + async def test_cli_blocked_workflow_outcome_in_json(self, checkpoint_blocked, capsys): + """CLI --json output for blocked workflow includes 'Blocked' outcome.""" + from forge.cli import cmd_stats + + args = argparse.Namespace(ticket="INT-103", json=True) + + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=checkpoint_blocked), + ): + await cmd_stats(args) + + data = json.loads(capsys.readouterr().out) + assert data["outcome"] == "Blocked" + + @pytest.mark.asyncio + async def test_cli_failed_workflow_outcome_in_json(self, checkpoint_failed, capsys): + """CLI --json output for failed workflow includes 'Failed' outcome.""" + from forge.cli import cmd_stats + + args = argparse.Namespace(ticket="INT-104", json=True) + + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=checkpoint_failed), + ): + await cmd_stats(args) + + data = json.loads(capsys.readouterr().out) + assert data["outcome"] == "Failed" + + @pytest.mark.asyncio + async def test_cli_in_progress_workflow_outcome_in_json( + self, checkpoint_with_empty_stages, capsys + ): + """CLI --json output for in-progress workflow includes 'In Progress' outcome.""" + from forge.cli import cmd_stats + + args = argparse.Namespace(ticket="INT-102", json=True) + + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=checkpoint_with_empty_stages), + ): + await cmd_stats(args) + + data = json.loads(capsys.readouterr().out) + assert data["outcome"] == "In Progress" + + @pytest.mark.asyncio + async def test_cli_completed_workflow_outcome_in_json(self, checkpoint_completed, capsys): + """CLI --json output for completed workflow includes 'Completed' outcome.""" + from forge.cli import cmd_stats + + args = argparse.Namespace(ticket="INT-105", json=True) + + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=checkpoint_completed), + ): + await cmd_stats(args) + + data = json.loads(capsys.readouterr().out) + assert data["outcome"] == "Completed" + + @pytest.mark.asyncio + async def test_jira_stats_partial_workflow_shows_prd_stage_only( + self, worker: OrchestratorWorker, checkpoint_with_stats + ): + """Stats for a workflow that has only completed PRD shows only PRD metrics.""" + # Remove spec stage to simulate partial run (only PRD completed) + partial_state = { + **checkpoint_with_stats, + "stage_timestamps": { + "prd": checkpoint_with_stats["stage_timestamps"]["prd"], + }, + } + + message = _make_jira_message("INT-100", "/forge stats") + mock_jira = _make_mock_jira() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, partial_state) + + comment_body = mock_jira.add_comment.call_args[0][1] + # PRD metrics should appear; spec should show dash/empty + assert "PRD" in comment_body or "prd" in comment_body + + @pytest.mark.asyncio + async def test_cli_partial_workflow_json_contains_only_recorded_stages( + self, checkpoint_with_stats, capsys + ): + """CLI JSON for partial workflow only includes recorded stages.""" + from forge.cli import cmd_stats + + # Use just the PRD stage + partial_state = { + **checkpoint_with_stats, + "stage_timestamps": { + "prd": checkpoint_with_stats["stage_timestamps"]["prd"], + }, + } + args = argparse.Namespace(ticket="INT-100", json=True) + + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=partial_state), + ): + await cmd_stats(args) + + data = json.loads(capsys.readouterr().out) + assert "prd" in data["stages"] + assert "spec" not in data["stages"] + + @pytest.mark.asyncio + async def test_jira_stats_multiple_pr_urls_in_comment( + self, worker: OrchestratorWorker, checkpoint_completed + ): + """Stats comment for completed workflow includes PR URLs section.""" + message = _make_jira_message("INT-105", "/forge stats") + mock_jira = _make_mock_jira() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, checkpoint_completed) + + comment_body = mock_jira.add_comment.call_args[0][1] + # The formatter includes PR URLs when they are present + assert "github.com" in comment_body or "pull" in comment_body.lower() diff --git a/tests/integration/test_weekly_report.py b/tests/integration/test_weekly_report.py new file mode 100644 index 00000000..180ff7d5 --- /dev/null +++ b/tests/integration/test_weekly_report.py @@ -0,0 +1,1628 @@ +"""Integration tests for the weekly reporting system. + +These tests verify end-to-end flows for the weekly reporting system including: +- Data aggregation from Redis checkpoints (collect_weekly_data) +- Date-range and project filtering +- Per-feature rollup grouping +- CLI output: text, JSON, and file export +- Jira ticket creation and idempotent updates +- Notification delivery + +Redis and Jira network calls are mocked to avoid external dependencies. +""" + +from __future__ import annotations + +import argparse +import json +import tempfile +from datetime import UTC, datetime, timedelta +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from forge.integrations.jira.models import JiraIssue +from forge.workflow.stats.weekly_report import ( + UNASSIGNED_FEATURE_KEY, + TicketSummary, + WeeklyReportData, + collect_weekly_data, +) + +# --------------------------------------------------------------------------- +# Shared constants — computed at import time so timestamps are always recent +# --------------------------------------------------------------------------- + +_NOW = datetime.now(UTC) +_ONE_DAY_AGO = (_NOW - timedelta(days=1)).isoformat() +_THREE_DAYS_AGO = (_NOW - timedelta(days=3)).isoformat() +_TEN_DAYS_AGO = (_NOW - timedelta(days=10)).isoformat() + + +# --------------------------------------------------------------------------- +# Fixture: mock_workflow_checkpoints +# --------------------------------------------------------------------------- + + +def _make_stage( + stage_name: str = "prd", + *, + iteration_count: int = 1, + machine_time_seconds: float = 60.0, + human_time_seconds: float = 0.0, + input_tokens: int = 500, + output_tokens: int = 250, + started_at: str | None = None, + ended_at: str | None = None, +) -> dict: + """Build a single stage stats dict with sensible defaults.""" + return { + "stage_name": stage_name, + "iteration_count": iteration_count, + "machine_time_seconds": machine_time_seconds, + "human_time_seconds": human_time_seconds, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "started_at": started_at or _ONE_DAY_AGO, + "ended_at": ended_at, + } + + +def _make_checkpoint( + ticket_key: str = "PROJ-1", + *, + ticket_type: str = "Feature", + workflow_outcome: str | None = "Completed", + is_blocked: bool = False, + stats_ci_cycles: int = 0, + updated_at: str | None = None, + stage_timestamps: dict | None = None, + **extra: object, +) -> dict: + """Build a minimal checkpoint state dict that weekly_report can parse.""" + if stage_timestamps is None: + stage_timestamps = { + "prd": _make_stage( + "prd", + started_at=_ONE_DAY_AGO, + ended_at=_ONE_DAY_AGO, + ) + } + return { + "ticket_key": ticket_key, + "ticket_type": ticket_type, + "workflow_outcome": workflow_outcome, + "is_blocked": is_blocked, + "stage_timestamps": stage_timestamps, + "stats_ci_cycles": stats_ci_cycles, + "updated_at": updated_at or _ONE_DAY_AGO, + **extra, + } + + +@pytest.fixture +def mock_workflow_checkpoints() -> dict[str, dict]: + """Factory: a dict of ticket_key to checkpoint state for PROJ-* tickets. + + Contains: + - PROJ-1: completed Feature, PRD + Spec stages, 1 CI cycle + - PROJ-2: in-progress Feature, PRD stage only + - PROJ-3: blocked Feature, PRD stage, is_blocked=True + """ + return { + "PROJ-1": _make_checkpoint( + ticket_key="PROJ-1", + ticket_type="Feature", + workflow_outcome="Completed", + stats_ci_cycles=1, + stage_timestamps={ + "prd": _make_stage( + "prd", + iteration_count=2, + machine_time_seconds=45.0, + input_tokens=1200, + output_tokens=2000, + started_at=_ONE_DAY_AGO, + ended_at=_ONE_DAY_AGO, + ), + "spec": _make_stage( + "spec", + iteration_count=1, + machine_time_seconds=30.0, + input_tokens=800, + output_tokens=1500, + started_at=_ONE_DAY_AGO, + ended_at=_ONE_DAY_AGO, + ), + }, + ), + "PROJ-2": _make_checkpoint( + ticket_key="PROJ-2", + ticket_type="Feature", + workflow_outcome=None, + stage_timestamps={ + "prd": _make_stage( + "prd", + iteration_count=1, + machine_time_seconds=60.0, + input_tokens=700, + output_tokens=900, + started_at=_ONE_DAY_AGO, + ended_at=None, # Still running + ) + }, + ), + "PROJ-3": _make_checkpoint( + ticket_key="PROJ-3", + ticket_type="Feature", + workflow_outcome=None, + is_blocked=True, + stage_timestamps={ + "prd": _make_stage( + "prd", + iteration_count=3, + machine_time_seconds=120.0, + input_tokens=3000, + output_tokens=4000, + started_at=_ONE_DAY_AGO, + ended_at=_ONE_DAY_AGO, + ) + }, + ), + } + + +# --------------------------------------------------------------------------- +# Fixture: mock_jira_responses +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_jira_responses() -> MagicMock: + """Mock JiraClient with pre-configured responses for weekly report operations.""" + jira = MagicMock() + jira.close = AsyncMock() + jira.get_issue = AsyncMock() + jira.search_issues = AsyncMock(return_value=[]) + jira.create_task = AsyncMock(return_value="PROJ-99") + jira.update_description = AsyncMock() + jira.add_comment = AsyncMock() + jira.get_project_property = AsyncMock(return_value=None) + return jira + + +@pytest.fixture(autouse=True) +def _patch_get_checkpoint_state(): + async def mock_get_state(ticket_key: str): + from forge.workflow.stats.weekly_report import get_redis_client + + try: + redis_client = await get_redis_client() + key = f"checkpoint:{ticket_key}" + val = await redis_client.get(key) + if val is not None: + import json + + return json.loads(val) + except Exception: + pass + return None + + with patch( + "forge.workflow.stats.weekly_report.get_checkpoint_state", side_effect=mock_get_state + ): + yield + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _build_redis_mock(checkpoints: dict[str, dict]) -> MagicMock: + """Build a mock Redis client returning checkpoints keyed by Redis pattern. + + The checkpoint key format is ``checkpoint:{ticket_key}``. + The ``scan`` mock is pattern-aware so only matching keys are returned. + """ + redis = MagicMock() + + key_map: dict[str, str] = { + f"checkpoint:{ticket_key}": json.dumps(state) for ticket_key, state in checkpoints.items() + } + + async def _scan(cursor: int, match: str, count: int) -> tuple[int, list[str]]: + if cursor == 0: + prefix = match.rstrip("*") + filtered = [k for k in key_map if k.startswith(prefix)] + return (0, filtered) + return (0, []) + + redis.scan = AsyncMock(side_effect=_scan) + + async def _get(key: str) -> str | None: + return key_map.get(key) + + redis.get = AsyncMock(side_effect=_get) + return redis + + +def _make_jira_issue( + key: str, + issue_type: str = "Task", + summary: str = "", + parent_key: str | None = None, +) -> JiraIssue: + """Build a minimal JiraIssue for testing hierarchy resolution.""" + return JiraIssue( + key=key, + id="1", + summary=summary or f"Summary of {key}", + description="", + status="In Progress", + issue_type=issue_type, + parent_key=parent_key, + ) + + +def _make_cli_args( + project: str = "PROJ", + days: int = 7, + output: str | None = None, + fmt: str = "text", + create_ticket: bool = False, + notify: bool = False, +) -> argparse.Namespace: + """Create a minimal argparse.Namespace for cmd_weekly_report.""" + return argparse.Namespace( + project=project, + days=days, + output=output, + format=fmt, + create_ticket=create_ticket, + notify=notify, + ) + + +def _make_report( + project: str = "PROJ", + *, + completed: list[TicketSummary] | None = None, + in_progress: list[TicketSummary] | None = None, + blocked: list[TicketSummary] | None = None, +) -> WeeklyReportData: + """Build a WeeklyReportData for CLI testing.""" + if completed is None: + completed = [ + TicketSummary( + ticket_key=f"{project}-1", + status="completed", + duration_seconds=3600.0, + input_tokens=1000, + output_tokens=500, + ) + ] + ip = in_progress or [] + bl = blocked or [] + return WeeklyReportData( + project=project, + period_days=7, + report_start=_THREE_DAYS_AGO, + report_end=_ONE_DAY_AGO, + completed_tickets=completed, + in_progress_tickets=ip, + blocked_tickets=bl, + total_input_tokens=sum(t.input_tokens for t in completed + ip + bl), + total_output_tokens=sum(t.output_tokens for t in completed + ip + bl), + all_tickets=list(completed) + list(ip) + list(bl), + ) + + +# --------------------------------------------------------------------------- +# Section 1: test_collect_weekly_data_with_multiple_workflows +# --------------------------------------------------------------------------- + + +class TestCollectWeeklyDataWithMultipleWorkflows: + """Verifies data aggregation from multiple checkpoints.""" + + @pytest.mark.asyncio + async def test_all_tickets_collected(self, mock_workflow_checkpoints): + """All checkpoints within the window are included in all_tickets.""" + redis = _build_redis_mock(mock_workflow_checkpoints) + jira = MagicMock() + jira.close = AsyncMock() + jira.get_issue = AsyncMock(side_effect=Exception("hierarchy not needed")) + + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis), + ), + patch( + "forge.workflow.stats.weekly_report.JiraClient", + return_value=jira, + ), + ): + report = await collect_weekly_data("PROJ", days=7) + + assert len(report.all_tickets) == 3 + + @pytest.mark.asyncio + async def test_completed_tickets_categorised(self, mock_workflow_checkpoints): + """Completed tickets go into the completed_tickets list.""" + redis = _build_redis_mock(mock_workflow_checkpoints) + jira = MagicMock() + jira.close = AsyncMock() + jira.get_issue = AsyncMock(side_effect=Exception("hierarchy not needed")) + + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis), + ), + patch( + "forge.workflow.stats.weekly_report.JiraClient", + return_value=jira, + ), + ): + report = await collect_weekly_data("PROJ", days=7) + + assert len(report.completed_tickets) == 1 + assert report.completed_tickets[0].ticket_key == "PROJ-1" + + @pytest.mark.asyncio + async def test_in_progress_tickets_categorised(self, mock_workflow_checkpoints): + """In-progress tickets go into the in_progress_tickets list.""" + redis = _build_redis_mock(mock_workflow_checkpoints) + jira = MagicMock() + jira.close = AsyncMock() + jira.get_issue = AsyncMock(side_effect=Exception("hierarchy not needed")) + + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis), + ), + patch( + "forge.workflow.stats.weekly_report.JiraClient", + return_value=jira, + ), + ): + report = await collect_weekly_data("PROJ", days=7) + + assert len(report.in_progress_tickets) == 1 + assert report.in_progress_tickets[0].ticket_key == "PROJ-2" + + @pytest.mark.asyncio + async def test_blocked_tickets_categorised(self, mock_workflow_checkpoints): + """Blocked tickets go into the blocked_tickets list.""" + redis = _build_redis_mock(mock_workflow_checkpoints) + jira = MagicMock() + jira.close = AsyncMock() + jira.get_issue = AsyncMock(side_effect=Exception("hierarchy not needed")) + + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis), + ), + patch( + "forge.workflow.stats.weekly_report.JiraClient", + return_value=jira, + ), + ): + report = await collect_weekly_data("PROJ", days=7) + + assert len(report.blocked_tickets) == 1 + assert report.blocked_tickets[0].ticket_key == "PROJ-3" + + @pytest.mark.asyncio + async def test_token_totals_aggregated(self, mock_workflow_checkpoints): + """Token counts are summed across all tickets.""" + redis = _build_redis_mock(mock_workflow_checkpoints) + jira = MagicMock() + jira.close = AsyncMock() + jira.get_issue = AsyncMock(side_effect=Exception("hierarchy not needed")) + + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis), + ), + patch( + "forge.workflow.stats.weekly_report.JiraClient", + return_value=jira, + ), + ): + report = await collect_weekly_data("PROJ", days=7) + + # PROJ-1: prd(1200) + spec(800) = 2000 in; prd(2000) + spec(1500) = 3500 out + # PROJ-2: 700 in, 900 out + # PROJ-3: 3000 in, 4000 out + assert report.total_input_tokens == 5700 + assert report.total_output_tokens == 8400 + + @pytest.mark.asyncio + async def test_project_field_set(self, mock_workflow_checkpoints): + """The project field in the report matches the argument.""" + redis = _build_redis_mock(mock_workflow_checkpoints) + jira = MagicMock() + jira.close = AsyncMock() + jira.get_issue = AsyncMock(side_effect=Exception("hierarchy not needed")) + + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis), + ), + patch( + "forge.workflow.stats.weekly_report.JiraClient", + return_value=jira, + ), + ): + report = await collect_weekly_data("PROJ", days=7) + + assert report.project == "PROJ" + + @pytest.mark.asyncio + async def test_empty_data_returns_zero_counts(self): + """When no checkpoints exist, all ticket lists are empty.""" + redis = _build_redis_mock({}) + jira = MagicMock() + jira.close = AsyncMock() + jira.get_issue = AsyncMock(side_effect=Exception("hierarchy not needed")) + + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis), + ), + patch( + "forge.workflow.stats.weekly_report.JiraClient", + return_value=jira, + ), + ): + report = await collect_weekly_data("PROJ", days=7) + + assert report.all_tickets == [] + assert report.completed_tickets == [] + assert report.in_progress_tickets == [] + assert report.blocked_tickets == [] + assert report.total_input_tokens == 0 + assert report.total_output_tokens == 0 + + +# --------------------------------------------------------------------------- +# Section 2: test_collect_weekly_data_filters_by_date_range +# --------------------------------------------------------------------------- + + +class TestCollectWeeklyDataFiltersByDateRange: + """Verifies time-window filtering.""" + + @pytest.mark.asyncio + async def test_recent_checkpoint_included(self): + """A checkpoint updated 1 day ago is included in a 7-day window.""" + checkpoints = { + "PROJ-10": _make_checkpoint( + ticket_key="PROJ-10", + updated_at=_ONE_DAY_AGO, + ) + } + redis = _build_redis_mock(checkpoints) + jira = MagicMock() + jira.close = AsyncMock() + jira.get_issue = AsyncMock(side_effect=Exception("hierarchy not needed")) + + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis), + ), + patch( + "forge.workflow.stats.weekly_report.JiraClient", + return_value=jira, + ), + ): + report = await collect_weekly_data("PROJ", days=7) + + assert len(report.all_tickets) == 1 + assert report.all_tickets[0].ticket_key == "PROJ-10" + + @pytest.mark.asyncio + async def test_old_checkpoint_excluded(self): + """A checkpoint updated 10 days ago is excluded from a 7-day window.""" + old_checkpoint = _make_checkpoint( + ticket_key="PROJ-20", + updated_at=_TEN_DAYS_AGO, + stage_timestamps={ + "prd": _make_stage( + "prd", + started_at=_TEN_DAYS_AGO, + ended_at=_TEN_DAYS_AGO, + ) + }, + ) + checkpoints = {"PROJ-20": old_checkpoint} + redis = _build_redis_mock(checkpoints) + jira = MagicMock() + jira.close = AsyncMock() + jira.get_issue = AsyncMock(side_effect=Exception("hierarchy not needed")) + + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis), + ), + patch( + "forge.workflow.stats.weekly_report.JiraClient", + return_value=jira, + ), + ): + report = await collect_weekly_data("PROJ", days=7) + + assert report.all_tickets == [] + + @pytest.mark.asyncio + async def test_mixed_old_and_recent(self): + """Only the recent checkpoint is returned when mixed ages are present.""" + checkpoints = { + "PROJ-10": _make_checkpoint( + ticket_key="PROJ-10", + updated_at=_ONE_DAY_AGO, + ), + "PROJ-20": _make_checkpoint( + ticket_key="PROJ-20", + updated_at=_TEN_DAYS_AGO, + stage_timestamps={ + "prd": _make_stage("prd", started_at=_TEN_DAYS_AGO, ended_at=_TEN_DAYS_AGO) + }, + ), + } + redis = _build_redis_mock(checkpoints) + jira = MagicMock() + jira.close = AsyncMock() + jira.get_issue = AsyncMock(side_effect=Exception("hierarchy not needed")) + + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis), + ), + patch( + "forge.workflow.stats.weekly_report.JiraClient", + return_value=jira, + ), + ): + report = await collect_weekly_data("PROJ", days=7) + + assert len(report.all_tickets) == 1 + assert report.all_tickets[0].ticket_key == "PROJ-10" + + @pytest.mark.asyncio + async def test_stage_timestamp_qualifies_checkpoint(self): + """A checkpoint qualifies by stage.started_at even if updated_at is old.""" + # updated_at is 10 days ago but a stage started_at is within the window + checkpoint = _make_checkpoint( + ticket_key="PROJ-30", + updated_at=_TEN_DAYS_AGO, # old top-level timestamp + stage_timestamps={ + "prd": _make_stage( + "prd", + started_at=_ONE_DAY_AGO, # recent stage timestamp qualifies it + ended_at=_ONE_DAY_AGO, + ) + }, + ) + checkpoints = {"PROJ-30": checkpoint} + redis = _build_redis_mock(checkpoints) + jira = MagicMock() + jira.close = AsyncMock() + jira.get_issue = AsyncMock(side_effect=Exception("hierarchy not needed")) + + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis), + ), + patch( + "forge.workflow.stats.weekly_report.JiraClient", + return_value=jira, + ), + ): + report = await collect_weekly_data("PROJ", days=7) + + assert len(report.all_tickets) == 1 + assert report.all_tickets[0].ticket_key == "PROJ-30" + + +# --------------------------------------------------------------------------- +# Section 3: test_collect_weekly_data_filters_by_project +# --------------------------------------------------------------------------- + + +class TestCollectWeeklyDataFiltersByProject: + """Verifies project scoping via Redis scan pattern.""" + + @pytest.mark.asyncio + async def test_only_matching_project_keys_returned(self): + """Only checkpoints for project PROJ are returned, not OTHER.""" + proj_checkpoint = _make_checkpoint(ticket_key="PROJ-1") + other_checkpoint = _make_checkpoint(ticket_key="OTHER-1") + + redis = MagicMock() + key_map = { + "checkpoint:PROJ-1": json.dumps(proj_checkpoint), + "checkpoint:OTHER-1": json.dumps(other_checkpoint), + } + + async def _scan(cursor: int, match: str, count: int) -> tuple[int, list[str]]: + prefix = match.rstrip("*") + filtered = [k for k in key_map if k.startswith(prefix)] + return (0, filtered) + + async def _get(key: str) -> str | None: + return key_map.get(key) + + redis.scan = AsyncMock(side_effect=_scan) + redis.get = AsyncMock(side_effect=_get) + + jira = MagicMock() + jira.close = AsyncMock() + jira.get_issue = AsyncMock(side_effect=Exception("hierarchy not needed")) + + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis), + ), + patch( + "forge.workflow.stats.weekly_report.JiraClient", + return_value=jira, + ), + ): + report = await collect_weekly_data("PROJ", days=7) + + assert len(report.all_tickets) == 1 + assert report.all_tickets[0].ticket_key == "PROJ-1" + + @pytest.mark.asyncio + async def test_different_project_key_not_mixed_in(self): + """Requesting OTHER project does not return PROJ tickets.""" + proj_checkpoint = _make_checkpoint(ticket_key="PROJ-1") + other_checkpoint = _make_checkpoint(ticket_key="OTHER-1") + + redis = MagicMock() + key_map = { + "checkpoint:PROJ-1": json.dumps(proj_checkpoint), + "checkpoint:OTHER-1": json.dumps(other_checkpoint), + } + + async def _scan(cursor: int, match: str, count: int) -> tuple[int, list[str]]: + prefix = match.rstrip("*") + filtered = [k for k in key_map if k.startswith(prefix)] + return (0, filtered) + + async def _get(key: str) -> str | None: + return key_map.get(key) + + redis.scan = AsyncMock(side_effect=_scan) + redis.get = AsyncMock(side_effect=_get) + + jira = MagicMock() + jira.close = AsyncMock() + jira.get_issue = AsyncMock(side_effect=Exception("hierarchy not needed")) + + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis), + ), + patch( + "forge.workflow.stats.weekly_report.JiraClient", + return_value=jira, + ), + ): + report = await collect_weekly_data("OTHER", days=7) + + assert len(report.all_tickets) == 1 + assert report.all_tickets[0].ticket_key == "OTHER-1" + + +# --------------------------------------------------------------------------- +# Section 4: test_feature_rollup_groups_correctly +# --------------------------------------------------------------------------- + + +class TestFeatureRollupGroupsCorrectly: + """Verifies tickets are grouped by parent feature.""" + + @pytest.mark.asyncio + async def test_tickets_grouped_under_feature(self): + """Tickets resolved to the same Feature are grouped into one rollup.""" + checkpoint_t1 = _make_checkpoint(ticket_key="PROJ-10") + checkpoint_t2 = _make_checkpoint(ticket_key="PROJ-11") + + redis = _build_redis_mock({"PROJ-10": checkpoint_t1, "PROJ-11": checkpoint_t2}) + + # Both tickets resolve to parent FEAT-1 + feature_issue = _make_jira_issue("FEAT-1", issue_type="Feature", summary="My Feature") + task_issue_t1 = _make_jira_issue("PROJ-10", issue_type="Task", parent_key="FEAT-1") + task_issue_t2 = _make_jira_issue("PROJ-11", issue_type="Task", parent_key="FEAT-1") + + issue_map = { + "FEAT-1": feature_issue, + "PROJ-10": task_issue_t1, + "PROJ-11": task_issue_t2, + } + + async def _get_issue(key: str) -> JiraIssue: + return issue_map[key] + + jira = MagicMock() + jira.close = AsyncMock() + jira.get_issue = AsyncMock(side_effect=_get_issue) + + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis), + ), + patch( + "forge.workflow.stats.weekly_report.JiraClient", + return_value=jira, + ), + ): + report = await collect_weekly_data("PROJ", days=7) + + assert "FEAT-1" in report.feature_rollups + rollup = report.feature_rollups["FEAT-1"] + assert len(rollup.linked_tickets) == 2 + + @pytest.mark.asyncio + async def test_unresolvable_tickets_go_to_unassigned(self): + """Tickets with no feature parent are placed in the Unassigned bucket.""" + checkpoint = _make_checkpoint(ticket_key="PROJ-50") + redis = _build_redis_mock({"PROJ-50": checkpoint}) + + # get_issue raises so no Feature can be resolved + jira = MagicMock() + jira.close = AsyncMock() + jira.get_issue = AsyncMock(side_effect=Exception("Jira unavailable")) + + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis), + ), + patch( + "forge.workflow.stats.weekly_report.JiraClient", + return_value=jira, + ), + ): + report = await collect_weekly_data("PROJ", days=7) + + assert UNASSIGNED_FEATURE_KEY in report.feature_rollups + assert len(report.feature_rollups[UNASSIGNED_FEATURE_KEY].linked_tickets) == 1 + + @pytest.mark.asyncio + async def test_completion_percentage_computed(self): + """completion_percentage is 50 % when 1 of 2 linked tickets is completed.""" + checkpoint_done = _make_checkpoint(ticket_key="PROJ-60", workflow_outcome="Completed") + checkpoint_wip = _make_checkpoint(ticket_key="PROJ-61", workflow_outcome=None) + redis = _build_redis_mock({"PROJ-60": checkpoint_done, "PROJ-61": checkpoint_wip}) + + feature_issue = _make_jira_issue("FEAT-2", issue_type="Feature") + task_done = _make_jira_issue("PROJ-60", issue_type="Task", parent_key="FEAT-2") + task_wip = _make_jira_issue("PROJ-61", issue_type="Task", parent_key="FEAT-2") + + issue_map = { + "FEAT-2": feature_issue, + "PROJ-60": task_done, + "PROJ-61": task_wip, + } + + async def _get_issue(key: str) -> JiraIssue: + return issue_map[key] + + jira = MagicMock() + jira.close = AsyncMock() + jira.get_issue = AsyncMock(side_effect=_get_issue) + + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis), + ), + patch( + "forge.workflow.stats.weekly_report.JiraClient", + return_value=jira, + ), + ): + report = await collect_weekly_data("PROJ", days=7) + + rollup = report.feature_rollups.get("FEAT-2") + assert rollup is not None + assert rollup.completion_percentage == pytest.approx(50.0) + + @pytest.mark.asyncio + async def test_empty_checkpoint_list_produces_no_rollups(self): + """When there are no checkpoints, feature_rollups is an empty dict.""" + redis = _build_redis_mock({}) + jira = MagicMock() + jira.close = AsyncMock() + jira.get_issue = AsyncMock(side_effect=Exception("should not be called")) + + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis), + ), + patch( + "forge.workflow.stats.weekly_report.JiraClient", + return_value=jira, + ), + ): + report = await collect_weekly_data("PROJ", days=7) + + assert report.feature_rollups == {} + + +# --------------------------------------------------------------------------- +# Section 5: test_cli_weekly_report_text_output +# --------------------------------------------------------------------------- + + +class TestCliWeeklyReportTextOutput: + """Verifies CLI produces correct text output.""" + + @pytest.mark.asyncio + async def test_text_output_exits_zero(self, capsys): + """forge weekly-report exits 0 when data is available.""" + from forge.cli import cmd_weekly_report + + report = _make_report() + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + code = await cmd_weekly_report(_make_cli_args(fmt="text")) + + assert code == 0 + + @pytest.mark.asyncio + async def test_text_output_contains_ticket_key(self, capsys): + """Text output mentions the completed ticket key.""" + from forge.cli import cmd_weekly_report + + report = _make_report(project="PROJ") + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + await cmd_weekly_report(_make_cli_args(project="PROJ", fmt="text")) + + out = capsys.readouterr().out + assert "PROJ-1" in out + + @pytest.mark.asyncio + async def test_text_output_contains_project_name(self, capsys): + """Text output includes the project name.""" + from forge.cli import cmd_weekly_report + + report = _make_report(project="MYPROJ") + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + await cmd_weekly_report(_make_cli_args(project="MYPROJ", fmt="text")) + + out = capsys.readouterr().out + assert "MYPROJ" in out + + @pytest.mark.asyncio + async def test_no_data_exits_nonzero(self, capsys): + """forge weekly-report exits 1 when no tickets are found.""" + from forge.cli import cmd_weekly_report + + empty_report = WeeklyReportData( + project="PROJ", + period_days=7, + report_start=_THREE_DAYS_AGO, + report_end=_ONE_DAY_AGO, + ) + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=empty_report), + ): + code = await cmd_weekly_report(_make_cli_args()) + + assert code == 1 + + @pytest.mark.asyncio + async def test_error_during_collection_exits_nonzero(self, capsys): + """forge weekly-report exits 1 on collection errors.""" + from forge.cli import cmd_weekly_report + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(side_effect=RuntimeError("Redis unavailable")), + ): + code = await cmd_weekly_report(_make_cli_args()) + + assert code == 1 + + @pytest.mark.asyncio + async def test_single_ticket_text_output(self, capsys): + """A report with a single completed ticket produces text output.""" + from forge.cli import cmd_weekly_report + + report = _make_report( + completed=[ + TicketSummary( + ticket_key="PROJ-1", + status="completed", + duration_seconds=1800.0, + input_tokens=500, + output_tokens=200, + ) + ] + ) + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + code = await cmd_weekly_report(_make_cli_args(fmt="text")) + + assert code == 0 + out = capsys.readouterr().out + assert "PROJ-1" in out + + @pytest.mark.asyncio + async def test_no_completed_tickets_text_output(self, capsys): + """A report with only in-progress tickets still exits 0.""" + from forge.cli import cmd_weekly_report + + report = _make_report( + completed=[], + in_progress=[ + TicketSummary( + ticket_key="PROJ-5", + status="in_progress", + input_tokens=200, + output_tokens=100, + ) + ], + ) + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + code = await cmd_weekly_report(_make_cli_args(fmt="text")) + + assert code == 0 + + +# --------------------------------------------------------------------------- +# Section 6: test_cli_weekly_report_json_output +# --------------------------------------------------------------------------- + + +class TestCliWeeklyReportJsonOutput: + """Verifies JSON output is valid and complete.""" + + @pytest.mark.asyncio + async def test_json_output_is_valid(self, capsys): + """--format json produces parseable JSON.""" + from forge.cli import cmd_weekly_report + + report = _make_report() + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + code = await cmd_weekly_report(_make_cli_args(fmt="json")) + + assert code == 0 + out = capsys.readouterr().out + data = json.loads(out) + assert isinstance(data, dict) + + @pytest.mark.asyncio + async def test_json_output_contains_required_fields(self, capsys): + """JSON output contains the required top-level sections.""" + from forge.cli import cmd_weekly_report + + report = _make_report() + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + await cmd_weekly_report(_make_cli_args(fmt="json")) + + out = capsys.readouterr().out + data = json.loads(out) + # Weekly JSON formatter has these mandatory top-level keys + assert "project" in data + assert "completed_tickets" in data + assert "in_progress_tickets" in data + assert "blocked_tickets" in data + # Token totals are nested under 'summary' + assert "summary" in data + assert "total_input_tokens" in data["summary"] + assert "total_output_tokens" in data["summary"] + + @pytest.mark.asyncio + async def test_json_output_contains_ticket_keys(self, capsys): + """JSON completed_tickets contains the ticket keys.""" + from forge.cli import cmd_weekly_report + + report = _make_report(project="PROJ") + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + await cmd_weekly_report(_make_cli_args(project="PROJ", fmt="json")) + + out = capsys.readouterr().out + data = json.loads(out) + ticket_keys = [t["ticket_key"] for t in data["completed_tickets"]] + assert "PROJ-1" in ticket_keys + + @pytest.mark.asyncio + async def test_json_output_empty_completed(self, capsys): + """JSON output is still valid when completed_tickets is empty.""" + from forge.cli import cmd_weekly_report + + report = _make_report( + completed=[], + in_progress=[TicketSummary(ticket_key="PROJ-5", status="in_progress")], + ) + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + code = await cmd_weekly_report(_make_cli_args(fmt="json")) + + assert code == 0 + out = capsys.readouterr().out + data = json.loads(out) + assert data["completed_tickets"] == [] + + +# --------------------------------------------------------------------------- +# Section 7: test_cli_weekly_report_file_export +# --------------------------------------------------------------------------- + + +class TestCliWeeklyReportFileExport: + """Verifies file export works.""" + + @pytest.mark.asyncio + async def test_file_export_creates_file(self): + """--output writes the report to disk.""" + from forge.cli import cmd_weekly_report + + report = _make_report() + + with tempfile.TemporaryDirectory() as tmpdir: + outfile = str(Path(tmpdir) / "report.txt") + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + code = await cmd_weekly_report(_make_cli_args(fmt="text", output=outfile)) + + assert code == 0 + assert Path(outfile).exists() + + @pytest.mark.asyncio + async def test_file_export_contains_project_name(self): + """The exported file content includes the project name.""" + from forge.cli import cmd_weekly_report + + report = _make_report(project="MYPROJ") + + with tempfile.TemporaryDirectory() as tmpdir: + outfile = str(Path(tmpdir) / "report.txt") + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + await cmd_weekly_report( + _make_cli_args(project="MYPROJ", fmt="text", output=outfile) + ) + + content = Path(outfile).read_text() + assert "MYPROJ" in content + + @pytest.mark.asyncio + async def test_file_export_json_format(self): + """File export with --format json writes valid JSON.""" + from forge.cli import cmd_weekly_report + + report = _make_report() + + with tempfile.TemporaryDirectory() as tmpdir: + outfile = str(Path(tmpdir) / "report.json") + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + code = await cmd_weekly_report(_make_cli_args(fmt="json", output=outfile)) + + assert code == 0 + content = Path(outfile).read_text() + data = json.loads(content) + assert "project" in data + + @pytest.mark.asyncio + async def test_file_export_invalid_path_exits_nonzero(self, capsys): + """Writing to a non-existent directory exits 1.""" + from forge.cli import cmd_weekly_report + + report = _make_report() + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + code = await cmd_weekly_report( + _make_cli_args(fmt="text", output="/nonexistent/dir/report.txt") + ) + + assert code == 1 + + +# --------------------------------------------------------------------------- +# Section 8: test_report_ticket_creation +# --------------------------------------------------------------------------- + + +class TestReportTicketCreation: + """Verifies Jira ticket is created with correct fields.""" + + @pytest.mark.asyncio + async def test_ticket_created_with_correct_summary(self): + """create_report_ticket uses the expected summary format.""" + from datetime import date + + from forge.workflow.stats.report_ticket import create_report_ticket + + mock_jira = MagicMock() + mock_jira.close = AsyncMock() + mock_jira.create_task = AsyncMock(return_value="PROJ-99") + + week_start = date(2024, 1, 8) + + with patch( + "forge.workflow.stats.report_ticket.JiraClient", + return_value=mock_jira, + ): + key = await create_report_ticket("PROJ", week_start, "## Report") + + assert key == "PROJ-99" + call_kwargs = mock_jira.create_task.call_args.kwargs + assert "Forge Weekly Report" in call_kwargs["summary"] + assert "PROJ" in call_kwargs["summary"] + assert "2024-01-08" in call_kwargs["summary"] + + @pytest.mark.asyncio + async def test_ticket_created_with_required_labels(self): + """Report ticket is created with both required labels.""" + from datetime import date + + from forge.workflow.stats.report_ticket import create_report_ticket + + mock_jira = MagicMock() + mock_jira.close = AsyncMock() + mock_jira.create_task = AsyncMock(return_value="PROJ-99") + + with patch( + "forge.workflow.stats.report_ticket.JiraClient", + return_value=mock_jira, + ): + await create_report_ticket("PROJ", date(2024, 1, 8), "## Report") + + call_kwargs = mock_jira.create_task.call_args.kwargs + assert "forge:weekly-report" in call_kwargs["labels"] + assert "forge:generated" in call_kwargs["labels"] + + @pytest.mark.asyncio + async def test_ticket_created_with_report_content(self): + """The report markdown is passed as the description.""" + from datetime import date + + from forge.workflow.stats.report_ticket import create_report_ticket + + mock_jira = MagicMock() + mock_jira.close = AsyncMock() + mock_jira.create_task = AsyncMock(return_value="PROJ-99") + + report_md = "## Weekly Report\n\nSome content here." + + with patch( + "forge.workflow.stats.report_ticket.JiraClient", + return_value=mock_jira, + ): + await create_report_ticket("PROJ", date(2024, 1, 8), report_md) + + call_kwargs = mock_jira.create_task.call_args.kwargs + assert call_kwargs["description"] == report_md + + @pytest.mark.asyncio + async def test_jira_client_closed_after_creation(self): + """JiraClient.close() is always called after ticket creation.""" + from datetime import date + + from forge.workflow.stats.report_ticket import create_report_ticket + + mock_jira = MagicMock() + mock_jira.close = AsyncMock() + mock_jira.create_task = AsyncMock(return_value="PROJ-99") + + with patch( + "forge.workflow.stats.report_ticket.JiraClient", + return_value=mock_jira, + ): + await create_report_ticket("PROJ", date(2024, 1, 8), "## Report") + + mock_jira.close.assert_awaited_once() + + +# --------------------------------------------------------------------------- +# Section 9: test_report_ticket_update_idempotency +# --------------------------------------------------------------------------- + + +class TestReportTicketUpdateIdempotency: + """Verifies updating existing ticket works and is idempotent.""" + + @pytest.mark.asyncio + async def test_existing_ticket_is_updated_not_recreated(self): + """ensure_report_ticket updates the description instead of creating a new ticket.""" + from datetime import date + + from forge.integrations.jira.models import JiraIssue + from forge.workflow.stats.report_ticket import ensure_report_ticket + + existing_ticket = JiraIssue( + key="PROJ-42", + id="100", + summary="Forge Weekly Report - PROJ - Week of 2024-01-08", + description="", + status="Open", + issue_type="Task", + ) + + mock_jira_resolve = MagicMock() + mock_jira_resolve.close = AsyncMock() + mock_jira_resolve.search_issues = AsyncMock(return_value=[existing_ticket]) + + mock_jira_update = MagicMock() + mock_jira_update.close = AsyncMock() + mock_jira_update.update_description = AsyncMock() + + jira_instances = iter([mock_jira_resolve, mock_jira_update]) + + with patch( + "forge.workflow.stats.report_ticket.JiraClient", + side_effect=jira_instances, + ): + ticket_key = await ensure_report_ticket("PROJ", date(2024, 1, 8), "## Report content") + + assert ticket_key == "PROJ-42" + mock_jira_resolve.search_issues.assert_awaited_once() + mock_jira_update.update_description.assert_awaited_once() + + @pytest.mark.asyncio + async def test_new_ticket_created_when_not_found(self): + """ensure_report_ticket creates a new ticket when none exists.""" + from datetime import date + + from forge.workflow.stats.report_ticket import ensure_report_ticket + + mock_jira_search = MagicMock() + mock_jira_search.close = AsyncMock() + mock_jira_search.search_issues = AsyncMock(return_value=[]) + + mock_jira_create = MagicMock() + mock_jira_create.close = AsyncMock() + mock_jira_create.create_task = AsyncMock(return_value="PROJ-100") + + # The update call after create + mock_jira_update = MagicMock() + mock_jira_update.close = AsyncMock() + mock_jira_update.update_description = AsyncMock() + + jira_instances = iter([mock_jira_search, mock_jira_create, mock_jira_update]) + + with patch( + "forge.workflow.stats.report_ticket.JiraClient", + side_effect=jira_instances, + ): + ticket_key = await ensure_report_ticket("PROJ", date(2024, 1, 8), "## New report") + + assert ticket_key == "PROJ-100" + mock_jira_create.create_task.assert_awaited_once() + + @pytest.mark.asyncio + async def test_update_called_with_correct_content(self): + """update_report_ticket passes the correct markdown to Jira.""" + from forge.workflow.stats.report_ticket import update_report_ticket + + mock_jira = MagicMock() + mock_jira.close = AsyncMock() + mock_jira.update_description = AsyncMock() + + report_md = "# Updated Weekly Report\n\nNew content." + + with patch( + "forge.workflow.stats.report_ticket.JiraClient", + return_value=mock_jira, + ): + await update_report_ticket("PROJ-42", report_md) + + mock_jira.update_description.assert_awaited_once_with("PROJ-42", report_md) + + @pytest.mark.asyncio + async def test_calling_twice_does_not_create_duplicate(self): + """Calling ensure_report_ticket a second time updates, not creates.""" + from datetime import date + + from forge.integrations.jira.models import JiraIssue + from forge.workflow.stats.report_ticket import ensure_report_ticket + + existing_ticket = JiraIssue( + key="PROJ-42", + id="100", + summary="Forge Weekly Report - PROJ - Week of 2024-01-08", + description="", + status="Open", + issue_type="Task", + ) + + create_task_mock = AsyncMock(return_value="PROJ-NEW") + update_desc_mock = AsyncMock() + search_mock = AsyncMock(return_value=[existing_ticket]) + + def _make_jira() -> MagicMock: + m = MagicMock() + m.close = AsyncMock() + m.search_issues = search_mock + m.create_task = create_task_mock + m.update_description = update_desc_mock + return m + + with patch( + "forge.workflow.stats.report_ticket.JiraClient", + side_effect=_make_jira, + ): + key1 = await ensure_report_ticket("PROJ", date(2024, 1, 8), "v1") + key2 = await ensure_report_ticket("PROJ", date(2024, 1, 8), "v2") + + # No create_task should have been called since the ticket already exists + create_task_mock.assert_not_awaited() + # Both calls updated the description + assert update_desc_mock.await_count == 2 + assert key1 == "PROJ-42" + assert key2 == "PROJ-42" + + @pytest.mark.asyncio + async def test_missing_stats_fields_handled_gracefully(self): + """Checkpoints with missing optional stats fields still produce TicketSummary.""" + # A checkpoint that has stage_timestamps present but with missing optional fields + checkpoint = { + "ticket_key": "PROJ-70", + "ticket_type": "Feature", + "stage_timestamps": { + "prd": { + "stage_name": "prd", + # input_tokens, output_tokens, etc. intentionally absent + "started_at": _ONE_DAY_AGO, + "ended_at": _ONE_DAY_AGO, + } + }, + # workflow_outcome, stats_ci_cycles, is_blocked intentionally absent + "updated_at": _ONE_DAY_AGO, + } + redis = _build_redis_mock({"PROJ-70": checkpoint}) + jira = MagicMock() + jira.close = AsyncMock() + jira.get_issue = AsyncMock(side_effect=Exception("hierarchy not needed")) + + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis), + ), + patch( + "forge.workflow.stats.weekly_report.JiraClient", + return_value=jira, + ), + ): + report = await collect_weekly_data("PROJ", days=7) + + # Should still parse without crashing; tokens default to 0 + assert len(report.all_tickets) == 1 + ticket = report.all_tickets[0] + assert ticket.ticket_key == "PROJ-70" + assert ticket.input_tokens == 0 + assert ticket.output_tokens == 0 + + +# --------------------------------------------------------------------------- +# Section 10: test_notification_delivery +# --------------------------------------------------------------------------- + + +class TestNotificationDelivery: + """Verifies notification comment is posted.""" + + @pytest.mark.asyncio + async def test_notification_comment_posted(self): + """notify_report_ready posts a comment to the Jira ticket.""" + from forge.workflow.stats.notifications import notify_report_ready + + mock_jira = MagicMock() + mock_jira.close = AsyncMock() + mock_jira.add_comment = AsyncMock() + + with patch( + "forge.workflow.stats.notifications.JiraClient", + return_value=mock_jira, + ): + await notify_report_ready( + "PROJ-42", + ["user1", "user2"], + jira_base_url="https://test.atlassian.net", + ) + + mock_jira.add_comment.assert_awaited_once() + + @pytest.mark.asyncio + async def test_notification_posted_to_correct_ticket(self): + """notify_report_ready posts the comment to the specified ticket key.""" + from forge.workflow.stats.notifications import notify_report_ready + + mock_jira = MagicMock() + mock_jira.close = AsyncMock() + mock_jira.add_comment = AsyncMock() + + with patch( + "forge.workflow.stats.notifications.JiraClient", + return_value=mock_jira, + ): + await notify_report_ready( + "PROJ-99", + ["user1"], + jira_base_url="https://test.atlassian.net", + ) + + call_args = mock_jira.add_comment.call_args + ticket_arg = call_args[0][0] + assert ticket_arg == "PROJ-99" + + @pytest.mark.asyncio + async def test_notification_comment_mentions_recipients(self): + """The notification comment body mentions each recipient.""" + from forge.workflow.stats.notifications import notify_report_ready + + mock_jira = MagicMock() + mock_jira.close = AsyncMock() + mock_jira.add_comment = AsyncMock() + + with patch( + "forge.workflow.stats.notifications.JiraClient", + return_value=mock_jira, + ): + await notify_report_ready( + "PROJ-42", + ["abc123", "def456"], + jira_base_url="https://test.atlassian.net", + ) + + call_args = mock_jira.add_comment.call_args + comment_body = call_args[0][1] + assert "abc123" in comment_body + assert "def456" in comment_body + + @pytest.mark.asyncio + async def test_no_notification_for_empty_recipients(self): + """notify_report_ready does not post when recipients list is empty.""" + from forge.workflow.stats.notifications import notify_report_ready + + mock_jira = MagicMock() + mock_jira.close = AsyncMock() + mock_jira.add_comment = AsyncMock() + + with patch( + "forge.workflow.stats.notifications.JiraClient", + return_value=mock_jira, + ): + await notify_report_ready( + "PROJ-42", + [], + jira_base_url="https://test.atlassian.net", + ) + + mock_jira.add_comment.assert_not_awaited() + + @pytest.mark.asyncio + async def test_jira_client_closed_after_notification(self): + """JiraClient.close() is always called after notification delivery.""" + from forge.workflow.stats.notifications import notify_report_ready + + mock_jira = MagicMock() + mock_jira.close = AsyncMock() + mock_jira.add_comment = AsyncMock() + + with patch( + "forge.workflow.stats.notifications.JiraClient", + return_value=mock_jira, + ): + await notify_report_ready( + "PROJ-42", + ["user1"], + jira_base_url="https://test.atlassian.net", + ) + + mock_jira.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_notification_comment_includes_ticket_link(self): + """The notification comment body includes a link to the report ticket.""" + from forge.workflow.stats.notifications import notify_report_ready + + mock_jira = MagicMock() + mock_jira.close = AsyncMock() + mock_jira.add_comment = AsyncMock() + + with patch( + "forge.workflow.stats.notifications.JiraClient", + return_value=mock_jira, + ): + await notify_report_ready( + "PROJ-42", + ["user1"], + jira_base_url="https://test.atlassian.net", + ) + + call_args = mock_jira.add_comment.call_args + comment_body = call_args[0][1] + assert "PROJ-42" in comment_body + + @pytest.mark.asyncio + async def test_invalid_account_ids_are_skipped(self): + """Account IDs containing spaces or commas are skipped with a warning.""" + from forge.workflow.stats.notifications import notify_report_ready + + mock_jira = MagicMock() + mock_jira.close = AsyncMock() + mock_jira.add_comment = AsyncMock() + + with patch( + "forge.workflow.stats.notifications.JiraClient", + return_value=mock_jira, + ): + # "bad id" has a space and "another,bad" has a comma — both invalid + await notify_report_ready( + "PROJ-42", + ["bad id", "another,bad"], + jira_base_url="https://test.atlassian.net", + ) + + # All recipients are invalid so no comment is posted + mock_jira.add_comment.assert_not_awaited() diff --git a/tests/scripts/create-test-feature.py b/tests/scripts/create-test-feature.py old mode 100755 new mode 100644 diff --git a/tests/test_sandbox_runner.py b/tests/test_sandbox_runner.py index e4e02c24..c10e5655 100644 --- a/tests/test_sandbox_runner.py +++ b/tests/test_sandbox_runner.py @@ -1,6 +1,7 @@ """Quick tests for container sandbox runner.""" import asyncio +import shutil import tempfile from pathlib import Path @@ -13,17 +14,19 @@ class TestContainerRunner: """Tests for ContainerRunner.""" + @pytest.mark.skipif(not shutil.which("podman"), reason="podman not found") def test_runner_init(self): """Test runner initializes correctly.""" runner = ContainerRunner() assert runner is not None + @pytest.mark.skipif(not shutil.which("podman"), reason="podman not found") def test_podman_exists(self): """Test podman is available.""" - import shutil assert shutil.which("podman") is not None @pytest.mark.asyncio + @pytest.mark.skipif(not shutil.which("podman"), reason="podman not found") async def test_image_exists_returns_false_for_missing(self): """Test image_exists returns False for non-existent image.""" runner = ContainerRunner() @@ -31,6 +34,7 @@ async def test_image_exists_returns_false_for_missing(self): assert exists is False @pytest.mark.asyncio + @pytest.mark.skipif(not shutil.which("podman"), reason="podman not found") async def test_simple_container_run(self): """Test running a simple container with alpine.""" # Create a minimal test workspace @@ -46,10 +50,14 @@ async def test_simple_container_run(self): result = subprocess.run( [ - "podman", "run", "--rm", - "-v", f"{workspace}:/workspace:Z", + "podman", + "run", + "--rm", + "-v", + f"{workspace}:/workspace:Z", "alpine:latest", - "cat", "/workspace/test.txt", + "cat", + "/workspace/test.txt", ], capture_output=True, text=True, diff --git a/tests/unit/integrations/jira/test_client.py b/tests/unit/integrations/jira/test_client.py index 0b011f9b..688b233e 100644 --- a/tests/unit/integrations/jira/test_client.py +++ b/tests/unit/integrations/jira/test_client.py @@ -220,9 +220,7 @@ async def test_archive_issue_updates_labels_unlinks_parent_and_archives_natively assert mock_http.put.await_args_list[1].kwargs["json"] == {"fields": {"parent": None}} assert mock_http.put.await_args_list[2].args[0] == "/issue/archive" - assert mock_http.put.await_args_list[2].kwargs["json"] == { - "issueIdsOrKeys": ["TEST-123"] - } + assert mock_http.put.await_args_list[2].kwargs["json"] == {"issueIdsOrKeys": ["TEST-123"]} @pytest.mark.asyncio async def test_archive_issue_logs_native_archive_body_errors(self, mock_client, caplog): @@ -779,3 +777,48 @@ async def test_parses_json_string_value(self, jira_client): assert len(result) == 1 assert isinstance(result[0], SkillEntry) assert result[0].source == "https://github.com/acme/skills" + + +class TestJiraClientGetServiceAccountId: + """Tests for get_service_account_id method.""" + + @pytest.mark.asyncio + async def test_get_service_account_id_success(self, jira_client): + """Returns the accountId and caches it.""" + import forge.integrations.jira.client as client_module + + client_module._service_account_id_cache = None + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "accountId": "resolved-id-123", + "displayName": "Service Account", + } + mock_response.raise_for_status = MagicMock() + + with patch.object(jira_client, "_get_client") as mock_get_client: + mock_http = AsyncMock() + mock_http.request = AsyncMock(return_value=mock_response) + mock_get_client.return_value = mock_http + + result = await jira_client.get_service_account_id() + + assert result == "resolved-id-123" + assert client_module._service_account_id_cache == "resolved-id-123" + + @pytest.mark.asyncio + async def test_get_service_account_id_cached(self, jira_client): + """Returns cached accountId without making an HTTP request.""" + import forge.integrations.jira.client as client_module + + client_module._service_account_id_cache = "cached-id-456" + + with patch.object(jira_client, "_get_client") as mock_get_client: + mock_http = AsyncMock() + mock_get_client.return_value = mock_http + + result = await jira_client.get_service_account_id() + + mock_http.request.assert_not_called() + assert result == "cached-id-456" diff --git a/tests/unit/orchestrator/test_worker_forge_stats.py b/tests/unit/orchestrator/test_worker_forge_stats.py new file mode 100644 index 00000000..e8d6f2a3 --- /dev/null +++ b/tests/unit/orchestrator/test_worker_forge_stats.py @@ -0,0 +1,468 @@ +"""Unit tests for the /forge stats Jira comment command handler.""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from forge.models.events import EventSource +from forge.orchestrator.worker import OrchestratorWorker +from forge.queue.models import QueueMessage + + +def _make_jira_message(ticket_key: str, comment_body: str) -> QueueMessage: + """Create a Jira comment QueueMessage.""" + return QueueMessage( + message_id="1234567890-0", + event_id="test-event-001", + source=EventSource.JIRA, + event_type="comment_created", + ticket_key=ticket_key, + payload={ + "issue": { + "key": ticket_key, + "fields": { + "issuetype": {"name": "Feature"}, + "labels": [], + }, + }, + "comment": {"body": comment_body}, + "changelog": {"items": []}, + }, + ) + + +def _base_state(ticket_key: str = "TEST-123", **overrides) -> dict: + """Return a minimal workflow state dict.""" + return { + "ticket_key": ticket_key, + "ticket_type": "Feature", + "current_node": "prd_approval_gate", + "is_paused": True, + "context": {}, + "stage_timestamps": { + "prd": { + "stage_name": "prd", + "iteration_count": 1, + "machine_time_seconds": 30.0, + "human_time_seconds": 120.0, + "input_tokens": 500, + "output_tokens": 800, + } + }, + "stats_pr_urls": [], + "stats_ci_cycles": 0, + "workflow_outcome": None, + "stats_outcome_reason": None, + **overrides, + } + + +@pytest.fixture +def worker() -> OrchestratorWorker: + return OrchestratorWorker(consumer_name="test-worker") + + +@pytest.fixture +def mock_jira(): + """Return a mock JiraClient that is also an async context manager.""" + jira = AsyncMock() + jira.add_comment = AsyncMock() + jira.close = AsyncMock() + return jira + + +class TestForgeStatsCommandDetection: + """Tests that /forge stats is detected case-insensitively.""" + + @pytest.mark.asyncio + async def test_forge_stats_detected_lowercase(self, worker: OrchestratorWorker, mock_jira): + """/forge stats (lowercase) triggers stats posting.""" + message = _make_jira_message("TEST-123", "/forge stats") + state = _base_state() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + result = await worker._handle_resume_event(message, state) + + assert result is state, "State must be returned unchanged" + mock_jira.add_comment.assert_awaited_once() + mock_jira.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_forge_stats_detected_uppercase(self, worker: OrchestratorWorker, mock_jira): + """/FORGE STATS (uppercase) triggers stats posting.""" + message = _make_jira_message("TEST-123", "/FORGE STATS") + state = _base_state() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + result = await worker._handle_resume_event(message, state) + + assert result is state + + @pytest.mark.asyncio + async def test_forge_stats_detected_mixed_case(self, worker: OrchestratorWorker, mock_jira): + """/Forge Stats (mixed case) triggers stats posting.""" + message = _make_jira_message("TEST-123", "/Forge Stats") + state = _base_state() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + result = await worker._handle_resume_event(message, state) + + assert result is state + + @pytest.mark.asyncio + async def test_forge_stats_with_trailing_text(self, worker: OrchestratorWorker, mock_jira): + """/forge stats with unknown trailing subcommand is treated as informational (no post).""" + message = _make_jira_message("TEST-123", "/forge stats please show me") + state = _base_state() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + result = await worker._handle_resume_event(message, state) + + # Unknown subcommand is informational — state is returned unchanged, no comment posted + assert result is state + mock_jira.add_comment.assert_not_awaited() + + @pytest.mark.asyncio + async def test_forge_stats_with_leading_whitespace(self, worker: OrchestratorWorker, mock_jira): + """Leading whitespace before /forge stats is stripped before matching.""" + message = _make_jira_message("TEST-123", " /forge stats") + state = _base_state() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + result = await worker._handle_resume_event(message, state) + + assert result is state + mock_jira.add_comment.assert_awaited_once() + + @pytest.mark.asyncio + async def test_non_forge_stats_comment_not_intercepted( + self, worker: OrchestratorWorker, mock_jira + ): + """Comments not starting with /forge stats are processed normally.""" + message = _make_jira_message("TEST-123", "!Please revise the PRD") + state = _base_state() + + with ( + patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira), + patch("forge.orchestrator.worker.post_status_comment", new_callable=AsyncMock), + ): + result = await worker._handle_resume_event(message, state) + + # Should be treated as a revision request, not a stats command + assert result is not state or result.get("revision_requested") is True + + +class TestForgeStatsReturnStateUnchanged: + """Tests that /forge stats returns the current state without modification.""" + + @pytest.mark.asyncio + async def test_state_identity_returned(self, worker: OrchestratorWorker, mock_jira): + """The exact same state object is returned (identity check).""" + message = _make_jira_message("TEST-123", "/forge stats") + state = _base_state() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + result = await worker._handle_resume_event(message, state) + + assert result is state + + @pytest.mark.asyncio + async def test_is_paused_not_modified(self, worker: OrchestratorWorker, mock_jira): + """is_paused flag is not changed by /forge stats command.""" + message = _make_jira_message("TEST-123", "/forge stats") + state = _base_state(is_paused=True) + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + result = await worker._handle_resume_event(message, state) + + assert result["is_paused"] is True + + @pytest.mark.asyncio + async def test_current_node_not_modified(self, worker: OrchestratorWorker, mock_jira): + """current_node is not changed by /forge stats command.""" + message = _make_jira_message("TEST-123", "/forge stats") + state = _base_state(current_node="spec_approval_gate") + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + result = await worker._handle_resume_event(message, state) + + assert result["current_node"] == "spec_approval_gate" + + +class TestForgeStatsRetrieval: + """Tests for stats retrieval and formatting.""" + + @pytest.mark.asyncio + async def test_posts_formatted_stats_to_correct_ticket( + self, worker: OrchestratorWorker, mock_jira + ): + """The stats comment is posted to the ticket from the message.""" + message = _make_jira_message("PROJ-456", "/forge stats") + state = _base_state(ticket_key="PROJ-456") + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, state) + + mock_jira.add_comment.assert_awaited_once() + call_args = mock_jira.add_comment.await_args + assert call_args.args[0] == "PROJ-456" + + @pytest.mark.asyncio + async def test_posted_comment_contains_stats_heading( + self, worker: OrchestratorWorker, mock_jira + ): + """The posted comment includes a workflow statistics section.""" + message = _make_jira_message("TEST-123", "/forge stats") + state = _base_state() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, state) + + comment_body = mock_jira.add_comment.await_args.args[1] + assert "Workflow Statistics" in comment_body + + @pytest.mark.asyncio + async def test_stats_uses_pre_set_outcome(self, worker: OrchestratorWorker, mock_jira): + """When workflow_outcome is set in state, it is used in the formatted output.""" + message = _make_jira_message("TEST-123", "/forge stats") + state = _base_state(workflow_outcome="Completed") + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, state) + + comment_body = mock_jira.add_comment.await_args.args[1] + assert "Completed" in comment_body + + @pytest.mark.asyncio + async def test_stats_derives_blocked_outcome(self, worker: OrchestratorWorker, mock_jira): + """When is_blocked=True and no pre-set outcome, outcome is 'Blocked'.""" + message = _make_jira_message("TEST-123", "/forge stats") + state = _base_state(is_blocked=True, workflow_outcome=None) + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, state) + + comment_body = mock_jira.add_comment.await_args.args[1] + assert "Blocked" in comment_body + + @pytest.mark.asyncio + async def test_stats_derives_failed_outcome(self, worker: OrchestratorWorker, mock_jira): + """When last_error is set and no pre-set outcome, outcome is 'Failed'.""" + message = _make_jira_message("TEST-123", "/forge stats") + state = _base_state(last_error="Something went wrong", workflow_outcome=None) + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, state) + + comment_body = mock_jira.add_comment.await_args.args[1] + assert "Failed" in comment_body + + @pytest.mark.asyncio + async def test_stats_in_progress_outcome_for_active_workflow( + self, worker: OrchestratorWorker, mock_jira + ): + """Active workflow with no error/blocked status uses 'In Progress' outcome.""" + message = _make_jira_message("TEST-123", "/forge stats") + state = _base_state(workflow_outcome=None, is_blocked=False, last_error=None) + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, state) + + comment_body = mock_jira.add_comment.await_args.args[1] + assert "In Progress" in comment_body + + +class TestForgeStatsMissingCheckpoint: + """Tests for graceful handling when no stats data is present.""" + + @pytest.mark.asyncio + async def test_no_stage_timestamps_posts_no_data_message( + self, worker: OrchestratorWorker, mock_jira + ): + """When stage_timestamps key is missing, posts 'No workflow data found.' message.""" + message = _make_jira_message("TEST-123", "/forge stats") + state = { + "ticket_key": "TEST-123", + "current_node": "prd_approval_gate", + "is_paused": True, + "context": {}, + # stage_timestamps is absent entirely + } + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + result = await worker._handle_resume_event(message, state) + + assert result is state + mock_jira.add_comment.assert_awaited_once() + comment_body = mock_jira.add_comment.await_args.args[1] + assert "No workflow data found" in comment_body + + @pytest.mark.asyncio + async def test_empty_stage_timestamps_still_formats( + self, worker: OrchestratorWorker, mock_jira + ): + """Empty stage_timestamps dict (workflow just started) still produces formatted output.""" + message = _make_jira_message("TEST-123", "/forge stats") + state = _base_state(stage_timestamps={}) + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + result = await worker._handle_resume_event(message, state) + + assert result is state + mock_jira.add_comment.assert_awaited_once() + comment_body = mock_jira.add_comment.await_args.args[1] + # Should contain the stats table, not the "no data" message + assert "Workflow Statistics" in comment_body + + @pytest.mark.asyncio + async def test_no_stats_returns_state_unchanged(self, worker: OrchestratorWorker, mock_jira): + """Even when no data is found, current state is returned unchanged.""" + message = _make_jira_message("TEST-123", "/forge stats") + state = { + "ticket_key": "TEST-123", + "current_node": "prd_approval_gate", + "is_paused": True, + } + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + result = await worker._handle_resume_event(message, state) + + assert result is state + + +class TestForgeStatsErrorHandling: + """Tests for error resilience in the stats command handler.""" + + @pytest.mark.asyncio + async def test_jira_add_comment_failure_does_not_raise(self, worker: OrchestratorWorker): + """JiraClient.add_comment failure is caught and does not propagate.""" + message = _make_jira_message("TEST-123", "/forge stats") + state = _base_state() + + mock_jira = AsyncMock() + mock_jira.add_comment = AsyncMock(side_effect=Exception("Jira API error")) + mock_jira.close = AsyncMock() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + # Should not raise + result = await worker._handle_resume_event(message, state) + + assert result is state + + @pytest.mark.asyncio + async def test_formatter_failure_posts_fallback_message( + self, worker: OrchestratorWorker, mock_jira + ): + """When the formatter raises, a fallback message is posted.""" + message = _make_jira_message("TEST-123", "/forge stats") + state = _base_state() + + with ( + patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira), + patch( + "forge.orchestrator.worker.format_stats_summary", + side_effect=RuntimeError("formatter error"), + ), + ): + result = await worker._handle_resume_event(message, state) + + assert result is state + mock_jira.add_comment.assert_awaited_once() + comment_body = mock_jira.add_comment.await_args.args[1] + assert "Unable to format" in comment_body + + @pytest.mark.asyncio + async def test_jira_close_always_called_on_success(self, worker: OrchestratorWorker, mock_jira): + """JiraClient.close() is called even after a successful add_comment.""" + message = _make_jira_message("TEST-123", "/forge stats") + state = _base_state() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, state) + + mock_jira.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_jira_close_called_even_after_no_data_path( + self, worker: OrchestratorWorker, mock_jira + ): + """JiraClient.close() is called in the 'no data' path too.""" + message = _make_jira_message("TEST-123", "/forge stats") + state = {"ticket_key": "TEST-123", "current_node": "prd_approval_gate"} + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_resume_event(message, state) + + mock_jira.close.assert_awaited_once() + + +class TestHandleStatsCommandDirect: + """Direct unit tests for _handle_stats_command.""" + + @pytest.mark.asyncio + async def test_direct_call_with_stats(self, worker: OrchestratorWorker, mock_jira): + """Direct call with stats data posts a formatted comment.""" + state = _base_state() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_stats_command("TEST-123", state) + + mock_jira.add_comment.assert_awaited_once() + args = mock_jira.add_comment.await_args.args + assert args[0] == "TEST-123" + assert "Workflow Statistics" in args[1] + + @pytest.mark.asyncio + async def test_direct_call_without_stage_timestamps( + self, worker: OrchestratorWorker, mock_jira + ): + """Direct call when stage_timestamps is missing posts 'No workflow data found.'.""" + state = {"ticket_key": "TEST-123", "current_node": "prd_approval_gate"} + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + await worker._handle_stats_command("TEST-123", state) + + mock_jira.add_comment.assert_awaited_once() + body = mock_jira.add_comment.await_args.args[1] + assert "No workflow data found" in body + + @pytest.mark.asyncio + async def test_uses_stats_outcome_reason_as_detail(self, worker: OrchestratorWorker, mock_jira): + """stats_outcome_reason is passed as outcome_detail to the formatter.""" + state = _base_state( + workflow_outcome="Blocked", + stats_outcome_reason="Waiting for security review", + ) + + with ( + patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira), + patch("forge.orchestrator.worker.format_stats_summary") as mock_format, + ): + mock_format.return_value = "formatted stats" + await worker._handle_stats_command("TEST-123", state) + + mock_format.assert_called_once_with( + state, "Blocked", "Waiting for security review", pricing=worker.settings.llm_pricing + ) + + @pytest.mark.asyncio + async def test_uses_last_error_as_detail_when_no_reason( + self, worker: OrchestratorWorker, mock_jira + ): + """last_error is used as outcome_detail when stats_outcome_reason is absent.""" + state = _base_state( + workflow_outcome=None, + last_error="Connection timeout", + stats_outcome_reason=None, + ) + + with ( + patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira), + patch("forge.orchestrator.worker.format_stats_summary") as mock_format, + ): + mock_format.return_value = "formatted stats" + await worker._handle_stats_command("TEST-123", state) + + _, called_outcome, called_detail = mock_format.call_args.args + assert called_outcome == "Failed" + assert called_detail == "Connection timeout" diff --git a/tests/unit/orchestrator/test_worker_forge_stats_retry.py b/tests/unit/orchestrator/test_worker_forge_stats_retry.py new file mode 100644 index 00000000..3cff13da --- /dev/null +++ b/tests/unit/orchestrator/test_worker_forge_stats_retry.py @@ -0,0 +1,512 @@ +"""Unit tests for the /forge stats retry subcommand handler.""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from forge.models.events import EventSource +from forge.orchestrator.worker import OrchestratorWorker +from forge.queue.models import QueueMessage + + +def _make_jira_message(ticket_key: str, comment_body: str) -> QueueMessage: + """Create a Jira comment QueueMessage.""" + return QueueMessage( + message_id="1234567890-0", + event_id="test-event-001", + source=EventSource.JIRA, + event_type="comment_created", + ticket_key=ticket_key, + payload={ + "issue": { + "key": ticket_key, + "fields": { + "issuetype": {"name": "Feature"}, + "labels": [], + }, + }, + "comment": {"body": comment_body}, + "changelog": {"items": []}, + }, + ) + + +def _base_state(ticket_key: str = "TEST-123", **overrides) -> dict: + """Return a minimal workflow state dict with stats data.""" + return { + "ticket_key": ticket_key, + "ticket_type": "Feature", + "current_node": "prd_approval_gate", + "is_paused": True, + "context": {}, + "stage_timestamps": { + "prd": { + "stage_name": "prd", + "iteration_count": 1, + "machine_time_seconds": 30.0, + "human_time_seconds": 120.0, + "input_tokens": 500, + "output_tokens": 800, + } + }, + "stats_pr_urls": [], + "stats_ci_cycles": 0, + "workflow_outcome": None, + "stats_outcome_reason": None, + **overrides, + } + + +@pytest.fixture +def worker() -> OrchestratorWorker: + return OrchestratorWorker(consumer_name="test-worker") + + +@pytest.fixture +def mock_jira(): + """Return a mock JiraClient that is also an async context manager.""" + jira = AsyncMock() + jira.add_comment = AsyncMock() + jira.close = AsyncMock() + return jira + + +class TestForgeStatsRetryDetection: + """Tests that /forge stats retry is detected distinctly from base /forge stats.""" + + @pytest.mark.asyncio + async def test_retry_detected_lowercase(self, worker: OrchestratorWorker): + """/forge stats retry (lowercase) triggers the retry handler.""" + message = _make_jira_message("TEST-123", "/forge stats retry") + state = _base_state() + + with patch.object( + worker, "_handle_stats_retry_command", new_callable=AsyncMock + ) as mock_retry: + result = await worker._handle_resume_event(message, state) + + mock_retry.assert_awaited_once_with("TEST-123", state) + assert result is state + + @pytest.mark.asyncio + async def test_retry_detected_uppercase(self, worker: OrchestratorWorker): + """/FORGE STATS RETRY (uppercase) triggers the retry handler.""" + message = _make_jira_message("TEST-123", "/FORGE STATS RETRY") + state = _base_state() + + with patch.object( + worker, "_handle_stats_retry_command", new_callable=AsyncMock + ) as mock_retry: + result = await worker._handle_resume_event(message, state) + + mock_retry.assert_awaited_once_with("TEST-123", state) + assert result is state + + @pytest.mark.asyncio + async def test_retry_detected_mixed_case(self, worker: OrchestratorWorker): + """/Forge Stats Retry (mixed case) triggers the retry handler.""" + message = _make_jira_message("TEST-123", "/Forge Stats Retry") + state = _base_state() + + with patch.object( + worker, "_handle_stats_retry_command", new_callable=AsyncMock + ) as mock_retry: + result = await worker._handle_resume_event(message, state) + + mock_retry.assert_awaited_once_with("TEST-123", state) + assert result is state + + @pytest.mark.asyncio + async def test_retry_returns_state_unchanged(self, worker: OrchestratorWorker): + """/forge stats retry returns current state without modification.""" + message = _make_jira_message("TEST-123", "/forge stats retry") + state = _base_state(current_node="spec_approval_gate", is_paused=True) + + with patch.object(worker, "_handle_stats_retry_command", new_callable=AsyncMock): + result = await worker._handle_resume_event(message, state) + + assert result is state + assert result["current_node"] == "spec_approval_gate" + assert result["is_paused"] is True + + @pytest.mark.asyncio + async def test_base_stats_uses_base_handler(self, worker: OrchestratorWorker): + """Plain /forge stats (no subcommand) uses the base handler, not retry.""" + message = _make_jira_message("TEST-123", "/forge stats") + state = _base_state() + + base_called = [] + retry_called = [] + + with ( + patch.object( + worker, + "_handle_stats_command", + new_callable=AsyncMock, + side_effect=lambda *_a, **_kw: base_called.append(True), + ), + patch.object( + worker, + "_handle_stats_retry_command", + new_callable=AsyncMock, + side_effect=lambda *_a, **_kw: retry_called.append(True), + ), + ): + result = await worker._handle_resume_event(message, state) + + assert len(base_called) == 1, "Base handler should be called once" + assert len(retry_called) == 0, "Retry handler should NOT be called" + assert result is state + + @pytest.mark.asyncio + async def test_retry_does_not_call_base_handler(self, worker: OrchestratorWorker): + """/forge stats retry does not invoke the base stats handler.""" + message = _make_jira_message("TEST-123", "/forge stats retry") + state = _base_state() + + base_called = [] + retry_called = [] + + with ( + patch.object( + worker, + "_handle_stats_command", + new_callable=AsyncMock, + side_effect=lambda *_a, **_kw: base_called.append(True), + ), + patch.object( + worker, + "_handle_stats_retry_command", + new_callable=AsyncMock, + side_effect=lambda *_a, **_kw: retry_called.append(True), + ), + ): + result = await worker._handle_resume_event(message, state) + + assert len(retry_called) == 1, "Retry handler should be called once" + assert len(base_called) == 0, "Base handler should NOT be called" + assert result is state + + +class TestForgeStatsUnknownSubcommand: + """Tests that unknown /forge stats subcommands are handled gracefully.""" + + @pytest.mark.asyncio + async def test_unknown_subcommand_returns_state_unchanged(self, worker: OrchestratorWorker): + """Unknown /forge stats subcommand returns current state without posting.""" + message = _make_jira_message("TEST-123", "/forge stats unknown-command") + state = _base_state() + + with ( + patch.object(worker, "_handle_stats_command", new_callable=AsyncMock) as mock_base, + patch.object( + worker, "_handle_stats_retry_command", new_callable=AsyncMock + ) as mock_retry, + ): + result = await worker._handle_resume_event(message, state) + + # Neither handler should be called for an unknown subcommand + mock_base.assert_not_awaited() + mock_retry.assert_not_awaited() + assert result is state + + @pytest.mark.asyncio + async def test_unknown_subcommand_does_not_post_comment( + self, worker: OrchestratorWorker, mock_jira + ): + """Unknown subcommand does not post any comment to Jira.""" + message = _make_jira_message("TEST-123", "/forge stats foobar") + state = _base_state() + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + result = await worker._handle_resume_event(message, state) + + mock_jira.add_comment.assert_not_awaited() + assert result is state + + @pytest.mark.asyncio + async def test_unknown_subcommand_is_informational_not_error( + self, worker: OrchestratorWorker, mock_jira + ): + """Unknown subcommand does not trigger revision request or any workflow change.""" + message = _make_jira_message("TEST-123", "/forge stats bogus") + state = _base_state(is_paused=True, current_node="prd_approval_gate") + + with patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira): + result = await worker._handle_resume_event(message, state) + + # State must be returned unchanged — workflow not resumed + assert result is state + assert result["is_paused"] is True + assert result["current_node"] == "prd_approval_gate" + + +class TestForgeStatsRetryRepostBehavior: + """Tests that /forge stats retry uses the re-post mechanism.""" + + @pytest.mark.asyncio + async def test_retry_calls_ensure_stats_is_final_comment(self, worker: OrchestratorWorker): + """/forge stats retry calls ensure_stats_is_final_comment for re-posting.""" + state = _base_state() + + with patch( + "forge.orchestrator.worker.ensure_stats_is_final_comment", + new_callable=AsyncMock, + return_value=True, + ) as mock_ensure: + await worker._handle_stats_retry_command("TEST-123", state) + + mock_ensure.assert_awaited_once() + # args: (ticket_key, stats, outcome, outcome_detail) + assert mock_ensure.await_args.args[0] == "TEST-123" + assert mock_ensure.await_args.args[1] is state + + @pytest.mark.asyncio + async def test_retry_does_not_call_add_comment_directly( + self, worker: OrchestratorWorker, mock_jira + ): + """/forge stats retry does not call JiraClient.add_comment directly.""" + state = _base_state() + + with ( + patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira), + patch( + "forge.orchestrator.worker.ensure_stats_is_final_comment", + new_callable=AsyncMock, + return_value=True, + ), + ): + await worker._handle_stats_retry_command("TEST-123", state) + + # The retry path goes through ensure_stats_is_final_comment, not direct add_comment + mock_jira.add_comment.assert_not_awaited() + + @pytest.mark.asyncio + async def test_retry_passes_correct_outcome_to_ensure(self, worker: OrchestratorWorker): + """Retry derives outcome correctly and passes it to ensure_stats_is_final_comment.""" + state = _base_state(workflow_outcome="Completed") + + with patch( + "forge.orchestrator.worker.ensure_stats_is_final_comment", + new_callable=AsyncMock, + return_value=True, + ) as mock_ensure: + await worker._handle_stats_retry_command("TEST-123", state) + + # args: (ticket_key, stats, outcome, outcome_detail) + assert mock_ensure.await_args.args[0] == "TEST-123" + assert mock_ensure.await_args.args[2] == "Completed" + + @pytest.mark.asyncio + async def test_retry_derives_blocked_outcome(self, worker: OrchestratorWorker): + """Retry correctly derives 'Blocked' outcome when is_blocked=True.""" + state = _base_state(is_blocked=True, workflow_outcome=None) + + with patch( + "forge.orchestrator.worker.ensure_stats_is_final_comment", + new_callable=AsyncMock, + return_value=True, + ) as mock_ensure: + await worker._handle_stats_retry_command("TEST-123", state) + + assert mock_ensure.await_args.args[2] == "Blocked" + + @pytest.mark.asyncio + async def test_retry_derives_failed_outcome(self, worker: OrchestratorWorker): + """Retry correctly derives 'Failed' outcome when last_error is set.""" + state = _base_state(last_error="Something went wrong", workflow_outcome=None) + + with patch( + "forge.orchestrator.worker.ensure_stats_is_final_comment", + new_callable=AsyncMock, + return_value=True, + ) as mock_ensure: + await worker._handle_stats_retry_command("TEST-123", state) + + assert mock_ensure.await_args.args[2] == "Failed" + + @pytest.mark.asyncio + async def test_retry_derives_in_progress_outcome(self, worker: OrchestratorWorker): + """Retry uses 'In Progress' outcome for active workflows.""" + state = _base_state(workflow_outcome=None, is_blocked=False, last_error=None) + + with patch( + "forge.orchestrator.worker.ensure_stats_is_final_comment", + new_callable=AsyncMock, + return_value=True, + ) as mock_ensure: + await worker._handle_stats_retry_command("TEST-123", state) + + assert mock_ensure.await_args.args[2] == "In Progress" + + @pytest.mark.asyncio + async def test_retry_passes_outcome_detail(self, worker: OrchestratorWorker): + """Retry passes stats_outcome_reason as outcome_detail.""" + state = _base_state( + workflow_outcome="Blocked", + stats_outcome_reason="Waiting for review", + ) + + with patch( + "forge.orchestrator.worker.ensure_stats_is_final_comment", + new_callable=AsyncMock, + return_value=True, + ) as mock_ensure: + await worker._handle_stats_retry_command("TEST-123", state) + + # args: (ticket_key, stats, outcome, outcome_detail) + assert mock_ensure.await_args.args[3] == "Waiting for review" + + @pytest.mark.asyncio + async def test_retry_uses_last_error_as_detail(self, worker: OrchestratorWorker): + """Retry passes last_error as outcome_detail when no stats_outcome_reason.""" + state = _base_state( + workflow_outcome=None, + last_error="Connection timeout", + stats_outcome_reason=None, + ) + + with patch( + "forge.orchestrator.worker.ensure_stats_is_final_comment", + new_callable=AsyncMock, + return_value=True, + ) as mock_ensure: + await worker._handle_stats_retry_command("TEST-123", state) + + # args: (ticket_key, stats, outcome, outcome_detail) + assert mock_ensure.await_args.args[3] == "Connection timeout" + + +class TestForgeStatsRetryNoData: + """Tests for retry behaviour when no stats data is present.""" + + @pytest.mark.asyncio + async def test_retry_with_no_stage_timestamps_posts_no_data( + self, worker: OrchestratorWorker, mock_jira + ): + """/forge stats retry without stage_timestamps posts 'No workflow data found.'.""" + state = { + "ticket_key": "TEST-123", + "current_node": "prd_approval_gate", + "is_paused": True, + "context": {}, + # stage_timestamps key is absent + } + + with ( + patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira), + patch( + "forge.orchestrator.worker.ensure_stats_is_final_comment", + new_callable=AsyncMock, + ) as mock_ensure, + ): + await worker._handle_stats_retry_command("TEST-123", state) + + # Should fall back to the "no data" path before reaching ensure_stats_is_final_comment + mock_ensure.assert_not_awaited() + mock_jira.add_comment.assert_awaited_once() + body = mock_jira.add_comment.await_args.args[1] + assert "No workflow data found" in body + + @pytest.mark.asyncio + async def test_retry_ensure_failure_does_not_raise(self, worker: OrchestratorWorker): + """/forge stats retry failure in ensure_stats_is_final_comment is non-raising.""" + state = _base_state() + + with patch( + "forge.orchestrator.worker.ensure_stats_is_final_comment", + new_callable=AsyncMock, + side_effect=Exception("network error"), + ): + # Should not raise + await worker._handle_stats_retry_command("TEST-123", state) + + +class TestPostStatsCommentHelper: + """Direct unit tests for _post_stats_comment helper.""" + + @pytest.mark.asyncio + async def test_force_repost_true_uses_ensure_stats(self, worker: OrchestratorWorker): + """force_repost=True routes through ensure_stats_is_final_comment.""" + state = _base_state() + + with patch( + "forge.orchestrator.worker.ensure_stats_is_final_comment", + new_callable=AsyncMock, + return_value=True, + ) as mock_ensure: + await worker._post_stats_comment("TEST-123", state, force_repost=True) + + mock_ensure.assert_awaited_once() + + @pytest.mark.asyncio + async def test_force_repost_false_uses_add_comment(self, worker: OrchestratorWorker, mock_jira): + """force_repost=False uses direct JiraClient.add_comment.""" + state = _base_state() + + with ( + patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira), + patch( + "forge.orchestrator.worker.ensure_stats_is_final_comment", + new_callable=AsyncMock, + ) as mock_ensure, + ): + await worker._post_stats_comment("TEST-123", state, force_repost=False) + + mock_jira.add_comment.assert_awaited_once() + mock_ensure.assert_not_awaited() + + @pytest.mark.asyncio + async def test_force_repost_default_is_false(self, worker: OrchestratorWorker, mock_jira): + """Default force_repost=False uses add_comment (not ensure_stats).""" + state = _base_state() + + with ( + patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira), + patch( + "forge.orchestrator.worker.ensure_stats_is_final_comment", + new_callable=AsyncMock, + ) as mock_ensure, + ): + await worker._post_stats_comment("TEST-123", state) + + mock_jira.add_comment.assert_awaited_once() + mock_ensure.assert_not_awaited() + + @pytest.mark.asyncio + async def test_handle_stats_command_delegates_to_post_helper(self, worker: OrchestratorWorker): + """_handle_stats_command delegates to _post_stats_comment with force_repost=False.""" + state = _base_state() + + with patch.object(worker, "_post_stats_comment", new_callable=AsyncMock) as mock_post: + await worker._handle_stats_command("TEST-123", state) + + mock_post.assert_awaited_once_with("TEST-123", state, force_repost=False) + + @pytest.mark.asyncio + async def test_handle_stats_retry_command_delegates_to_post_helper( + self, worker: OrchestratorWorker + ): + """_handle_stats_retry_command delegates to _post_stats_comment with force_repost=True.""" + state = _base_state() + + with patch.object(worker, "_post_stats_comment", new_callable=AsyncMock) as mock_post: + await worker._handle_stats_retry_command("TEST-123", state) + + mock_post.assert_awaited_once_with("TEST-123", state, force_repost=True) + + @pytest.mark.asyncio + async def test_retry_via_full_resume_event_calls_ensure(self, worker: OrchestratorWorker): + """/forge stats retry via _handle_resume_event triggers ensure_stats_is_final_comment.""" + message = _make_jira_message("TEST-123", "/forge stats retry") + state = _base_state() + + with patch( + "forge.orchestrator.worker.ensure_stats_is_final_comment", + new_callable=AsyncMock, + return_value=True, + ) as mock_ensure: + result = await worker._handle_resume_event(message, state) + + mock_ensure.assert_awaited_once() + assert result is state diff --git a/tests/unit/stats/__init__.py b/tests/unit/stats/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/stats/test_cli_formatter.py b/tests/unit/stats/test_cli_formatter.py new file mode 100644 index 00000000..8623822f --- /dev/null +++ b/tests/unit/stats/test_cli_formatter.py @@ -0,0 +1,756 @@ +"""Unit tests for forge.stats.cli_formatter. + +All tests exercise the public API (format_stats_table, format_stats_json) +and the internal helpers without any I/O or external dependencies. +""" + +from __future__ import annotations + +import json + +from forge.stats.cli_formatter import ( + _COLOR_GREEN, + _COLOR_RED, + _COLOR_RESET, + _DASH, + _colorize, + _determine_display_stages, + _fmt_seconds, + _fmt_tokens, + _stage_row_values, + _totals_row_values, + _truncate, + format_stats_json, + format_stats_table, +) +from forge.stats.retrieval import WorkflowStats + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + +_TICKET = "AISOS-999" + + +def _make_stage( + *, + stage_name: str = "prd", + iteration_count: int = 1, + machine_time_seconds: float = 60.0, + human_time_seconds: float = 120.0, + input_tokens: int = 1000, + output_tokens: int = 500, + started_at: str | None = "2024-01-01T00:00:00+00:00", + ended_at: str | None = "2024-01-01T00:01:00+00:00", +) -> dict: + return { + "stage_name": stage_name, + "iteration_count": iteration_count, + "machine_time_seconds": machine_time_seconds, + "human_time_seconds": human_time_seconds, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "started_at": started_at, + "ended_at": ended_at, + } + + +def _make_stats(**kwargs) -> WorkflowStats: + """Construct a WorkflowStats with sensible defaults.""" + defaults: dict = { + "ticket_key": _TICKET, + "stages": {}, + "pr_urls": [], + "ci_cycles": 0, + "outcome": None, + "outcome_reason": None, + "comment_posted": False, + "workflow_run_id": "", + } + defaults.update(kwargs) + return WorkflowStats(**defaults) + + +# --------------------------------------------------------------------------- +# _fmt_seconds +# --------------------------------------------------------------------------- + + +class TestFmtSeconds: + def test_seconds_only(self): + assert _fmt_seconds(45.0) == "45s" + + def test_minutes_and_seconds(self): + assert _fmt_seconds(90.0) == "1m 30s" + + def test_hours_minutes_seconds(self): + assert _fmt_seconds(3661.0) == "1h 1m 1s" + + def test_zero(self): + assert _fmt_seconds(0.0) == "0s" + + def test_exact_hour(self): + assert _fmt_seconds(3600.0) == "1h 0m 0s" + + def test_truncates_fractional(self): + # fractional seconds are truncated + assert _fmt_seconds(1.9) == "1s" + + +# --------------------------------------------------------------------------- +# _fmt_tokens +# --------------------------------------------------------------------------- + + +class TestFmtTokens: + def test_small_number(self): + assert _fmt_tokens(500) == "500" + + def test_thousands(self): + assert _fmt_tokens(1_000) == "1,000" + + def test_millions(self): + assert _fmt_tokens(1_234_567) == "1,234,567" + + def test_zero(self): + assert _fmt_tokens(0) == "0" + + +# --------------------------------------------------------------------------- +# _truncate +# --------------------------------------------------------------------------- + + +class TestTruncate: + def test_short_string_unchanged(self): + assert _truncate("hello", 10) == "hello" + + def test_exact_length_unchanged(self): + assert _truncate("12345", 5) == "12345" + + def test_long_string_truncated(self): + result = _truncate("abcdefghij", 7) + assert result == "abcd..." + assert len(result) == 7 + + def test_max_len_three_gives_ellipsis(self): + result = _truncate("hello", 3) + assert result == "..." + + +# --------------------------------------------------------------------------- +# _colorize +# --------------------------------------------------------------------------- + + +class TestColorize: + def test_no_color_returns_text(self): + assert _colorize("hello", _COLOR_GREEN, use_color=False) == "hello" + + def test_color_wraps_text(self): + result = _colorize("OK", _COLOR_GREEN, use_color=True) + assert _COLOR_GREEN in result + assert "OK" in result + assert _COLOR_RESET in result + + def test_color_reset_appended(self): + result = _colorize("ERR", _COLOR_RED, use_color=True) + assert result.endswith(_COLOR_RESET) + + +# --------------------------------------------------------------------------- +# _stage_row_values +# --------------------------------------------------------------------------- + + +class TestStageRowValues: + def test_none_stage_returns_dashes(self): + label, itr, mt, ti, to = _stage_row_values("PRD", None) + assert label == "PRD" + assert itr == _DASH + assert mt == _DASH + assert ti == _DASH + assert to == _DASH + + def test_executed_stage_returns_values(self): + stage = _make_stage( + iteration_count=2, + machine_time_seconds=90.0, + human_time_seconds=30.0, + input_tokens=1000, + output_tokens=500, + ) + label, itr, mt, ti, to = _stage_row_values("PRD", stage) + assert label == "PRD" + assert itr == "2" + assert mt == "1m 30s" + assert ti == "1,000" + assert to == "500" + + def test_zero_iteration_count(self): + stage = _make_stage(iteration_count=0) + label, itr, *_ = _stage_row_values("Spec", stage) + assert itr == "0" + + def test_missing_stage_fields_default_to_zero(self): + stage: dict = {} + label, itr, mt, ti, to = _stage_row_values("CI", stage) + assert itr == "0" + assert mt == "0s" + assert ti == "0" + assert to == "0" + + +# --------------------------------------------------------------------------- +# _totals_row_values +# --------------------------------------------------------------------------- + + +class TestTotalsRowValues: + def test_empty_stages_gives_zeros(self): + label, itr, mt, ti, to = _totals_row_values({}) + assert label == "TOTAL" + assert itr == "" + assert mt == "0s" + assert ti == "0" + assert to == "0" + + def test_sums_across_stages(self): + stages = { + "prd": _make_stage( + machine_time_seconds=60.0, + human_time_seconds=30.0, + input_tokens=1000, + output_tokens=500, + ), + "spec": _make_stage( + machine_time_seconds=120.0, + human_time_seconds=60.0, + input_tokens=2000, + output_tokens=1000, + ), + } + label, _, mt, ti, to = _totals_row_values(stages) + assert label == "TOTAL" + assert mt == "3m 0s" + assert ti == "3,000" + assert to == "1,500" + + +# --------------------------------------------------------------------------- +# _determine_display_stages +# --------------------------------------------------------------------------- + + +class TestDetermineDisplayStages: + def test_empty_stages_returns_feature_stages(self): + from forge.workflow.stats import ALL_FEATURE_STAGES + + result = _determine_display_stages({}) + assert result == ALL_FEATURE_STAGES + + def test_feature_stages_returns_feature_list(self): + from forge.workflow.stats import ALL_FEATURE_STAGES + + stages = {"prd": {}, "spec": {}} + result = _determine_display_stages(stages) + assert result == ALL_FEATURE_STAGES + + def test_bug_stages_returns_bug_list(self): + from forge.workflow.stats import ALL_BUG_STAGES + + stages = {"triage": {}, "rca": {}} + result = _determine_display_stages(stages) + assert result == ALL_BUG_STAGES + + def test_planning_triggers_bug_list(self): + from forge.workflow.stats import ALL_BUG_STAGES + + stages = {"planning": {}, "implementation": {}} + result = _determine_display_stages(stages) + assert result == ALL_BUG_STAGES + + +# --------------------------------------------------------------------------- +# format_stats_table — basic structure +# --------------------------------------------------------------------------- + + +class TestFormatStatsTableBasicStructure: + def test_returns_string(self): + stats = _make_stats() + result = format_stats_table(stats) + assert isinstance(result, str) + + def test_contains_ticket_key(self): + stats = _make_stats() + result = format_stats_table(stats) + assert _TICKET in result + + def test_contains_header_columns(self): + stats = _make_stats() + result = format_stats_table(stats) + assert "Stage" in result + assert "Iterations" in result + assert "Machine Time" in result + assert "Tokens In" in result + assert "Tokens Out" in result + + def test_contains_totals_row(self): + stats = _make_stats() + result = format_stats_table(stats) + assert "TOTAL" in result + + def test_contains_outcome(self): + stats = _make_stats(outcome="Completed") + result = format_stats_table(stats) + assert "Completed" in result + + def test_contains_ci_cycles(self): + stats = _make_stats(ci_cycles=3) + result = format_stats_table(stats) + assert "3" in result + + def test_run_id_included_when_present(self): + stats = _make_stats(workflow_run_id="abc-123-def") + result = format_stats_table(stats) + assert "abc-123-def" in result + + def test_run_id_omitted_when_empty(self): + stats = _make_stats(workflow_run_id="") + result = format_stats_table(stats) + assert "Run ID" not in result + + def test_workflow_statistics_heading(self): + stats = _make_stats() + result = format_stats_table(stats) + assert "Workflow Statistics" in result + + +# --------------------------------------------------------------------------- +# format_stats_table — unexecuted stages +# --------------------------------------------------------------------------- + + +class TestFormatStatsTableUnexecutedStages: + def test_empty_stages_shows_dashes(self): + stats = _make_stats(stages={}) + result = format_stats_table(stats) + # All feature stages should show dash + assert _DASH in result + + def test_feature_stages_with_one_executed(self): + stats = _make_stats(stages={"prd": _make_stage()}) + result = format_stats_table(stats) + # PRD shows metrics; other stages show dashes + assert _DASH in result + # PRD row should have "1m 0s" (machine_time_seconds=60) + assert "1m 0s" in result + + def test_dash_present_for_each_unexecuted_stage(self): + """For N unexecuted feature stages there should be multiple dashes.""" + stats = _make_stats(stages={}) + result = format_stats_table(stats) + count = result.count(_DASH) + # 7 feature stages × 4 metric columns = 28 dashes + assert count == 28 + + +# --------------------------------------------------------------------------- +# format_stats_table — stage metrics accuracy +# --------------------------------------------------------------------------- + + +class TestFormatStatsTableMetrics: + def test_iterations_displayed(self): + stage = _make_stage(iteration_count=3) + stats = _make_stats(stages={"prd": stage}) + result = format_stats_table(stats) + assert "3" in result + + def test_machine_time_displayed(self): + stage = _make_stage(machine_time_seconds=3661.0) + stats = _make_stats(stages={"prd": stage}) + result = format_stats_table(stats) + assert "1h 1m 1s" in result + + def test_input_tokens_displayed(self): + stage = _make_stage(input_tokens=1_234_000) + stats = _make_stats(stages={"prd": stage}) + result = format_stats_table(stats) + assert "1,234,000" in result + + def test_output_tokens_displayed(self): + stage = _make_stage(output_tokens=999) + stats = _make_stats(stages={"prd": stage}) + result = format_stats_table(stats) + assert "999" in result + + +# --------------------------------------------------------------------------- +# format_stats_table — summary totals +# --------------------------------------------------------------------------- + + +class TestFormatStatsTableTotals: + def test_totals_row_sums_tokens(self): + stages = { + "prd": _make_stage(input_tokens=1000, output_tokens=500), + "spec": _make_stage(input_tokens=2000, output_tokens=1000), + } + stats = _make_stats(stages=stages) + result = format_stats_table(stats) + # Total input = 3,000; total output = 1,500 + assert "3,000" in result + assert "1,500" in result + + def test_totals_row_label(self): + stats = _make_stats() + result = format_stats_table(stats) + assert "TOTAL" in result + + +# --------------------------------------------------------------------------- +# format_stats_table — PR links +# --------------------------------------------------------------------------- + + +class TestFormatStatsTablePrLinks: + def test_pr_links_included_when_present(self): + pr_url = "https://github.com/org/repo/pull/42" + stats = _make_stats(pr_urls=[pr_url]) + result = format_stats_table(stats) + assert pr_url in result + assert "Pull Requests" in result + + def test_pr_links_omitted_when_empty(self): + stats = _make_stats(pr_urls=[]) + result = format_stats_table(stats) + assert "Pull Requests" not in result + + def test_multiple_pr_links(self): + urls = [ + "https://github.com/org/repo/pull/1", + "https://github.com/org/repo/pull/2", + ] + stats = _make_stats(pr_urls=urls) + result = format_stats_table(stats) + for url in urls: + assert url in result + + +# --------------------------------------------------------------------------- +# format_stats_table — metadata +# --------------------------------------------------------------------------- + + +class TestFormatStatsTableMetadata: + def test_started_from_earliest_stage(self): + stages = { + "prd": _make_stage(started_at="2024-01-01T01:00:00+00:00"), + "spec": _make_stage(started_at="2024-01-01T00:00:00+00:00"), + } + stats = _make_stats(stages=stages) + result = format_stats_table(stats) + # Earliest started_at should appear as "Started" + assert "2024-01-01T00:00:00+00:00" in result + + def test_last_updated_from_latest_ended(self): + stages = { + "prd": _make_stage(ended_at="2024-01-01T01:00:00+00:00"), + "spec": _make_stage(ended_at="2024-01-01T02:00:00+00:00"), + } + stats = _make_stats(stages=stages) + result = format_stats_table(stats) + assert "2024-01-01T02:00:00+00:00" in result + + def test_started_omitted_when_no_stages(self): + stats = _make_stats(stages={}) + result = format_stats_table(stats) + assert "Started" not in result + + def test_outcome_reason_included(self): + stats = _make_stats(outcome="Blocked", outcome_reason="Waiting for approval") + result = format_stats_table(stats) + assert "Waiting for approval" in result + + def test_outcome_reason_omitted_when_none(self): + stats = _make_stats(outcome="Completed", outcome_reason=None) + result = format_stats_table(stats) + assert "Reason" not in result + + def test_outcome_reason_truncated(self): + long_reason = "X" * 200 + stats = _make_stats(outcome="Failed", outcome_reason=long_reason) + result = format_stats_table(stats) + assert "..." in result + # Reason line should exist and be truncated + reason_line = [line for line in result.splitlines() if "Reason" in line][0] + assert len(reason_line) < 200 + 20 # padded with label + + +# --------------------------------------------------------------------------- +# format_stats_table — outcome display +# --------------------------------------------------------------------------- + + +class TestFormatStatsTableOutcome: + def test_in_progress_when_outcome_none(self): + stats = _make_stats(outcome=None) + result = format_stats_table(stats) + assert "In Progress" in result + + def test_completed_outcome(self): + stats = _make_stats(outcome="Completed") + result = format_stats_table(stats) + assert "Completed" in result + + def test_failed_outcome(self): + stats = _make_stats(outcome="Failed: some error") + result = format_stats_table(stats) + assert "Failed" in result + + def test_blocked_outcome(self): + stats = _make_stats(outcome="Blocked") + result = format_stats_table(stats) + assert "Blocked" in result + + +# --------------------------------------------------------------------------- +# format_stats_table — color support +# --------------------------------------------------------------------------- + + +class TestFormatStatsTableColor: + def test_no_color_by_default(self): + stats = _make_stats(outcome="Completed") + result = format_stats_table(stats) + assert "\033[" not in result + + def test_color_completed_green(self): + stats = _make_stats(outcome="Completed") + result = format_stats_table(stats, use_color=True) + assert _COLOR_GREEN in result + + def test_color_failed_red(self): + stats = _make_stats(outcome="Failed: err") + result = format_stats_table(stats, use_color=True) + assert _COLOR_RED in result + + def test_color_reset_present(self): + stats = _make_stats(outcome="Completed") + result = format_stats_table(stats, use_color=True) + assert _COLOR_RESET in result + + +# --------------------------------------------------------------------------- +# format_stats_table — bug workflow stages +# --------------------------------------------------------------------------- + + +class TestFormatStatsTableBugWorkflow: + def test_bug_stages_displayed(self): + stages = { + "triage": _make_stage(stage_name="triage"), + "rca": _make_stage(stage_name="rca"), + } + stats = _make_stats(stages=stages) + result = format_stats_table(stats) + assert "Triage" in result + assert "RCA" in result + # Bug-specific stages + assert "Planning" in result # unexecuted but in bug list + + def test_bug_workflow_does_not_show_prd(self): + """Bug workflows should not display PRD/Spec/Epics/Tasks stages.""" + stages = {"triage": _make_stage(stage_name="triage")} + stats = _make_stats(stages=stages) + result = format_stats_table(stats) + assert "PRD" not in result + assert "Epics" not in result + + +# --------------------------------------------------------------------------- +# format_stats_table — column width truncation +# --------------------------------------------------------------------------- + + +class TestFormatStatsTableColumnWidth: + def test_long_values_truncated(self): + """Very long values should be truncated to max_col_width.""" + stage = _make_stage(stage_name="implementation" * 5) # absurdly long + stats = _make_stats(stages={"implementation": stage}) + result = format_stats_table(stats, max_col_width=10) + # No single cell should exceed the max width significantly + for line in result.splitlines(): + if "|" in line: + # Each cell within pipes should respect max width (with ...suffix) + parts = [p.strip() for p in line.strip("|").split("|")] + for part in parts: + assert len(part) <= 10 + 5 # allow some padding tolerance + + +# --------------------------------------------------------------------------- +# format_stats_json — basic validity +# --------------------------------------------------------------------------- + + +class TestFormatStatsJsonBasicValidity: + def test_returns_string(self): + stats = _make_stats() + result = format_stats_json(stats) + assert isinstance(result, str) + + def test_valid_json(self): + stats = _make_stats() + result = format_stats_json(stats) + parsed = json.loads(result) + assert isinstance(parsed, dict) + + def test_pretty_printed(self): + stats = _make_stats() + result = format_stats_json(stats) + # Pretty-printed JSON contains newlines and indentation + assert "\n" in result + assert " " in result + + +# --------------------------------------------------------------------------- +# format_stats_json — field presence and typing +# --------------------------------------------------------------------------- + + +class TestFormatStatsJsonFields: + def setup_method(self): + stage = _make_stage( + stage_name="prd", + iteration_count=2, + machine_time_seconds=90.0, + human_time_seconds=30.0, + input_tokens=1000, + output_tokens=500, + started_at="2024-01-01T00:00:00+00:00", + ended_at="2024-01-01T01:00:00+00:00", + ) + self.stats = _make_stats( + stages={"prd": stage}, + pr_urls=["https://github.com/org/repo/pull/1"], + ci_cycles=2, + outcome="Completed", + outcome_reason=None, + comment_posted=True, + workflow_run_id="abc-123", + ) + self.parsed = json.loads(format_stats_json(self.stats)) + + def test_ticket_key_field(self): + assert self.parsed["ticket_key"] == _TICKET + + def test_outcome_field(self): + assert self.parsed["outcome"] == "Completed" + + def test_outcome_reason_field(self): + assert self.parsed["outcome_reason"] is None + + def test_ci_cycles_field(self): + assert self.parsed["ci_cycles"] == 2 + + def test_comment_posted_field(self): + assert self.parsed["comment_posted"] is True + + def test_workflow_run_id_field(self): + assert self.parsed["workflow_run_id"] == "abc-123" + + def test_pr_urls_field(self): + assert self.parsed["pr_urls"] == ["https://github.com/org/repo/pull/1"] + + def test_stages_field_present(self): + assert "stages" in self.parsed + + def test_stage_has_all_fields(self): + prd = self.parsed["stages"]["prd"] + assert "stage_name" in prd + assert "iteration_count" in prd + assert "machine_time_seconds" in prd + assert "input_tokens" in prd + assert "output_tokens" in prd + assert "started_at" in prd + assert "ended_at" in prd + + def test_stage_field_types(self): + prd = self.parsed["stages"]["prd"] + assert isinstance(prd["stage_name"], str) + assert isinstance(prd["iteration_count"], int) + assert isinstance(prd["machine_time_seconds"], float) + assert isinstance(prd["input_tokens"], int) + assert isinstance(prd["output_tokens"], int) + assert isinstance(prd["started_at"], str) + assert prd["ended_at"] is not None + + def test_stage_name_value(self): + assert self.parsed["stages"]["prd"]["stage_name"] == "prd" + + def test_stage_metrics_values(self): + prd = self.parsed["stages"]["prd"] + assert prd["iteration_count"] == 2 + assert prd["input_tokens"] == 1000 + assert prd["output_tokens"] == 500 + + +# --------------------------------------------------------------------------- +# format_stats_json — edge cases +# --------------------------------------------------------------------------- + + +class TestFormatStatsJsonEdgeCases: + def test_empty_stages(self): + stats = _make_stats(stages={}) + parsed = json.loads(format_stats_json(stats)) + assert parsed["stages"] == {} + + def test_none_outcome(self): + stats = _make_stats(outcome=None) + parsed = json.loads(format_stats_json(stats)) + assert parsed["outcome"] is None + + def test_empty_pr_urls(self): + stats = _make_stats(pr_urls=[]) + parsed = json.loads(format_stats_json(stats)) + assert parsed["pr_urls"] == [] + + def test_multiple_stages(self): + stages = { + "prd": _make_stage(stage_name="prd"), + "spec": _make_stage(stage_name="spec"), + } + stats = _make_stats(stages=stages) + parsed = json.loads(format_stats_json(stats)) + assert set(parsed["stages"].keys()) == {"prd", "spec"} + + def test_sorted_keys(self): + stats = _make_stats( + stages={"prd": _make_stage()}, + pr_urls=["https://example.com"], + ci_cycles=1, + outcome="Completed", + ) + result = format_stats_json(stats) + parsed_keys = list(json.loads(result).keys()) + assert parsed_keys == sorted(parsed_keys) + + def test_started_at_none_serialized(self): + stage = _make_stage(started_at=None, ended_at=None) + stats = _make_stats(stages={"prd": stage}) + parsed = json.loads(format_stats_json(stats)) + assert parsed["stages"]["prd"]["started_at"] is None + assert parsed["stages"]["prd"]["ended_at"] is None + + def test_missing_stage_fields_use_defaults(self): + """Stages with missing fields should use zero/None defaults.""" + stats = _make_stats(stages={"prd": {}}) + parsed = json.loads(format_stats_json(stats)) + prd = parsed["stages"]["prd"] + assert prd["iteration_count"] == 0 + assert prd["machine_time_seconds"] == 0.0 + assert prd["input_tokens"] == 0 + assert prd["started_at"] is None diff --git a/tests/unit/stats/test_notifications.py b/tests/unit/stats/test_notifications.py new file mode 100644 index 00000000..af818910 --- /dev/null +++ b/tests/unit/stats/test_notifications.py @@ -0,0 +1,645 @@ +"""Unit tests for forge.workflow.stats.notifications. + +All Jira API calls are mocked; no real HTTP connections are made. +""" + +from __future__ import annotations + +import argparse +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from forge.workflow.stats.notifications import ( + _format_mention, + _parse_account_ids, + get_notification_recipients, + notify_report_ready, +) + +# --------------------------------------------------------------------------- +# Tests for _format_mention +# --------------------------------------------------------------------------- + + +class TestFormatMention: + """Tests for the _format_mention() helper.""" + + def test_basic_account_id(self): + """Account ID is wrapped in Jira mention syntax.""" + assert _format_mention("abc123") == "[~accountid:abc123]" + + def test_long_account_id(self): + """Longer account IDs (real Jira IDs) are formatted correctly.""" + long_id = "5e7e3b1a8c9d2f0b4a6e8c12" + assert _format_mention(long_id) == f"[~accountid:{long_id}]" + + def test_alphanumeric_account_id(self): + """Alphanumeric account IDs are formatted correctly.""" + assert _format_mention("user-id-456") == "[~accountid:user-id-456]" + + def test_format_produces_valid_jira_syntax(self): + """The output should start with [~accountid: and end with ].""" + result = _format_mention("someuser") + assert result.startswith("[~accountid:") + assert result.endswith("]") + + def test_empty_string(self): + """Empty string is formatted (caller is responsible for filtering).""" + assert _format_mention("") == "[~accountid:]" + + +# --------------------------------------------------------------------------- +# Tests for _parse_account_ids +# --------------------------------------------------------------------------- + + +class TestParseAccountIds: + """Tests for the _parse_account_ids() helper.""" + + def test_list_of_strings(self): + """A list of strings is returned as-is (stripped).""" + assert _parse_account_ids(["abc", "def", "ghi"]) == ["abc", "def", "ghi"] + + def test_list_with_whitespace(self): + """Items with leading/trailing whitespace are stripped.""" + assert _parse_account_ids([" abc ", " def"]) == ["abc", "def"] + + def test_comma_separated_string(self): + """Comma-separated string is split into individual IDs.""" + assert _parse_account_ids("abc,def,ghi") == ["abc", "def", "ghi"] + + def test_comma_separated_with_spaces(self): + """Spaces around commas are stripped.""" + assert _parse_account_ids("abc, def , ghi") == ["abc", "def", "ghi"] + + def test_single_string(self): + """A single account ID string (no commas) is returned as a one-item list.""" + assert _parse_account_ids("abc123") == ["abc123"] + + def test_empty_string(self): + """Empty string returns empty list.""" + assert _parse_account_ids("") == [] + + def test_empty_list(self): + """Empty list returns empty list.""" + assert _parse_account_ids([]) == [] + + def test_list_with_empty_entries(self): + """Empty strings in a list are filtered out.""" + assert _parse_account_ids(["abc", "", "def"]) == ["abc", "def"] + + def test_comma_string_with_empty_parts(self): + """Consecutive commas produce empty parts that are filtered out.""" + assert _parse_account_ids("abc,,def") == ["abc", "def"] + + def test_deduplication(self): + """Duplicate IDs are removed, first occurrence wins.""" + assert _parse_account_ids(["abc", "def", "abc"]) == ["abc", "def"] + + def test_deduplication_in_string(self): + """Duplicate IDs in comma-separated string are deduplicated.""" + assert _parse_account_ids("abc,def,abc") == ["abc", "def"] + + def test_unsupported_type(self): + """Non-string, non-list input returns empty list.""" + assert _parse_account_ids(None) == [] # type: ignore[arg-type] + assert _parse_account_ids(42) == [] # type: ignore[arg-type] + assert _parse_account_ids({}) == [] # type: ignore[arg-type] + + def test_list_of_non_strings(self): + """Non-string items in list are coerced to strings.""" + result = _parse_account_ids([123, 456]) + assert result == ["123", "456"] + + +# --------------------------------------------------------------------------- +# Tests for get_notification_recipients +# --------------------------------------------------------------------------- + + +class TestGetNotificationRecipients: + """Tests for the async get_notification_recipients() function.""" + + @pytest.mark.asyncio + async def test_project_property_takes_precedence(self): + """Project property overrides env var when both are set.""" + mock_jira = MagicMock() + mock_jira.get_project_property = AsyncMock(return_value=["prop_user1", "prop_user2"]) + mock_jira.close = AsyncMock() + + with ( + patch("forge.workflow.stats.notifications.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.notifications.get_settings", + return_value=MagicMock(weekly_report_notify="env_user1,env_user2"), + ), + ): + result = await get_notification_recipients("PROJ") + + assert result == ["prop_user1", "prop_user2"] + + @pytest.mark.asyncio + async def test_falls_back_to_env_var_when_no_property(self): + """Env var is used when the project property is not set.""" + mock_jira = MagicMock() + mock_jira.get_project_property = AsyncMock(return_value=None) + mock_jira.close = AsyncMock() + + with ( + patch("forge.workflow.stats.notifications.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.notifications.get_settings", + return_value=MagicMock(weekly_report_notify="env_user1,env_user2"), + ), + ): + result = await get_notification_recipients("PROJ") + + assert result == ["env_user1", "env_user2"] + + @pytest.mark.asyncio + async def test_empty_when_no_config(self): + """Returns empty list when no env var and no project property.""" + mock_jira = MagicMock() + mock_jira.get_project_property = AsyncMock(return_value=None) + mock_jira.close = AsyncMock() + + with ( + patch("forge.workflow.stats.notifications.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.notifications.get_settings", + return_value=MagicMock(weekly_report_notify=""), + ), + ): + result = await get_notification_recipients("PROJ") + + assert result == [] + + @pytest.mark.asyncio + async def test_project_leads_sentinel_with_no_property(self): + """'project-leads' sentinel returns empty list when property is absent.""" + mock_jira = MagicMock() + mock_jira.get_project_property = AsyncMock(return_value=None) + mock_jira.close = AsyncMock() + + with ( + patch("forge.workflow.stats.notifications.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.notifications.get_settings", + return_value=MagicMock(weekly_report_notify="project-leads"), + ), + ): + result = await get_notification_recipients("PROJ") + + assert result == [] + + @pytest.mark.asyncio + async def test_project_property_as_string(self): + """Project property value as comma-separated string is parsed correctly.""" + mock_jira = MagicMock() + mock_jira.get_project_property = AsyncMock(return_value="user1,user2") + mock_jira.close = AsyncMock() + + with ( + patch("forge.workflow.stats.notifications.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.notifications.get_settings", + return_value=MagicMock(weekly_report_notify=""), + ), + ): + result = await get_notification_recipients("PROJ") + + assert result == ["user1", "user2"] + + @pytest.mark.asyncio + async def test_project_property_error_falls_back_to_env(self): + """When the project property lookup fails, env var is used.""" + mock_jira = MagicMock() + mock_jira.get_project_property = AsyncMock(side_effect=Exception("Network error")) + mock_jira.close = AsyncMock() + + with ( + patch("forge.workflow.stats.notifications.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.notifications.get_settings", + return_value=MagicMock(weekly_report_notify="fallback_user"), + ), + ): + result = await get_notification_recipients("PROJ") + + assert result == ["fallback_user"] + + @pytest.mark.asyncio + async def test_jira_client_is_closed_after_property_lookup(self): + """The JiraClient is always closed after the project property lookup.""" + mock_jira = MagicMock() + mock_jira.get_project_property = AsyncMock(return_value=None) + mock_jira.close = AsyncMock() + + with ( + patch("forge.workflow.stats.notifications.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.notifications.get_settings", + return_value=MagicMock(weekly_report_notify=""), + ), + ): + await get_notification_recipients("PROJ") + + mock_jira.close.assert_awaited_once() + + +# --------------------------------------------------------------------------- +# Tests for notify_report_ready +# --------------------------------------------------------------------------- + + +class TestNotifyReportReady: + """Tests for the async notify_report_ready() function.""" + + @pytest.mark.asyncio + async def test_posts_comment_with_mentions(self): + """A comment containing mentions is posted to the ticket.""" + from forge.integrations.jira.models import JiraComment + + mock_jira = MagicMock() + mock_jira.add_comment = AsyncMock( + return_value=JiraComment( + id="10001", + author_id="forge-bot", + author_name="Forge", + body="test", + ) + ) + mock_jira.close = AsyncMock() + + with ( + patch("forge.workflow.stats.notifications.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.notifications.get_settings", + return_value=MagicMock(jira_base_url="https://example.atlassian.net"), + ), + ): + await notify_report_ready("PROJ-42", ["user1", "user2"]) + + mock_jira.add_comment.assert_awaited_once() + call_args = mock_jira.add_comment.call_args + assert call_args[0][0] == "PROJ-42" + comment_body = call_args[0][1] + assert "[~accountid:user1]" in comment_body + assert "[~accountid:user2]" in comment_body + + @pytest.mark.asyncio + async def test_comment_includes_ticket_link(self): + """The notification comment contains a link to the report ticket.""" + from forge.integrations.jira.models import JiraComment + + mock_jira = MagicMock() + mock_jira.add_comment = AsyncMock( + return_value=JiraComment( + id="10001", + author_id="forge-bot", + author_name="Forge", + body="test", + ) + ) + mock_jira.close = AsyncMock() + + with ( + patch("forge.workflow.stats.notifications.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.notifications.get_settings", + return_value=MagicMock(jira_base_url="https://example.atlassian.net"), + ), + ): + await notify_report_ready( + "PROJ-42", + ["user1"], + jira_base_url="https://example.atlassian.net", + ) + + comment_body = mock_jira.add_comment.call_args[0][1] + assert "PROJ-42" in comment_body + assert "https://example.atlassian.net/browse/PROJ-42" in comment_body + + @pytest.mark.asyncio + async def test_no_comment_when_recipients_empty(self): + """No comment is posted when the recipients list is empty.""" + mock_jira = MagicMock() + mock_jira.add_comment = AsyncMock() + mock_jira.close = AsyncMock() + + with patch("forge.workflow.stats.notifications.JiraClient", return_value=mock_jira): + await notify_report_ready("PROJ-42", []) + + mock_jira.add_comment.assert_not_awaited() + + @pytest.mark.asyncio + async def test_skips_invalid_account_ids_with_spaces(self): + """Account IDs containing spaces are skipped with a warning.""" + from forge.integrations.jira.models import JiraComment + + mock_jira = MagicMock() + mock_jira.add_comment = AsyncMock( + return_value=JiraComment( + id="10001", + author_id="forge-bot", + author_name="Forge", + body="test", + ) + ) + mock_jira.close = AsyncMock() + + with ( + patch("forge.workflow.stats.notifications.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.notifications.get_settings", + return_value=MagicMock(jira_base_url="https://example.atlassian.net"), + ), + ): + await notify_report_ready("PROJ-42", ["valid_user", "bad user"]) + + comment_body = mock_jira.add_comment.call_args[0][1] + assert "[~accountid:valid_user]" in comment_body + assert "bad user" not in comment_body + + @pytest.mark.asyncio + async def test_skips_account_ids_with_commas(self): + """Account IDs containing commas are treated as malformed and skipped.""" + from forge.integrations.jira.models import JiraComment + + mock_jira = MagicMock() + mock_jira.add_comment = AsyncMock( + return_value=JiraComment( + id="10001", + author_id="forge-bot", + author_name="Forge", + body="test", + ) + ) + mock_jira.close = AsyncMock() + + with ( + patch("forge.workflow.stats.notifications.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.notifications.get_settings", + return_value=MagicMock(jira_base_url="https://example.atlassian.net"), + ), + ): + await notify_report_ready("PROJ-42", ["valid_user", "bad,user"]) + + comment_body = mock_jira.add_comment.call_args[0][1] + assert "[~accountid:valid_user]" in comment_body + assert "bad,user" not in comment_body + + @pytest.mark.asyncio + async def test_no_comment_when_all_recipients_invalid(self): + """No comment is posted when all recipients are invalid.""" + mock_jira = MagicMock() + mock_jira.add_comment = AsyncMock() + mock_jira.close = AsyncMock() + + with ( + patch("forge.workflow.stats.notifications.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.notifications.get_settings", + return_value=MagicMock(jira_base_url="https://example.atlassian.net"), + ), + ): + await notify_report_ready("PROJ-42", ["bad user", "also,bad"]) + + mock_jira.add_comment.assert_not_awaited() + + @pytest.mark.asyncio + async def test_jira_client_closed_on_success(self): + """JiraClient.close() is called after a successful comment post.""" + from forge.integrations.jira.models import JiraComment + + mock_jira = MagicMock() + mock_jira.add_comment = AsyncMock( + return_value=JiraComment( + id="10001", + author_id="forge-bot", + author_name="Forge", + body="test", + ) + ) + mock_jira.close = AsyncMock() + + with ( + patch("forge.workflow.stats.notifications.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.notifications.get_settings", + return_value=MagicMock(jira_base_url="https://example.atlassian.net"), + ), + ): + await notify_report_ready("PROJ-42", ["user1"]) + + mock_jira.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_jira_client_closed_on_error(self): + """JiraClient.close() is called even when add_comment raises.""" + mock_jira = MagicMock() + mock_jira.add_comment = AsyncMock(side_effect=Exception("API error")) + mock_jira.close = AsyncMock() + + with ( + patch("forge.workflow.stats.notifications.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.notifications.get_settings", + return_value=MagicMock(jira_base_url="https://example.atlassian.net"), + ), + pytest.raises(Exception, match="API error"), + ): + await notify_report_ready("PROJ-42", ["user1"]) + + mock_jira.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_uses_jira_base_url_override(self): + """jira_base_url parameter overrides the settings value.""" + from forge.integrations.jira.models import JiraComment + + mock_jira = MagicMock() + mock_jira.add_comment = AsyncMock( + return_value=JiraComment( + id="10001", + author_id="forge-bot", + author_name="Forge", + body="test", + ) + ) + mock_jira.close = AsyncMock() + + with ( + patch("forge.workflow.stats.notifications.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.notifications.get_settings", + return_value=MagicMock(jira_base_url="https://wrong.atlassian.net"), + ), + ): + await notify_report_ready( + "PROJ-1", + ["user1"], + jira_base_url="https://correct.atlassian.net", + ) + + comment_body = mock_jira.add_comment.call_args[0][1] + assert "https://correct.atlassian.net/browse/PROJ-1" in comment_body + assert "wrong" not in comment_body + + @pytest.mark.asyncio + async def test_trailing_slash_stripped_from_base_url(self): + """Trailing slashes in jira_base_url are stripped before building the link.""" + from forge.integrations.jira.models import JiraComment + + mock_jira = MagicMock() + mock_jira.add_comment = AsyncMock( + return_value=JiraComment( + id="10001", + author_id="forge-bot", + author_name="Forge", + body="test", + ) + ) + mock_jira.close = AsyncMock() + + with ( + patch("forge.workflow.stats.notifications.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.notifications.get_settings", + return_value=MagicMock(jira_base_url="https://example.atlassian.net/"), + ), + ): + await notify_report_ready("PROJ-5", ["user1"]) + + comment_body = mock_jira.add_comment.call_args[0][1] + # Should not have double slash + assert "//browse" not in comment_body + assert "https://example.atlassian.net/browse/PROJ-5" in comment_body + + +# --------------------------------------------------------------------------- +# Tests for CLI --notify integration +# --------------------------------------------------------------------------- + + +class TestCLINotifyFlag: + """Tests for the --notify flag in cmd_weekly_report.""" + + def _make_args(self, **kwargs) -> argparse.Namespace: + defaults = { + "project": "PROJ", + "days": 7, + "output": None, + "format": "text", + "create_ticket": False, + "notify": False, + } + defaults.update(kwargs) + return argparse.Namespace(**defaults) + + @pytest.mark.asyncio + async def test_notify_without_create_ticket_returns_error(self): + """--notify without --create-ticket returns exit code 1.""" + from forge.cli import cmd_weekly_report + from forge.workflow.stats.weekly_report import ( + TicketSummary, + WeeklyReportData, + ) + + report = WeeklyReportData( + project="PROJ", + period_days=7, + report_start="2024-01-01T00:00:00+00:00", + report_end="2024-01-08T00:00:00+00:00", + completed_tickets=[TicketSummary(ticket_key="PROJ-1", status="completed")], + ) + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new_callable=AsyncMock, + return_value=report, + ): + args = self._make_args(notify=True, create_ticket=False) + result = await cmd_weekly_report(args) + + assert result == 1 + + @pytest.mark.asyncio + async def test_notify_sends_notification_when_create_ticket_succeeds(self): + """--notify posts a notification after successfully creating the ticket.""" + from forge.cli import cmd_weekly_report + from forge.workflow.stats.weekly_report import ( + TicketSummary, + WeeklyReportData, + ) + + report = WeeklyReportData( + project="PROJ", + period_days=7, + report_start="2024-01-01T00:00:00+00:00", + report_end="2024-01-08T00:00:00+00:00", + completed_tickets=[TicketSummary(ticket_key="PROJ-1", status="completed")], + ) + + with ( + patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new_callable=AsyncMock, + return_value=report, + ), + patch( + "forge.workflow.stats.report_ticket.ensure_report_ticket", + new_callable=AsyncMock, + return_value="PROJ-99", + ), + patch( + "forge.workflow.stats.notifications.get_notification_recipients", + new_callable=AsyncMock, + return_value=["user1"], + ), + patch( + "forge.workflow.stats.notifications.notify_report_ready", + new_callable=AsyncMock, + ) as mock_notify, + ): + args = self._make_args(notify=True, create_ticket=True) + result = await cmd_weekly_report(args) + + assert result == 0 + mock_notify.assert_awaited_once_with("PROJ-99", ["user1"]) + + @pytest.mark.asyncio + async def test_no_notification_when_notify_flag_not_set(self): + """Without --notify, no notification functions are called.""" + from forge.cli import cmd_weekly_report + from forge.workflow.stats.weekly_report import ( + TicketSummary, + WeeklyReportData, + ) + + report = WeeklyReportData( + project="PROJ", + period_days=7, + report_start="2024-01-01T00:00:00+00:00", + report_end="2024-01-08T00:00:00+00:00", + completed_tickets=[TicketSummary(ticket_key="PROJ-1", status="completed")], + ) + + with ( + patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new_callable=AsyncMock, + return_value=report, + ), + patch( + "forge.workflow.stats.notifications.notify_report_ready", + new_callable=AsyncMock, + ) as mock_notify, + ): + args = self._make_args(notify=False, create_ticket=False) + result = await cmd_weekly_report(args) + + mock_notify.assert_not_awaited() + assert result == 0 diff --git a/tests/unit/stats/test_retrieval.py b/tests/unit/stats/test_retrieval.py new file mode 100644 index 00000000..9a956a8f --- /dev/null +++ b/tests/unit/stats/test_retrieval.py @@ -0,0 +1,576 @@ +"""Unit tests for forge.stats.retrieval. + +All checkpoint access is mocked; no Redis or LangGraph connections are +made. Tests cover the public API (get_workflow_stats and +get_workflow_stats_or_error) as well as the internal _extract_stats helper. +""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from forge.stats.retrieval import ( + WorkflowStats, + _extract_stats, + get_workflow_stats, + get_workflow_stats_or_error, +) + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + +_TICKET = "AISOS-123" + + +def _make_stage( + *, + stage_name: str = "prd", + iteration_count: int = 1, + machine_time_seconds: float = 60.0, + human_time_seconds: float = 0.0, + input_tokens: int = 1000, + output_tokens: int = 500, + started_at: str | None = "2024-01-01T00:00:00+00:00", + ended_at: str | None = "2024-01-01T00:01:00+00:00", +) -> dict: + return { + "stage_name": stage_name, + "iteration_count": iteration_count, + "machine_time_seconds": machine_time_seconds, + "human_time_seconds": human_time_seconds, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "started_at": started_at, + "ended_at": ended_at, + } + + +def _full_state(**overrides) -> dict: + """Return a well-formed checkpoint state dict with stats fields.""" + base: dict = { + "ticket_key": _TICKET, + "ticket_type": "Feature", + "current_node": "prd_approval_gate", + "is_paused": False, + "is_blocked": False, + "last_error": None, + "feedback_comment": None, + "context": {}, + "stage_timestamps": { + "prd": _make_stage(stage_name="prd"), + }, + "stats_pr_urls": ["https://github.com/org/repo/pull/1"], + "stats_ci_cycles": 2, + "workflow_outcome": "Completed", + "stats_outcome_reason": None, + "stats_comment_posted": True, + "workflow_run_id": "abc-123", + } + base.update(overrides) + return base + + +def _patch_checkpoint(return_value): + """Patch get_checkpoint_state in the retrieval module.""" + return patch( + "forge.stats.retrieval.get_checkpoint_state", + new=AsyncMock(return_value=return_value), + ) + + +# --------------------------------------------------------------------------- +# WorkflowStats dataclass +# --------------------------------------------------------------------------- + + +class TestWorkflowStatsDataclass: + """Tests for the WorkflowStats dataclass itself.""" + + def test_default_construction(self): + """WorkflowStats can be constructed with only ticket_key.""" + ws = WorkflowStats(ticket_key=_TICKET) + assert ws.ticket_key == _TICKET + assert ws.stages == {} + assert ws.pr_urls == [] + assert ws.ci_cycles == 0 + assert ws.outcome is None + assert ws.outcome_reason is None + assert ws.comment_posted is False + assert ws.workflow_run_id == "" + + def test_full_construction(self): + """WorkflowStats accepts all fields.""" + stage = _make_stage() + ws = WorkflowStats( + ticket_key=_TICKET, + stages={"prd": stage}, + pr_urls=["https://github.com/org/repo/pull/1"], + ci_cycles=3, + outcome="Completed", + outcome_reason=None, + comment_posted=True, + workflow_run_id="uuid-xyz", + ) + assert ws.stages == {"prd": stage} + assert ws.pr_urls == ["https://github.com/org/repo/pull/1"] + assert ws.ci_cycles == 3 + assert ws.outcome == "Completed" + assert ws.comment_posted is True + assert ws.workflow_run_id == "uuid-xyz" + + def test_stages_default_is_independent_per_instance(self): + """Each WorkflowStats instance gets its own stages dict (not shared).""" + ws1 = WorkflowStats(ticket_key="AISOS-1") + ws2 = WorkflowStats(ticket_key="AISOS-2") + ws1.stages["prd"] = _make_stage() + assert "prd" not in ws2.stages + + def test_pr_urls_default_is_independent_per_instance(self): + """Each WorkflowStats instance gets its own pr_urls list (not shared).""" + ws1 = WorkflowStats(ticket_key="AISOS-1") + ws2 = WorkflowStats(ticket_key="AISOS-2") + ws1.pr_urls.append("https://example.com") + assert ws2.pr_urls == [] + + +# --------------------------------------------------------------------------- +# _extract_stats internal helper +# --------------------------------------------------------------------------- + + +class TestExtractStats: + """Tests for the _extract_stats helper.""" + + def test_returns_none_when_stage_timestamps_absent(self): + """Returns None when stage_timestamps key is missing (legacy workflow).""" + state = { + "ticket_key": _TICKET, + "ticket_type": "Feature", + "current_node": "prd_generation", + } + result = _extract_stats(_TICKET, state) + assert result is None + + def test_returns_workflow_stats_with_stages_present(self): + """Returns WorkflowStats when stage_timestamps key is present.""" + state = _full_state() + result = _extract_stats(_TICKET, state) + assert result is not None + assert isinstance(result, WorkflowStats) + + def test_ticket_key_is_passed_through(self): + """The ticket_key from the argument is stored on the result.""" + state = _full_state() + result = _extract_stats("MYPROJ-999", state) + assert result is not None + assert result.ticket_key == "MYPROJ-999" + + def test_stages_are_extracted(self): + """stages dict contains the stages from the checkpoint.""" + stage = _make_stage(stage_name="prd") + state = _full_state(stage_timestamps={"prd": stage}) + result = _extract_stats(_TICKET, state) + assert result is not None + assert result.stages == {"prd": stage} + + def test_empty_stages_dict_is_valid(self): + """An empty stage_timestamps dict is returned as an empty stages dict.""" + state = _full_state(stage_timestamps={}) + result = _extract_stats(_TICKET, state) + assert result is not None + assert result.stages == {} + + def test_pr_urls_are_extracted(self): + """pr_urls are extracted from stats_pr_urls.""" + urls = ["https://github.com/org/repo/pull/1", "https://github.com/org/repo/pull/2"] + state = _full_state(stats_pr_urls=urls) + result = _extract_stats(_TICKET, state) + assert result is not None + assert result.pr_urls == urls + + def test_missing_pr_urls_defaults_to_empty_list(self): + """Missing stats_pr_urls key yields an empty pr_urls list.""" + state = _full_state() + del state["stats_pr_urls"] + result = _extract_stats(_TICKET, state) + assert result is not None + assert result.pr_urls == [] + + def test_null_pr_urls_defaults_to_empty_list(self): + """stats_pr_urls=None is treated as empty list.""" + state = _full_state(stats_pr_urls=None) + result = _extract_stats(_TICKET, state) + assert result is not None + assert result.pr_urls == [] + + def test_ci_cycles_extracted(self): + """ci_cycles is extracted from stats_ci_cycles.""" + state = _full_state(stats_ci_cycles=5) + result = _extract_stats(_TICKET, state) + assert result is not None + assert result.ci_cycles == 5 + + def test_missing_ci_cycles_defaults_to_zero(self): + """Missing stats_ci_cycles yields ci_cycles=0.""" + state = _full_state() + del state["stats_ci_cycles"] + result = _extract_stats(_TICKET, state) + assert result is not None + assert result.ci_cycles == 0 + + def test_null_ci_cycles_defaults_to_zero(self): + """stats_ci_cycles=None yields ci_cycles=0.""" + state = _full_state(stats_ci_cycles=None) + result = _extract_stats(_TICKET, state) + assert result is not None + assert result.ci_cycles == 0 + + def test_outcome_extracted(self): + """outcome is extracted from workflow_outcome.""" + state = _full_state(workflow_outcome="Completed") + result = _extract_stats(_TICKET, state) + assert result is not None + assert result.outcome == "Completed" + + def test_outcome_none_when_missing(self): + """Missing workflow_outcome yields outcome=None.""" + state = _full_state() + del state["workflow_outcome"] + result = _extract_stats(_TICKET, state) + assert result is not None + assert result.outcome is None + + def test_outcome_reason_extracted(self): + """outcome_reason is extracted from stats_outcome_reason.""" + state = _full_state(stats_outcome_reason="Deployment gate failed") + result = _extract_stats(_TICKET, state) + assert result is not None + assert result.outcome_reason == "Deployment gate failed" + + def test_comment_posted_true(self): + """comment_posted is True when stats_comment_posted=True.""" + state = _full_state(stats_comment_posted=True) + result = _extract_stats(_TICKET, state) + assert result is not None + assert result.comment_posted is True + + def test_comment_posted_false_by_default(self): + """Missing stats_comment_posted yields comment_posted=False.""" + state = _full_state() + del state["stats_comment_posted"] + result = _extract_stats(_TICKET, state) + assert result is not None + assert result.comment_posted is False + + def test_workflow_run_id_extracted(self): + """workflow_run_id is extracted from the state.""" + state = _full_state(workflow_run_id="run-uuid-4567") + result = _extract_stats(_TICKET, state) + assert result is not None + assert result.workflow_run_id == "run-uuid-4567" + + def test_missing_workflow_run_id_defaults_to_empty_string(self): + """Missing workflow_run_id yields empty string (pre-idempotency checkpoint).""" + state = _full_state() + del state["workflow_run_id"] + result = _extract_stats(_TICKET, state) + assert result is not None + assert result.workflow_run_id == "" + + def test_malformed_stages_dict_treated_as_empty(self): + """Malformed stage_timestamps (not a dict) is treated as empty dict.""" + state = _full_state(stage_timestamps="not-a-dict") + result = _extract_stats(_TICKET, state) + assert result is not None + assert result.stages == {} + + def test_malformed_pr_urls_treated_as_empty(self): + """Malformed stats_pr_urls (not a list) is treated as empty list.""" + state = _full_state(stats_pr_urls="not-a-list") + result = _extract_stats(_TICKET, state) + assert result is not None + assert result.pr_urls == [] + + def test_partial_state_in_progress_workflow(self): + """Partial stats for an in-progress workflow are returned as-is.""" + stage = _make_stage(stage_name="prd", ended_at=None) + state = _full_state( + stage_timestamps={"prd": stage}, + workflow_outcome=None, + stats_outcome_reason=None, + stats_comment_posted=False, + ) + result = _extract_stats(_TICKET, state) + assert result is not None + assert result.stages["prd"]["ended_at"] is None + assert result.outcome is None + assert result.comment_posted is False + + +# --------------------------------------------------------------------------- +# get_workflow_stats +# --------------------------------------------------------------------------- + + +class TestGetWorkflowStats: + """Tests for the public get_workflow_stats() function.""" + + @pytest.mark.asyncio + async def test_returns_none_when_no_checkpoint(self): + """Returns None when get_checkpoint_state returns None.""" + with _patch_checkpoint(None): + result = await get_workflow_stats(_TICKET) + assert result is None + + @pytest.mark.asyncio + async def test_returns_workflow_stats_for_valid_checkpoint(self): + """Returns WorkflowStats for a checkpoint with stats data.""" + state = _full_state() + with _patch_checkpoint(state): + result = await get_workflow_stats(_TICKET) + assert result is not None + assert isinstance(result, WorkflowStats) + + @pytest.mark.asyncio + async def test_ticket_key_propagated(self): + """WorkflowStats.ticket_key matches the requested ticket key.""" + state = _full_state() + with _patch_checkpoint(state): + result = await get_workflow_stats("MYPROJ-42") + assert result is not None + assert result.ticket_key == "MYPROJ-42" + + @pytest.mark.asyncio + async def test_returns_none_for_legacy_checkpoint_without_stats(self): + """Returns None when checkpoint exists but has no stage_timestamps key.""" + legacy_state = { + "ticket_key": _TICKET, + "ticket_type": "Feature", + "current_node": "done", + } + with _patch_checkpoint(legacy_state): + result = await get_workflow_stats(_TICKET) + assert result is None + + @pytest.mark.asyncio + async def test_stages_populated_from_checkpoint(self): + """stages dict contains the stages stored in the checkpoint.""" + stage = _make_stage(stage_name="spec") + state = _full_state(stage_timestamps={"spec": stage}) + with _patch_checkpoint(state): + result = await get_workflow_stats(_TICKET) + assert result is not None + assert "spec" in result.stages + + @pytest.mark.asyncio + async def test_empty_stages_valid(self): + """Workflow with empty stage_timestamps is returned (not treated as missing).""" + state = _full_state(stage_timestamps={}) + with _patch_checkpoint(state): + result = await get_workflow_stats(_TICKET) + assert result is not None + assert result.stages == {} + + @pytest.mark.asyncio + async def test_partial_in_progress_workflow_returned(self): + """Partial stats for an in-progress workflow are returned with available data.""" + stage = _make_stage(ended_at=None) + state = _full_state( + stage_timestamps={"prd": stage}, + workflow_outcome=None, + stats_pr_urls=[], + stats_ci_cycles=0, + ) + with _patch_checkpoint(state): + result = await get_workflow_stats(_TICKET) + assert result is not None + assert result.outcome is None + assert result.stages["prd"]["ended_at"] is None + + @pytest.mark.asyncio + async def test_calls_get_checkpoint_state_with_ticket_key(self): + """get_checkpoint_state is called with the supplied ticket_key.""" + state = _full_state() + mock = AsyncMock(return_value=state) + with patch("forge.stats.retrieval.get_checkpoint_state", new=mock): + await get_workflow_stats("PROJ-55") + mock.assert_called_once_with("PROJ-55") + + @pytest.mark.asyncio + async def test_pr_urls_extracted_correctly(self): + """pr_urls from the checkpoint appear in the returned WorkflowStats.""" + urls = ["https://github.com/org/repo/pull/10"] + state = _full_state(stats_pr_urls=urls) + with _patch_checkpoint(state): + result = await get_workflow_stats(_TICKET) + assert result is not None + assert result.pr_urls == urls + + @pytest.mark.asyncio + async def test_ci_cycles_extracted_correctly(self): + """ci_cycles from the checkpoint appear in the returned WorkflowStats.""" + state = _full_state(stats_ci_cycles=7) + with _patch_checkpoint(state): + result = await get_workflow_stats(_TICKET) + assert result is not None + assert result.ci_cycles == 7 + + @pytest.mark.asyncio + async def test_propagates_exception_from_checkpointer(self): + """Exceptions from get_checkpoint_state are not swallowed.""" + with ( + patch( + "forge.stats.retrieval.get_checkpoint_state", + new=AsyncMock(side_effect=ConnectionError("Redis down")), + ), + pytest.raises(ConnectionError), + ): + await get_workflow_stats(_TICKET) + + +# --------------------------------------------------------------------------- +# get_workflow_stats_or_error +# --------------------------------------------------------------------------- + + +class TestGetWorkflowStatsOrError: + """Tests for the public get_workflow_stats_or_error() function.""" + + @pytest.mark.asyncio + async def test_returns_stats_and_none_error_on_success(self): + """Returns (WorkflowStats, None) when stats are found.""" + state = _full_state() + with _patch_checkpoint(state): + stats, error = await get_workflow_stats_or_error(_TICKET) + assert stats is not None + assert error is None + + @pytest.mark.asyncio + async def test_returns_none_stats_and_error_when_no_checkpoint(self): + """Returns (None, error_str) when no checkpoint exists.""" + with _patch_checkpoint(None): + stats, error = await get_workflow_stats_or_error(_TICKET) + assert stats is None + assert error is not None + + @pytest.mark.asyncio + async def test_error_message_contains_ticket_key_for_missing(self): + """Error message mentions the ticket key when no checkpoint is found.""" + with _patch_checkpoint(None): + _stats, error = await get_workflow_stats_or_error("AISOS-999") + assert error is not None + assert "AISOS-999" in error + + @pytest.mark.asyncio + async def test_returns_none_stats_when_legacy_checkpoint(self): + """Returns (None, error_str) for legacy checkpoints without stats.""" + legacy_state = { + "ticket_key": _TICKET, + "ticket_type": "Feature", + "current_node": "done", + } + with _patch_checkpoint(legacy_state): + stats, error = await get_workflow_stats_or_error(_TICKET) + assert stats is None + assert error is not None + + @pytest.mark.asyncio + async def test_error_message_is_display_ready_string(self): + """Error message is a non-empty string when stats are unavailable.""" + with _patch_checkpoint(None): + _stats, error = await get_workflow_stats_or_error(_TICKET) + assert isinstance(error, str) + assert len(error) > 0 + + @pytest.mark.asyncio + async def test_exception_from_checkpointer_returns_error_not_raises(self): + """ConnectionError from get_checkpoint_state yields (None, error_str).""" + with patch( + "forge.stats.retrieval.get_checkpoint_state", + new=AsyncMock(side_effect=ConnectionError("Redis unavailable")), + ): + stats, error = await get_workflow_stats_or_error(_TICKET) + assert stats is None + assert error is not None + + @pytest.mark.asyncio + async def test_error_message_contains_ticket_key_on_exception(self): + """Error message mentions the ticket key when an exception occurs.""" + with patch( + "forge.stats.retrieval.get_checkpoint_state", + new=AsyncMock(side_effect=RuntimeError("unexpected")), + ): + _stats, error = await get_workflow_stats_or_error("MYPROJ-77") + assert error is not None + assert "MYPROJ-77" in error + + @pytest.mark.asyncio + async def test_runtime_error_does_not_propagate(self): + """RuntimeError from checkpointer is caught; no exception raised.""" + with patch( + "forge.stats.retrieval.get_checkpoint_state", + new=AsyncMock(side_effect=RuntimeError("oops")), + ): + # Should not raise + result = await get_workflow_stats_or_error(_TICKET) + assert result[0] is None + + @pytest.mark.asyncio + async def test_exactly_one_element_is_none(self): + """Exactly one of (stats, error) is always None on success.""" + state = _full_state() + with _patch_checkpoint(state): + stats, error = await get_workflow_stats_or_error(_TICKET) + # On success: stats is set, error is None + assert (stats is None) != (error is None) + + @pytest.mark.asyncio + async def test_exactly_one_element_is_none_on_failure(self): + """Exactly one of (stats, error) is always None on failure.""" + with _patch_checkpoint(None): + stats, error = await get_workflow_stats_or_error(_TICKET) + # On failure: stats is None, error is set + assert (stats is None) != (error is None) + + @pytest.mark.asyncio + async def test_stats_fields_correct_on_success(self): + """Returned WorkflowStats has correct fields populated.""" + state = _full_state( + workflow_outcome="Completed", + stats_ci_cycles=3, + stats_pr_urls=["https://github.com/org/repo/pull/5"], + ) + with _patch_checkpoint(state): + stats, _error = await get_workflow_stats_or_error(_TICKET) + assert stats is not None + assert stats.outcome == "Completed" + assert stats.ci_cycles == 3 + assert stats.pr_urls == ["https://github.com/org/repo/pull/5"] + + +# --------------------------------------------------------------------------- +# Import paths +# --------------------------------------------------------------------------- + + +class TestImportPaths: + """Verify the public API is importable from the package root.""" + + def test_workflow_stats_importable_from_package(self): + """WorkflowStats is importable from forge.stats.""" + from forge.stats import WorkflowStats as WS # noqa: F401 + + assert WS is WorkflowStats + + def test_get_workflow_stats_importable_from_package(self): + """get_workflow_stats is importable from forge.stats.""" + from forge.stats import get_workflow_stats as gws + + assert gws is get_workflow_stats + + def test_get_workflow_stats_or_error_importable_from_package(self): + """get_workflow_stats_or_error is importable from forge.stats.""" + from forge.stats import get_workflow_stats_or_error as gwsoe + + assert gwsoe is get_workflow_stats_or_error diff --git a/tests/unit/test_cli_stats.py b/tests/unit/test_cli_stats.py new file mode 100644 index 00000000..0e294213 --- /dev/null +++ b/tests/unit/test_cli_stats.py @@ -0,0 +1,637 @@ +"""Unit tests for the forge stats CLI command.""" + +import argparse +import json +from unittest.mock import AsyncMock, patch + +import pytest + +from forge.cli import cmd_stats + + +def _make_args(ticket: str = "AISOS-123", json_flag: bool = False) -> argparse.Namespace: + """Create a minimal argparse.Namespace for cmd_stats.""" + return argparse.Namespace(ticket=ticket, json=json_flag) + + +def _base_state(ticket_key: str = "AISOS-123", **overrides) -> dict: + """Return a minimal workflow state dict with stats data.""" + state: dict = { + "ticket_key": ticket_key, + "ticket_type": "Feature", + "current_node": "prd_approval_gate", + "is_paused": False, + "is_blocked": False, + "last_error": None, + "feedback_comment": None, + "context": {}, + "stage_timestamps": { + "prd": { + "stage_name": "prd", + "iteration_count": 1, + "machine_time_seconds": 30.0, + "human_time_seconds": 120.0, + "input_tokens": 500, + "output_tokens": 800, + } + }, + "stats_pr_urls": ["https://github.com/org/repo/pull/42"], + "stats_ci_cycles": 2, + "workflow_outcome": None, + "stats_outcome_reason": None, + } + state.update(overrides) + return state + + +# --------------------------------------------------------------------------- +# Argument parsing +# --------------------------------------------------------------------------- + + +class TestArgParsing: + """Tests for argument parsing.""" + + def test_stats_subparser_ticket_argument(self): + """forge stats ticket argument is parsed correctly.""" + parser = argparse.ArgumentParser(prog="forge") + subparsers = parser.add_subparsers(dest="command") + stats_parser = subparsers.add_parser("stats") + stats_parser.add_argument("ticket") + stats_parser.add_argument("--json", action="store_true") + + args = parser.parse_args(["stats", "AISOS-123"]) + assert args.command == "stats" + assert args.ticket == "AISOS-123" + assert args.json is False + + def test_stats_json_flag_true(self): + """--json flag is parsed as True when provided.""" + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="command") + stats_parser = subparsers.add_parser("stats") + stats_parser.add_argument("ticket") + stats_parser.add_argument("--json", action="store_true") + + args = parser.parse_args(["stats", "AISOS-123", "--json"]) + assert args.json is True + + def test_stats_json_flag_default_false(self): + """--json flag defaults to False when not provided.""" + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="command") + stats_parser = subparsers.add_parser("stats") + stats_parser.add_argument("ticket") + stats_parser.add_argument("--json", action="store_true") + + args = parser.parse_args(["stats", "PROJ-99"]) + assert args.json is False + + def test_ticket_argument_is_required(self): + """ticket positional argument is required (no default).""" + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="command") + stats_parser = subparsers.add_parser("stats") + stats_parser.add_argument("ticket") + stats_parser.add_argument("--json", action="store_true") + + with pytest.raises(SystemExit): + parser.parse_args(["stats"]) + + +# --------------------------------------------------------------------------- +# Missing checkpoint +# --------------------------------------------------------------------------- + + +class TestMissingCheckpoint: + """Tests for missing or absent checkpoint state.""" + + @pytest.mark.asyncio + async def test_returns_exit_code_1_when_no_checkpoint(self, capsys): + """Returns exit code 1 when get_checkpoint_state returns None.""" + args = _make_args("AISOS-123") + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=None), + ): + result = await cmd_stats(args) + + assert result == 1 + captured = capsys.readouterr() + assert "No workflow data found for AISOS-123" in captured.out + + @pytest.mark.asyncio + async def test_missing_message_includes_ticket_key(self, capsys): + """Error message mentions the specific ticket key.""" + args = _make_args("MYPROJ-999") + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=None), + ): + result = await cmd_stats(args) + + assert result == 1 + captured = capsys.readouterr() + assert "MYPROJ-999" in captured.out + + @pytest.mark.asyncio + async def test_returns_exit_code_1_when_stage_timestamps_key_absent(self, capsys): + """Returns exit code 1 when stage_timestamps key is not in state.""" + state_without_stats = { + "ticket_key": "AISOS-123", + "ticket_type": "Feature", + "current_node": "prd_approval_gate", + } + args = _make_args("AISOS-123") + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state_without_stats), + ): + result = await cmd_stats(args) + + assert result == 1 + captured = capsys.readouterr() + assert "No workflow data found for AISOS-123" in captured.out + + @pytest.mark.asyncio + async def test_connection_error_returns_exit_code_1(self, capsys): + """Returns exit code 1 when get_checkpoint_state raises an exception.""" + args = _make_args("AISOS-123") + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(side_effect=ConnectionError("Redis unavailable")), + ): + result = await cmd_stats(args) + + assert result == 1 + captured = capsys.readouterr() + assert "Error" in captured.err + + @pytest.mark.asyncio + async def test_generic_exception_returns_exit_code_1(self): + """Returns exit code 1 for any unexpected exception from checkpointer.""" + args = _make_args("AISOS-123") + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(side_effect=RuntimeError("unexpected")), + ): + result = await cmd_stats(args) + + assert result == 1 + + @pytest.mark.asyncio + async def test_connection_error_prints_ticket_in_stderr(self, capsys): + """Error message includes ticket key in stderr.""" + args = _make_args("AISOS-777") + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(side_effect=ConnectionError("Redis unavailable")), + ): + await cmd_stats(args) + + captured = capsys.readouterr() + assert "AISOS-777" in captured.err + + +# --------------------------------------------------------------------------- +# Plain text output +# --------------------------------------------------------------------------- + + +class TestPlainTextOutput: + """Tests for human-readable table output (no --json flag).""" + + @pytest.mark.asyncio + async def test_returns_exit_code_0_on_success(self): + """Returns exit code 0 when stats are found and displayed.""" + args = _make_args("AISOS-123") + state = _base_state() + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ): + result = await cmd_stats(args) + + assert result == 0 + + @pytest.mark.asyncio + async def test_output_contains_stats_heading(self, capsys): + """Output contains the 'Workflow Statistics' heading.""" + args = _make_args("AISOS-123") + state = _base_state() + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ): + await cmd_stats(args) + + captured = capsys.readouterr() + assert "Workflow Statistics" in captured.out + + @pytest.mark.asyncio + async def test_output_contains_outcome(self, capsys): + """Output contains the Outcome line.""" + args = _make_args("AISOS-123") + state = _base_state() + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ): + await cmd_stats(args) + + captured = capsys.readouterr() + assert "Outcome" in captured.out + + @pytest.mark.asyncio + async def test_output_contains_stage_label(self, capsys): + """Output contains PRD stage label.""" + args = _make_args("AISOS-123") + state = _base_state() + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ): + await cmd_stats(args) + + captured = capsys.readouterr() + assert "PRD" in captured.out + + @pytest.mark.asyncio + async def test_output_is_not_json(self, capsys): + """Plain text output is not valid JSON.""" + args = _make_args("AISOS-123") + state = _base_state() + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ): + await cmd_stats(args) + + captured = capsys.readouterr() + try: + json.loads(captured.out) + is_json = True + except (json.JSONDecodeError, ValueError): + is_json = False + assert not is_json + + @pytest.mark.asyncio + async def test_empty_stages_still_returns_exit_code_0(self): + """Empty stage_timestamps dict (present key, empty value) returns exit 0.""" + state = _base_state(stage_timestamps={}) + args = _make_args("AISOS-123") + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ): + result = await cmd_stats(args) + + assert result == 0 + + +# --------------------------------------------------------------------------- +# JSON output +# --------------------------------------------------------------------------- + + +class TestJsonOutput: + """Tests for --json flag output.""" + + @pytest.mark.asyncio + async def test_json_flag_produces_valid_json(self, capsys): + """--json flag produces valid JSON output.""" + args = _make_args("AISOS-123", json_flag=True) + state = _base_state() + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ): + result = await cmd_stats(args) + + assert result == 0 + captured = capsys.readouterr() + data = json.loads(captured.out) + assert isinstance(data, dict) + + @pytest.mark.asyncio + async def test_json_contains_ticket_key(self, capsys): + """JSON output includes the ticket key.""" + args = _make_args("AISOS-456", json_flag=True) + state = _base_state(ticket_key="AISOS-456") + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ): + await cmd_stats(args) + + captured = capsys.readouterr() + data = json.loads(captured.out) + assert data["ticket"] == "AISOS-456" + + @pytest.mark.asyncio + async def test_json_contains_outcome_field(self, capsys): + """JSON output includes the outcome field.""" + args = _make_args("AISOS-123", json_flag=True) + state = _base_state() + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ): + await cmd_stats(args) + + captured = capsys.readouterr() + data = json.loads(captured.out) + assert "outcome" in data + + @pytest.mark.asyncio + async def test_json_contains_stages(self, capsys): + """JSON output includes the stages dict.""" + args = _make_args("AISOS-123", json_flag=True) + state = _base_state() + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ): + await cmd_stats(args) + + captured = capsys.readouterr() + data = json.loads(captured.out) + assert "stages" in data + assert "prd" in data["stages"] + + @pytest.mark.asyncio + async def test_json_contains_pr_urls(self, capsys): + """JSON output includes PR URLs list.""" + args = _make_args("AISOS-123", json_flag=True) + state = _base_state() + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ): + await cmd_stats(args) + + captured = capsys.readouterr() + data = json.loads(captured.out) + assert "pr_urls" in data + assert data["pr_urls"] == ["https://github.com/org/repo/pull/42"] + + @pytest.mark.asyncio + async def test_json_contains_ci_cycles(self, capsys): + """JSON output includes ci_cycles.""" + args = _make_args("AISOS-123", json_flag=True) + state = _base_state(stats_ci_cycles=5) + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ): + await cmd_stats(args) + + captured = capsys.readouterr() + data = json.loads(captured.out) + assert data["ci_cycles"] == 5 + + @pytest.mark.asyncio + async def test_json_returns_exit_code_0(self): + """--json flag returns exit code 0 on success.""" + args = _make_args("AISOS-123", json_flag=True) + state = _base_state() + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ): + result = await cmd_stats(args) + + assert result == 0 + + @pytest.mark.asyncio + async def test_json_contains_outcome_detail(self, capsys): + """JSON output includes outcome_detail.""" + args = _make_args("AISOS-123", json_flag=True) + state = _base_state(last_error="build failed", workflow_outcome=None) + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ): + await cmd_stats(args) + + captured = capsys.readouterr() + data = json.loads(captured.out) + assert "outcome_detail" in data + assert data["outcome_detail"] == "build failed" + + @pytest.mark.asyncio + async def test_json_empty_stages(self, capsys): + """JSON output with empty stages contains empty stages dict.""" + args = _make_args("AISOS-123", json_flag=True) + state = _base_state(stage_timestamps={}) + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ): + await cmd_stats(args) + + data = json.loads(capsys.readouterr().out) + assert data["stages"] == {} + + +# --------------------------------------------------------------------------- +# Outcome derivation +# --------------------------------------------------------------------------- + + +class TestOutcomeDerivation: + """Tests for outcome derivation logic.""" + + @pytest.mark.asyncio + async def test_pre_set_workflow_outcome_used(self, capsys): + """workflow_outcome field is used when set.""" + args = _make_args("AISOS-123", json_flag=True) + state = _base_state(workflow_outcome="Completed") + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ): + await cmd_stats(args) + + data = json.loads(capsys.readouterr().out) + assert data["outcome"] == "Completed" + + @pytest.mark.asyncio + async def test_blocked_outcome_from_is_blocked(self, capsys): + """Outcome is 'Blocked' when is_blocked is True.""" + args = _make_args("AISOS-123", json_flag=True) + state = _base_state( + is_blocked=True, + workflow_outcome=None, + feedback_comment="waiting on PM", + ) + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ): + await cmd_stats(args) + + data = json.loads(capsys.readouterr().out) + assert data["outcome"] == "Blocked" + assert data["outcome_detail"] == "waiting on PM" + + @pytest.mark.asyncio + async def test_failed_outcome_from_last_error(self, capsys): + """Outcome is 'Failed' when last_error is set.""" + args = _make_args("AISOS-123", json_flag=True) + state = _base_state( + is_blocked=False, + workflow_outcome=None, + last_error="connection timeout", + ) + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ): + await cmd_stats(args) + + data = json.loads(capsys.readouterr().out) + assert data["outcome"] == "Failed" + assert data["outcome_detail"] == "connection timeout" + + @pytest.mark.asyncio + async def test_in_progress_outcome_when_no_signals(self, capsys): + """Outcome defaults to 'In Progress' when no outcome signals found.""" + args = _make_args("AISOS-123", json_flag=True) + state = _base_state(is_blocked=False, workflow_outcome=None, last_error=None) + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ): + await cmd_stats(args) + + data = json.loads(capsys.readouterr().out) + assert data["outcome"] == "In Progress" + assert data["outcome_detail"] is None + + @pytest.mark.asyncio + async def test_stats_outcome_reason_used_as_detail(self, capsys): + """stats_outcome_reason is used as outcome_detail when present.""" + args = _make_args("AISOS-123", json_flag=True) + state = _base_state( + workflow_outcome="Blocked", + stats_outcome_reason="manual hold by PM", + ) + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ): + await cmd_stats(args) + + data = json.loads(capsys.readouterr().out) + assert data["outcome_detail"] == "manual hold by PM" + + @pytest.mark.asyncio + async def test_workflow_outcome_precedence_over_is_blocked(self, capsys): + """Pre-set workflow_outcome takes precedence over is_blocked flag.""" + args = _make_args("AISOS-123", json_flag=True) + state = _base_state(workflow_outcome="Completed", is_blocked=True) + with patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ): + await cmd_stats(args) + + data = json.loads(capsys.readouterr().out) + assert data["outcome"] == "Completed" + + +# --------------------------------------------------------------------------- +# Formatter integration +# --------------------------------------------------------------------------- + + +class TestFormatterIntegration: + """Tests that format_stats_summary is called correctly.""" + + @pytest.mark.asyncio + async def test_format_stats_summary_called_for_plain_text(self, capsys): + """format_stats_summary is invoked for plain text output.""" + args = _make_args("AISOS-123") + state = _base_state() + + with ( + patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ), + patch( + "forge.workflow.stats.formatter.format_stats_summary", + return_value="mocked summary", + ) as mock_fmt, + ): + await cmd_stats(args) + + mock_fmt.assert_called_once() + assert "mocked summary" in capsys.readouterr().out + + @pytest.mark.asyncio + async def test_format_stats_summary_receives_correct_outcome(self): + """format_stats_summary is called with derived outcome.""" + args = _make_args("AISOS-123") + state = _base_state(workflow_outcome="Completed") + + with ( + patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ), + patch( + "forge.workflow.stats.formatter.format_stats_summary", + return_value="ok", + ) as mock_fmt, + ): + await cmd_stats(args) + + call_args = mock_fmt.call_args + assert call_args[0][1] == "Completed" + + @pytest.mark.asyncio + async def test_format_stats_summary_receives_pricing(self): + """format_stats_summary is called with the pricing dictionary.""" + args = _make_args("AISOS-123") + state = _base_state() + from forge.config import get_settings + + settings = get_settings() + + with ( + patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ), + patch( + "forge.workflow.stats.formatter.format_stats_summary", + return_value="ok", + ) as mock_fmt, + ): + await cmd_stats(args) + + mock_fmt.assert_called_once() + assert mock_fmt.call_args.kwargs.get("pricing") == settings.llm_pricing + + @pytest.mark.asyncio + async def test_format_stats_summary_not_called_for_json(self): + """format_stats_summary is NOT called when --json flag is set.""" + args = _make_args("AISOS-123", json_flag=True) + state = _base_state() + + with ( + patch( + "forge.orchestrator.checkpointer.get_checkpoint_state", + new=AsyncMock(return_value=state), + ), + patch( + "forge.workflow.stats.formatter.format_stats_summary", + return_value="should not appear", + ) as mock_fmt, + ): + await cmd_stats(args) + + mock_fmt.assert_not_called() diff --git a/tests/unit/test_cli_weekly_report.py b/tests/unit/test_cli_weekly_report.py new file mode 100644 index 00000000..810bbd1e --- /dev/null +++ b/tests/unit/test_cli_weekly_report.py @@ -0,0 +1,501 @@ +"""Integration tests for the forge weekly-report CLI command.""" + +from __future__ import annotations + +import argparse +import json +import os +import tempfile +from unittest.mock import AsyncMock, patch + +import pytest + +from forge.cli import cmd_weekly_report +from forge.workflow.stats.weekly_report import ( + BottleneckAnalysis, + TicketSummary, + WeeklyReportData, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_args( + project: str = "PROJ", + days: int = 7, + output: str | None = None, + fmt: str = "text", +) -> argparse.Namespace: + """Create a minimal argparse.Namespace for cmd_weekly_report.""" + return argparse.Namespace(project=project, days=days, output=output, format=fmt) + + +def _make_report(project: str = "PROJ", days: int = 7, **overrides) -> WeeklyReportData: + """Return a WeeklyReportData with one completed ticket for testing.""" + completed = [ + TicketSummary( + ticket_key=f"{project}-1", + ticket_type="Feature", + status="completed", + duration_seconds=3600.0, + input_tokens=1000, + output_tokens=500, + ) + ] + data = WeeklyReportData( + project=project, + period_days=days, + report_start="2024-01-01T00:00:00+00:00", + report_end="2024-01-08T00:00:00+00:00", + completed_tickets=overrides.pop("completed_tickets", completed), + in_progress_tickets=overrides.pop("in_progress_tickets", []), + blocked_tickets=overrides.pop("blocked_tickets", []), + total_input_tokens=overrides.pop("total_input_tokens", 1000), + total_output_tokens=overrides.pop("total_output_tokens", 500), + avg_cycle_time=overrides.pop("avg_cycle_time", 3600.0), + bottlenecks=overrides.pop("bottlenecks", BottleneckAnalysis()), + ) + return data + + +def _empty_report(project: str = "PROJ") -> WeeklyReportData: + """Return a WeeklyReportData with no tickets.""" + return WeeklyReportData( + project=project, + period_days=7, + report_start="2024-01-01T00:00:00+00:00", + report_end="2024-01-08T00:00:00+00:00", + ) + + +# --------------------------------------------------------------------------- +# Argument parsing +# --------------------------------------------------------------------------- + + +class TestArgParsing: + """Tests for argument parsing of the weekly-report subparser.""" + + def _build_parser(self) -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(prog="forge") + subparsers = parser.add_subparsers(dest="command") + wr_parser = subparsers.add_parser("weekly-report") + wr_parser.add_argument("--project", required=True) + wr_parser.add_argument("--days", type=int, default=7) + wr_parser.add_argument("--output", default=None) + wr_parser.add_argument("--format", choices=["text", "markdown", "json"], default="text") + return parser + + def test_project_is_required(self): + """--project is required; missing it raises SystemExit.""" + parser = self._build_parser() + with pytest.raises(SystemExit): + parser.parse_args(["weekly-report"]) + + def test_project_is_parsed(self): + """--project value is captured correctly.""" + parser = self._build_parser() + args = parser.parse_args(["weekly-report", "--project", "MYPROJ"]) + assert args.project == "MYPROJ" + + def test_days_defaults_to_7(self): + """--days defaults to 7 when not provided.""" + parser = self._build_parser() + args = parser.parse_args(["weekly-report", "--project", "PROJ"]) + assert args.days == 7 + + def test_days_custom_value(self): + """--days accepts a custom integer.""" + parser = self._build_parser() + args = parser.parse_args(["weekly-report", "--project", "PROJ", "--days", "14"]) + assert args.days == 14 + + def test_output_defaults_to_none(self): + """--output defaults to None when not provided.""" + parser = self._build_parser() + args = parser.parse_args(["weekly-report", "--project", "PROJ"]) + assert args.output is None + + def test_output_path_captured(self): + """--output path is captured correctly.""" + parser = self._build_parser() + args = parser.parse_args(["weekly-report", "--project", "PROJ", "--output", "report.md"]) + assert args.output == "report.md" + + def test_format_defaults_to_text(self): + """--format defaults to 'text' when not provided.""" + parser = self._build_parser() + args = parser.parse_args(["weekly-report", "--project", "PROJ"]) + assert args.format == "text" + + def test_format_markdown(self): + """--format markdown is accepted.""" + parser = self._build_parser() + args = parser.parse_args(["weekly-report", "--project", "PROJ", "--format", "markdown"]) + assert args.format == "markdown" + + def test_format_json(self): + """--format json is accepted.""" + parser = self._build_parser() + args = parser.parse_args(["weekly-report", "--project", "PROJ", "--format", "json"]) + assert args.format == "json" + + def test_invalid_format_raises(self): + """An invalid --format value raises SystemExit.""" + parser = self._build_parser() + with pytest.raises(SystemExit): + parser.parse_args(["weekly-report", "--project", "PROJ", "--format", "xml"]) + + +# --------------------------------------------------------------------------- +# Text output (stdout) +# --------------------------------------------------------------------------- + + +class TestTextOutput: + """Tests for default text format output to stdout.""" + + @pytest.mark.asyncio + async def test_returns_exit_code_0_with_data(self, capsys): + """Returns 0 when data is available.""" + args = _make_args() + report = _make_report() + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + result = await cmd_weekly_report(args) + + assert result == 0 + + @pytest.mark.asyncio + async def test_stdout_contains_project_key(self, capsys): + """stdout contains the project key.""" + args = _make_args(project="MYPROJ") + report = _make_report(project="MYPROJ") + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + await cmd_weekly_report(args) + + captured = capsys.readouterr() + assert "MYPROJ" in captured.out + + @pytest.mark.asyncio + async def test_stdout_contains_ticket_key(self, capsys): + """stdout contains ticket keys from the report.""" + args = _make_args(project="PROJ") + report = _make_report(project="PROJ") + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + await cmd_weekly_report(args) + + captured = capsys.readouterr() + assert "PROJ-1" in captured.out + + @pytest.mark.asyncio + async def test_days_passed_to_collect(self): + """--days value is forwarded to collect_weekly_data.""" + args = _make_args(days=14) + report = _make_report(days=14) + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ) as mock_collect: + await cmd_weekly_report(args) + + mock_collect.assert_awaited_once_with("PROJ", days=14) + + +# --------------------------------------------------------------------------- +# Markdown output +# --------------------------------------------------------------------------- + + +class TestMarkdownOutput: + """Tests for markdown format output.""" + + @pytest.mark.asyncio + async def test_markdown_to_stdout(self, capsys): + """--format markdown outputs Markdown content to stdout.""" + args = _make_args(fmt="markdown") + report = _make_report() + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + result = await cmd_weekly_report(args) + + assert result == 0 + captured = capsys.readouterr() + # Markdown report starts with a heading + assert "# Weekly Report" in captured.out + + @pytest.mark.asyncio + async def test_markdown_contains_project(self, capsys): + """Markdown output contains the project name.""" + args = _make_args(project="ACME", fmt="markdown") + report = _make_report(project="ACME") + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + await cmd_weekly_report(args) + + captured = capsys.readouterr() + assert "ACME" in captured.out + + +# --------------------------------------------------------------------------- +# JSON output +# --------------------------------------------------------------------------- + + +class TestJsonOutput: + """Tests for JSON format output.""" + + @pytest.mark.asyncio + async def test_json_to_stdout(self, capsys): + """--format json outputs valid JSON to stdout.""" + args = _make_args(fmt="json") + report = _make_report() + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + result = await cmd_weekly_report(args) + + assert result == 0 + captured = capsys.readouterr() + # Should be valid JSON + parsed = json.loads(captured.out) + assert isinstance(parsed, dict) + + @pytest.mark.asyncio + async def test_json_contains_project_field(self, capsys): + """JSON output has a 'project' field matching the requested project.""" + args = _make_args(project="TESTPROJ", fmt="json") + report = _make_report(project="TESTPROJ") + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + await cmd_weekly_report(args) + + captured = capsys.readouterr() + parsed = json.loads(captured.out) + assert parsed["project"] == "TESTPROJ" + + +# --------------------------------------------------------------------------- +# File output +# --------------------------------------------------------------------------- + + +class TestFileOutput: + """Tests for writing report to a file via --output.""" + + @pytest.mark.asyncio + async def test_writes_to_file(self): + """Report is written to the specified file path.""" + report = _make_report() + + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as tmp: + tmp_path = tmp.name + + try: + args = _make_args(output=tmp_path) + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + result = await cmd_weekly_report(args) + + assert result == 0 + assert os.path.exists(tmp_path) + content = open(tmp_path, encoding="utf-8").read() + assert len(content) > 0 + finally: + os.unlink(tmp_path) + + @pytest.mark.asyncio + async def test_file_output_contains_project(self): + """Written file contains the project key.""" + report = _make_report(project="FILEPROJ") + + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as tmp: + tmp_path = tmp.name + + try: + args = _make_args(project="FILEPROJ", output=tmp_path) + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + await cmd_weekly_report(args) + + content = open(tmp_path, encoding="utf-8").read() + assert "FILEPROJ" in content + finally: + os.unlink(tmp_path) + + @pytest.mark.asyncio + async def test_stdout_not_written_when_output_file(self, capsys): + """stdout only contains confirmation message when --output is set.""" + report = _make_report() + + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as tmp: + tmp_path = tmp.name + + try: + args = _make_args(output=tmp_path) + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + await cmd_weekly_report(args) + + captured = capsys.readouterr() + # The report body should NOT be on stdout; only the confirmation + assert "Report written to" in captured.out + assert "WEEKLY REPORT" not in captured.out + finally: + os.unlink(tmp_path) + + @pytest.mark.asyncio + async def test_markdown_written_to_file(self): + """Markdown report is correctly written when format=markdown.""" + report = _make_report() + + with tempfile.NamedTemporaryFile(mode="w", suffix=".md", delete=False) as tmp: + tmp_path = tmp.name + + try: + args = _make_args(output=tmp_path, fmt="markdown") + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + result = await cmd_weekly_report(args) + + assert result == 0 + content = open(tmp_path, encoding="utf-8").read() + assert "# Weekly Report" in content + finally: + os.unlink(tmp_path) + + @pytest.mark.asyncio + async def test_unwritable_path_returns_exit_code_1(self, capsys): + """Returns exit code 1 when the output file cannot be created.""" + args = _make_args(output="/nonexistent_dir/report.txt") + report = _make_report() + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + result = await cmd_weekly_report(args) + + assert result == 1 + captured = capsys.readouterr() + assert "Error" in captured.err + + +# --------------------------------------------------------------------------- +# No data / graceful failure +# --------------------------------------------------------------------------- + + +class TestNoData: + """Tests for graceful failure when project has no data.""" + + @pytest.mark.asyncio + async def test_empty_report_returns_exit_code_1(self, capsys): + """Returns exit code 1 when no tickets are found for the project.""" + args = _make_args(project="EMPTY") + report = _empty_report(project="EMPTY") + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + result = await cmd_weekly_report(args) + + assert result == 1 + + @pytest.mark.asyncio + async def test_empty_report_error_message_contains_project(self, capsys): + """Error message mentions the project key.""" + args = _make_args(project="EMPTY") + report = _empty_report(project="EMPTY") + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(return_value=report), + ): + await cmd_weekly_report(args) + + captured = capsys.readouterr() + assert "EMPTY" in captured.err + + @pytest.mark.asyncio + async def test_collect_exception_returns_exit_code_1(self, capsys): + """Returns exit code 1 when collect_weekly_data raises an exception.""" + args = _make_args(project="PROJ") + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(side_effect=ConnectionError("Redis unavailable")), + ): + result = await cmd_weekly_report(args) + + assert result == 1 + + @pytest.mark.asyncio + async def test_collect_exception_error_printed_to_stderr(self, capsys): + """Exception from collect_weekly_data prints an error to stderr.""" + args = _make_args(project="PROJ") + + with patch( + "forge.workflow.stats.weekly_report.collect_weekly_data", + new=AsyncMock(side_effect=RuntimeError("something went wrong")), + ): + await cmd_weekly_report(args) + + captured = capsys.readouterr() + assert "Error" in captured.err + + +# --------------------------------------------------------------------------- +# Handler registration +# --------------------------------------------------------------------------- + + +class TestHandlerRegistration: + """Verify that weekly-report is wired into the CLI handlers dict.""" + + def test_weekly_report_in_handlers(self): + """cmd_weekly_report is importable and matches the CLI handler signature.""" + # Should be an async function + import asyncio + + from forge.cli import cmd_weekly_report as handler + + assert asyncio.iscoroutinefunction(handler) diff --git a/tests/unit/test_config_cost_alert.py b/tests/unit/test_config_cost_alert.py new file mode 100644 index 00000000..8ca687b2 --- /dev/null +++ b/tests/unit/test_config_cost_alert.py @@ -0,0 +1,119 @@ +"""Tests for stats cost alert threshold configuration settings.""" + +import json + +from forge.config import Settings + +REQUIRED_SETTINGS = { + "jira_base_url": "https://test.atlassian.net", + "jira_api_token": "test", + "jira_user_email": "test@example.com", + "github_token": "test", + "anthropic_api_key": "test", +} + + +class TestStatsCostAlertConfig: + def test_default_cost_alert_enabled_is_true(self): + settings = Settings(**REQUIRED_SETTINGS) + assert settings.stats_alert_enabled is True + + def test_default_cost_alert_threshold_tokens(self): + settings = Settings(**REQUIRED_SETTINGS) + assert settings.stats_alert_threshold_tokens == 1_000_000 + + def test_cost_alert_enabled_can_be_disabled(self): + settings = Settings(**REQUIRED_SETTINGS, stats_alert_enabled=False) + assert settings.stats_alert_enabled is False + + def test_cost_alert_threshold_can_be_customized(self): + settings = Settings(**REQUIRED_SETTINGS, stats_alert_threshold_tokens=500_000) + assert settings.stats_alert_threshold_tokens == 500_000 + + def test_cost_alert_threshold_accepts_large_values(self): + settings = Settings(**REQUIRED_SETTINGS, stats_alert_threshold_tokens=10_000_000) + assert settings.stats_alert_threshold_tokens == 10_000_000 + + def test_cost_alert_threshold_is_int(self): + settings = Settings(**REQUIRED_SETTINGS) + assert isinstance(settings.stats_alert_threshold_tokens, int) + + def test_cost_alert_enabled_is_bool(self): + settings = Settings(**REQUIRED_SETTINGS) + assert isinstance(settings.stats_alert_enabled, bool) + + +class TestStatsCostAlertDollarThreshold: + """Tests for the new stats_alert_threshold_cost setting.""" + + def test_default_dollar_threshold_is_none(self): + settings = Settings(**REQUIRED_SETTINGS) + assert settings.stats_alert_threshold_cost is None + + def test_dollar_threshold_can_be_set(self): + settings = Settings(**REQUIRED_SETTINGS, stats_alert_threshold_cost=10.0) + assert settings.stats_alert_threshold_cost == 10.0 + + def test_dollar_threshold_accepts_small_values(self): + settings = Settings(**REQUIRED_SETTINGS, stats_alert_threshold_cost=0.01) + assert settings.stats_alert_threshold_cost == 0.01 + + def test_dollar_threshold_is_float_when_set(self): + settings = Settings(**REQUIRED_SETTINGS, stats_alert_threshold_cost=5.0) + assert isinstance(settings.stats_alert_threshold_cost, float) + + +class TestLLMPricingConfig: + """Tests for the llm_pricing configuration field.""" + + def test_default_pricing_contains_claude_sonnet_4(self): + settings = Settings(**REQUIRED_SETTINGS) + assert "claude-sonnet-4" in settings.llm_pricing + + def test_default_pricing_contains_claude_opus_4(self): + settings = Settings(**REQUIRED_SETTINGS) + assert "claude-opus-4" in settings.llm_pricing + + def test_default_pricing_contains_gemini_models(self): + settings = Settings(**REQUIRED_SETTINGS) + assert "gemini-2.5-flash" in settings.llm_pricing + + def test_default_pricing_has_input_and_output_rates(self): + settings = Settings(**REQUIRED_SETTINGS) + for key, rates in settings.llm_pricing.items(): + assert "input" in rates, f"Missing 'input' rate for {key}" + assert "output" in rates, f"Missing 'output' rate for {key}" + + def test_pricing_rates_are_floats(self): + settings = Settings(**REQUIRED_SETTINGS) + for key, rates in settings.llm_pricing.items(): + assert isinstance(rates["input"], float), f"Input rate for {key} is not float" + assert isinstance(rates["output"], float), f"Output rate for {key} is not float" + + def test_custom_pricing_via_direct_field(self): + custom = {"my-model": {"input": 1.0, "output": 2.0}} + settings = Settings(**REQUIRED_SETTINGS, llm_pricing=custom) + assert settings.llm_pricing == custom + + def test_pricing_is_dict(self): + settings = Settings(**REQUIRED_SETTINGS) + assert isinstance(settings.llm_pricing, dict) + + def test_custom_pricing_from_json_string(self, monkeypatch): + """Pricing can be loaded from a JSON-encoded env var.""" + custom = {"test-model": {"input": 5.0, "output": 10.0}} + monkeypatch.setenv("LLM_PRICING", json.dumps(custom)) + settings = Settings(**REQUIRED_SETTINGS) + assert settings.llm_pricing == custom + + def test_default_claude_sonnet_4_rates(self): + settings = Settings(**REQUIRED_SETTINGS) + rates = settings.llm_pricing["claude-sonnet-4"] + assert rates["input"] == 3.00 + assert rates["output"] == 15.00 + + def test_default_claude_opus_4_rates(self): + settings = Settings(**REQUIRED_SETTINGS) + rates = settings.llm_pricing["claude-opus-4"] + assert rates["input"] == 15.00 + assert rates["output"] == 75.00 diff --git a/tests/unit/workflow/bug/test_graph_stats.py b/tests/unit/workflow/bug/test_graph_stats.py new file mode 100644 index 00000000..02344767 --- /dev/null +++ b/tests/unit/workflow/bug/test_graph_stats.py @@ -0,0 +1,143 @@ +"""Tests for stats posting integration in the Bug workflow graph. + +Verifies that post_terminal_stats is wired into the bug graph at all terminal +paths: successful post-merge completion and blocked escalation. +""" + +from forge.models.workflow import TicketType +from forge.workflow.bug.graph import build_bug_graph + + +def _bug_state(**overrides): + """Build a minimal bug state dict for routing tests.""" + base = { + "ticket_key": "BUG-1", + "ticket_type": TicketType.BUG, + "current_node": "start", + "is_paused": False, + "retry_count": 0, + "last_error": None, + "pr_merged": False, + } + return {**base, **overrides} + + +class TestBugGraphStatsNode: + """post_terminal_stats is present in the compiled bug graph.""" + + def test_post_terminal_stats_node_present(self): + """post_terminal_stats node is registered in the compiled graph.""" + graph = build_bug_graph() + compiled = graph.compile() + assert "post_terminal_stats" in compiled.nodes + + def test_graph_compiles_with_stats_node(self): + """Bug graph compiles without error after stats node integration.""" + graph = build_bug_graph() + compiled = graph.compile() + assert compiled is not None + + +class TestBugGraphTerminalEdges: + """All terminal paths in the bug graph route through post_terminal_stats.""" + + def test_post_merge_summary_routes_to_stats(self): + """post_merge_summary → post_terminal_stats edge exists (success path).""" + graph = build_bug_graph() + assert ("post_merge_summary", "post_terminal_stats") in graph.edges, ( + "post_merge_summary must route to post_terminal_stats" + ) + + def test_escalate_blocked_routes_to_stats(self): + """escalate_blocked → post_terminal_stats edge exists (blocked path).""" + graph = build_bug_graph() + assert ("escalate_blocked", "post_terminal_stats") in graph.edges, ( + "escalate_blocked must route to post_terminal_stats" + ) + + def test_post_terminal_stats_routes_to_end(self): + """post_terminal_stats → __end__ edge exists.""" + graph = build_bug_graph() + assert ("post_terminal_stats", "__end__") in graph.edges, ( + "post_terminal_stats must route to END" + ) + + def test_post_merge_summary_does_not_route_directly_to_end(self): + """post_merge_summary does NOT have a direct edge to END (stats must be between).""" + graph = build_bug_graph() + assert ("post_merge_summary", "__end__") not in graph.edges, ( + "post_merge_summary must NOT edge directly to END; post_terminal_stats must be between" + ) + + def test_escalate_blocked_does_not_route_directly_to_end(self): + """escalate_blocked does NOT have a direct edge to END (stats must be between).""" + graph = build_bug_graph() + assert ("escalate_blocked", "__end__") not in graph.edges, ( + "escalate_blocked must NOT edge directly to END; post_terminal_stats must be between" + ) + + +class TestBugGraphStatsOrdering: + """Stats posting occurs AFTER other terminal actions.""" + + def test_success_path_order(self): + """Success path: post_merge_summary → post_terminal_stats → END.""" + graph = build_bug_graph() + edges = graph.edges + assert ("post_merge_summary", "post_terminal_stats") in edges, ( + "post_merge_summary must edge to post_terminal_stats" + ) + assert ("post_terminal_stats", "__end__") in edges, "post_terminal_stats must edge to END" + + def test_blocked_path_order(self): + """Blocked path: escalate_blocked → post_terminal_stats → END.""" + graph = build_bug_graph() + edges = graph.edges + assert ("escalate_blocked", "post_terminal_stats") in edges, ( + "escalate_blocked must edge to post_terminal_stats" + ) + assert ("post_terminal_stats", "__end__") in edges, "post_terminal_stats must edge to END" + + def test_stats_is_last_before_end(self): + """post_terminal_stats is the single gateway to END for terminal paths.""" + graph = build_bug_graph() + # Only post_terminal_stats should have a direct edge to __end__ + # (other terminal nodes go through stats first) + direct_to_end = {src for (src, dst) in graph.edges if dst == "__end__"} + # post_terminal_stats must be one such node + assert "post_terminal_stats" in direct_to_end, ( + "post_terminal_stats must have edge to __end__" + ) + # Neither escalate_blocked nor post_merge_summary should bypass stats + assert "escalate_blocked" not in direct_to_end, ( + "escalate_blocked must not directly edge to __end__" + ) + assert "post_merge_summary" not in direct_to_end, ( + "post_merge_summary must not directly edge to __end__" + ) + + +class TestBugGraphAllNodesPresent: + """Bug graph still contains all expected nodes after stats integration.""" + + def test_all_core_nodes_still_present(self): + """Core pipeline nodes are still registered after stats node addition.""" + graph = build_bug_graph() + compiled = graph.compile() + expected_nodes = { + "triage_check", + "triage_gate", + "analyze_bug", + "reflect_rca", + "rca_option_gate", + "regenerate_rca", + "plan_bug_fix", + "plan_approval_gate", + "regenerate_plan", + "decompose_plan", + "post_merge_summary", + "post_terminal_stats", + "escalate_blocked", + } + for node in expected_nodes: + assert node in compiled.nodes, f"Node '{node}' missing from compiled graph" diff --git a/tests/unit/workflow/feature/test_graph_stats.py b/tests/unit/workflow/feature/test_graph_stats.py new file mode 100644 index 00000000..a9de154f --- /dev/null +++ b/tests/unit/workflow/feature/test_graph_stats.py @@ -0,0 +1,194 @@ +"""Tests for stats posting integration in the Feature workflow graph. + +Verifies that: +- post_terminal_stats node is present in the compiled graph +- All terminal paths (success, blocked, failure) route through post_terminal_stats +- post_terminal_stats is the last node before END +- Unrecoverable failure routing functions return "post_terminal_stats" +""" + +from forge.models.workflow import TicketType +from forge.workflow.feature.graph import ( + _route_after_epic_decomposition, + _route_after_generation, + _route_after_spec_generation, + _route_after_task_generation, + build_feature_graph, +) + + +def _feature_state(**overrides): + """Build a minimal feature state dict for routing tests.""" + base = { + "ticket_key": "FEAT-1", + "ticket_type": TicketType.FEATURE, + "current_node": "start", + "is_paused": False, + "retry_count": 0, + "last_error": None, + "prd_content": "", + "spec_content": "", + "epic_keys": [], + "task_keys": [], + "pr_urls": [], + } + return {**base, **overrides} + + +class TestFeatureGraphStatsNode: + """post_terminal_stats is present in the compiled feature graph.""" + + def test_post_terminal_stats_node_present(self): + """post_terminal_stats node is registered in the compiled graph.""" + graph = build_feature_graph() + compiled = graph.compile() + assert "post_terminal_stats" in compiled.nodes + + def test_post_terminal_stats_node_is_reachable(self): + """post_terminal_stats appears in the compiled graph node set.""" + graph = build_feature_graph() + compiled = graph.compile() + # Node must be reachable — confirm it's not just a stub + node_keys = set(compiled.nodes.keys()) + assert "post_terminal_stats" in node_keys + + +class TestFeatureTerminalPathsRouteToStats: + """All terminal paths in the feature graph route through post_terminal_stats.""" + + def test_prd_generation_failure_routes_to_stats(self): + """generate_prd failure (no prd_content, has error) routes to post_terminal_stats.""" + state = _feature_state(last_error="LLM timeout", prd_content="") + result = _route_after_generation(state) + assert result == "post_terminal_stats" + + def test_prd_generation_success_does_not_route_to_stats(self): + """generate_prd success routes to prd_approval_gate, not post_terminal_stats.""" + state = _feature_state(last_error=None, prd_content="Some PRD content") + result = _route_after_generation(state) + assert result == "prd_approval_gate" + assert result != "post_terminal_stats" + + def test_prd_generation_error_with_content_does_not_route_to_stats(self): + """generate_prd with error but existing content goes to gate (not terminal failure).""" + state = _feature_state(last_error="minor error", prd_content="Existing PRD") + result = _route_after_generation(state) + assert result == "prd_approval_gate" + + def test_spec_generation_failure_routes_to_stats(self): + """generate_spec failure (no spec_content, has error) routes to post_terminal_stats.""" + state = _feature_state(last_error="LLM timeout", spec_content="") + result = _route_after_spec_generation(state) + assert result == "post_terminal_stats" + + def test_spec_generation_success_does_not_route_to_stats(self): + """generate_spec success routes to spec_approval_gate.""" + state = _feature_state(last_error=None, spec_content="Some spec content") + result = _route_after_spec_generation(state) + assert result == "spec_approval_gate" + + def test_epic_decomposition_failure_routes_to_stats(self): + """decompose_epics failure (no epic_keys, has error) routes to post_terminal_stats.""" + state = _feature_state(last_error="Epic decomposition failed", epic_keys=[]) + result = _route_after_epic_decomposition(state) + assert result == "post_terminal_stats" + + def test_epic_decomposition_success_does_not_route_to_stats(self): + """decompose_epics success routes to plan_approval_gate.""" + state = _feature_state(last_error=None, epic_keys=["FEAT-10", "FEAT-11"]) + result = _route_after_epic_decomposition(state) + assert result == "plan_approval_gate" + + def test_task_generation_failure_routes_to_stats(self): + """generate_tasks failure (no task_keys, has error) routes to post_terminal_stats.""" + state = _feature_state(last_error="Task generation failed", task_keys=[]) + result = _route_after_task_generation(state) + assert result == "post_terminal_stats" + + def test_task_generation_success_does_not_route_to_stats(self): + """generate_tasks success routes to task_approval_gate.""" + state = _feature_state(last_error=None, task_keys=["FEAT-20", "FEAT-21"]) + result = _route_after_task_generation(state) + assert result == "task_approval_gate" + + +class TestFeatureGraphEdgeStructure: + """Verify graph edge structure ensures stats posting on all terminal paths.""" + + def test_escalate_blocked_has_edge_to_post_terminal_stats(self): + """escalate_blocked edges directly to post_terminal_stats (blocked terminal path).""" + graph = build_feature_graph() + # Use the uncompiled graph's edges set (tuples of (from, to)) + assert ("escalate_blocked", "post_terminal_stats") in graph.edges, ( + "escalate_blocked must route to post_terminal_stats" + ) + + def test_aggregate_feature_status_has_edge_to_post_terminal_stats(self): + """aggregate_feature_status edges to post_terminal_stats (success terminal path).""" + graph = build_feature_graph() + assert ("aggregate_feature_status", "post_terminal_stats") in graph.edges, ( + "aggregate_feature_status must route to post_terminal_stats" + ) + + def test_post_terminal_stats_has_edge_to_end(self): + """post_terminal_stats has an outgoing edge to END (__end__).""" + graph = build_feature_graph() + assert ("post_terminal_stats", "__end__") in graph.edges, ( + "post_terminal_stats must route to END" + ) + + def test_graph_compiles_successfully(self): + """build_feature_graph() compiles without error after stats node addition.""" + graph = build_feature_graph() + compiled = graph.compile() + assert compiled is not None + + def test_success_path_flows_through_stats_before_end(self): + """The success path aggregate_feature_status → post_terminal_stats → END is wired.""" + graph = build_feature_graph() + edges = graph.edges + assert ("aggregate_feature_status", "post_terminal_stats") in edges, ( + "aggregate_feature_status must edge to post_terminal_stats" + ) + assert ("post_terminal_stats", "__end__") in edges, "post_terminal_stats must edge to END" + + +class TestFeatureGraphStatsOrdering: + """Stats posting occurs AFTER other terminal actions.""" + + def test_aggregate_feature_status_is_penultimate_node(self): + """Success path: complete_tasks → aggregate_epic_status → aggregate_feature_status → post_terminal_stats → END.""" + graph = build_feature_graph() + edges = graph.edges + + assert ("complete_tasks", "aggregate_epic_status") in edges, ( + "complete_tasks must edge to aggregate_epic_status" + ) + assert ("aggregate_epic_status", "aggregate_feature_status") in edges, ( + "aggregate_epic_status must edge to aggregate_feature_status" + ) + assert ("aggregate_feature_status", "post_terminal_stats") in edges, ( + "aggregate_feature_status must edge to post_terminal_stats (stats after status)" + ) + + def test_escalate_blocked_routes_directly_to_stats(self): + """escalate_blocked → post_terminal_stats (stats right after blocked action).""" + graph = build_feature_graph() + assert ("escalate_blocked", "post_terminal_stats") in graph.edges, ( + "escalate_blocked must directly edge to post_terminal_stats" + ) + + def test_aggregate_feature_status_does_not_edge_to_end_directly(self): + """aggregate_feature_status does NOT have a direct edge to END (stats must be between).""" + graph = build_feature_graph() + assert ("aggregate_feature_status", "__end__") not in graph.edges, ( + "aggregate_feature_status must NOT edge directly to END; " + "post_terminal_stats must be between" + ) + + def test_escalate_blocked_does_not_edge_to_end_directly(self): + """escalate_blocked does NOT have a direct edge to END (stats must be between).""" + graph = build_feature_graph() + assert ("escalate_blocked", "__end__") not in graph.edges, ( + "escalate_blocked must NOT edge directly to END; post_terminal_stats must be between" + ) diff --git a/tests/unit/workflow/nodes/test_ci_attempt_tracking.py b/tests/unit/workflow/nodes/test_ci_attempt_tracking.py index 59950ab6..e619568a 100644 --- a/tests/unit/workflow/nodes/test_ci_attempt_tracking.py +++ b/tests/unit/workflow/nodes/test_ci_attempt_tracking.py @@ -1,11 +1,11 @@ """Unit tests for CI attempt tracking (AISOS-654).""" -import pytest from unittest.mock import AsyncMock, MagicMock, patch -from forge.workflow.nodes.ci_evaluator import evaluate_ci_status -from forge.workflow.feature.state import FeatureState +import pytest +from forge.workflow.feature.state import FeatureState +from forge.workflow.nodes.ci_evaluator import attempt_ci_fix, evaluate_ci_status # ── Helpers ─────────────────────────────────────────────────────────────────── @@ -44,22 +44,26 @@ class TestCIAttemptTrackingStateFields: def test_current_attempt_in_ci_integration_state(self): """current_attempt must be a field in CIIntegrationState.""" from forge.workflow.base import CIIntegrationState + assert "ci_fix_attempt" in CIIntegrationState.__annotations__ def test_max_attempts_in_ci_integration_state(self): """max_attempts must be a field in CIIntegrationState.""" from forge.workflow.base import CIIntegrationState + assert "ci_fix_max_attempts" in CIIntegrationState.__annotations__ def test_feature_state_initializes_current_attempt_to_zero(self): """Feature state should initialize current_attempt to 0.""" from forge.workflow.feature.state import create_initial_feature_state + state = create_initial_feature_state(ticket_key="TEST-1") assert state.get("ci_fix_attempt") == 0 def test_feature_state_initializes_max_attempts_from_config(self): """Feature state should initialize max_attempts from config.""" from forge.workflow.feature.state import create_initial_feature_state + state = create_initial_feature_state(ticket_key="TEST-1") # Default config value is 5 assert state.get("ci_fix_max_attempts") is not None @@ -68,12 +72,14 @@ def test_feature_state_initializes_max_attempts_from_config(self): def test_bug_state_initializes_current_attempt_to_zero(self): """Bug state should initialize current_attempt to 0.""" from forge.workflow.bug.state import create_initial_bug_state + state = create_initial_bug_state(ticket_key="TEST-2") assert state.get("ci_fix_attempt") == 0 def test_bug_state_initializes_max_attempts_from_config(self): """Bug state should initialize max_attempts from config.""" from forge.workflow.bug.state import create_initial_bug_state + state = create_initial_bug_state(ticket_key="TEST-2") # Default config value is 5 assert state.get("ci_fix_max_attempts") is not None @@ -90,7 +96,7 @@ class TestCIAttemptIncrement: async def test_first_ci_failure_increments_attempt_to_one(self): """First CI failure should increment current_attempt from 0 to 1.""" state = create_base_state(ci_fix_attempt=0, ci_fix_max_attempts=3) - + github = create_mock_github_client() github.get_pull_request.return_value = {"head": {"sha": "abc123"}} github.get_check_runs.return_value = [ @@ -116,7 +122,7 @@ async def test_first_ci_failure_increments_attempt_to_one(self): async def test_second_ci_failure_increments_attempt_to_two(self): """Second CI failure should increment current_attempt from 1 to 2.""" state = create_base_state(ci_fix_attempt=1, ci_fix_max_attempts=3) - + github = create_mock_github_client() github.get_pull_request.return_value = {"head": {"sha": "abc123"}} github.get_check_runs.return_value = [ @@ -142,7 +148,7 @@ async def test_second_ci_failure_increments_attempt_to_two(self): async def test_third_ci_failure_increments_attempt_to_three(self): """Third CI failure should increment current_attempt from 2 to 3.""" state = create_base_state(ci_fix_attempt=2, ci_fix_max_attempts=3) - + github = create_mock_github_client() github.get_pull_request.return_value = {"head": {"sha": "abc123"}} github.get_check_runs.return_value = [ @@ -175,7 +181,7 @@ class TestCIAttemptLimitValidation: async def test_attempt_at_max_limit_blocks_further_attempts(self): """When current_attempt equals max_attempts, no more attempts should be made.""" state = create_base_state(ci_fix_attempt=3, ci_fix_max_attempts=3) - + github = create_mock_github_client() github.get_pull_request.return_value = {"head": {"sha": "abc123"}} github.get_check_runs.return_value = [ @@ -192,7 +198,9 @@ async def test_attempt_at_max_limit_blocks_further_attempts(self): with patch("forge.workflow.nodes.ci_evaluator.get_settings") as mock_settings: mock_settings.return_value.ci_fix_max_retries = 5 mock_settings.return_value.ignored_ci_checks = ["tide"] - with patch("forge.workflow.nodes.ci_evaluator.record_ci_fix_attempt") as mock_record: + with patch( + "forge.workflow.nodes.ci_evaluator.record_ci_fix_attempt" + ) as mock_record: result = await evaluate_ci_status(state) # Should not increment or route to attempt_ci_fix @@ -205,7 +213,7 @@ async def test_attempt_at_max_limit_blocks_further_attempts(self): async def test_attempt_exceeding_max_limit_blocks_further_attempts(self): """When current_attempt exceeds max_attempts, no more attempts should be made.""" state = create_base_state(ci_fix_attempt=4, ci_fix_max_attempts=3) - + github = create_mock_github_client() github.get_pull_request.return_value = {"head": {"sha": "abc123"}} github.get_check_runs.return_value = [ @@ -222,7 +230,9 @@ async def test_attempt_exceeding_max_limit_blocks_further_attempts(self): with patch("forge.workflow.nodes.ci_evaluator.get_settings") as mock_settings: mock_settings.return_value.ci_fix_max_retries = 5 mock_settings.return_value.ignored_ci_checks = ["tide"] - with patch("forge.workflow.nodes.ci_evaluator.record_ci_fix_attempt") as mock_record: + with patch( + "forge.workflow.nodes.ci_evaluator.record_ci_fix_attempt" + ) as mock_record: result = await evaluate_ci_status(state) # Should not increment or route to attempt_ci_fix @@ -235,7 +245,7 @@ async def test_attempt_exceeding_max_limit_blocks_further_attempts(self): async def test_attempt_one_below_max_allows_final_attempt(self): """When current_attempt is one below max, one more attempt should be allowed.""" state = create_base_state(ci_fix_attempt=2, ci_fix_max_attempts=3) - + github = create_mock_github_client() github.get_pull_request.return_value = {"head": {"sha": "abc123"}} github.get_check_runs.return_value = [ @@ -270,7 +280,7 @@ class TestCIAttemptReset: async def test_current_attempt_resets_on_ci_success(self): """When CI passes, current_attempt should reset to 0.""" state = create_base_state(ci_fix_attempt=2, ci_fix_max_attempts=3) - + github = create_mock_github_client() github.get_pull_request.return_value = {"head": {"sha": "abc123"}} github.get_check_runs.return_value = [ @@ -296,7 +306,7 @@ async def test_current_attempt_resets_on_ci_success(self): async def test_current_attempt_resets_on_workflow_completion(self): """When workflow completes (tasks complete), current_attempt should reset to 0.""" from forge.workflow.nodes.human_review import complete_tasks - + state = create_base_state( ci_fix_attempt=2, implemented_tasks=["TASK-1", "TASK-2"], @@ -327,7 +337,7 @@ async def test_missing_current_attempt_defaults_to_zero(self): state = create_base_state() # Remove current_attempt from state del state["ci_fix_attempt"] - + github = create_mock_github_client() github.get_pull_request.return_value = {"head": {"sha": "abc123"}} github.get_check_runs.return_value = [ @@ -355,7 +365,7 @@ async def test_missing_max_attempts_defaults_to_config_value(self): state = create_base_state(ci_fix_attempt=0) # Remove max_attempts from state del state["ci_fix_max_attempts"] - + github = create_mock_github_client() github.get_pull_request.return_value = {"head": {"sha": "abc123"}} github.get_check_runs.return_value = [ @@ -382,7 +392,7 @@ async def test_missing_max_attempts_defaults_to_config_value(self): async def test_max_attempts_one_allows_single_attempt(self): """When max_attempts is 1, only one attempt should be allowed.""" state = create_base_state(ci_fix_attempt=0, ci_fix_max_attempts=1) - + github = create_mock_github_client() github.get_pull_request.return_value = {"head": {"sha": "abc123"}} github.get_check_runs.return_value = [ @@ -419,3 +429,239 @@ async def test_max_attempts_one_allows_single_attempt(self): assert result2["ci_status"] == "failed" +# ── Token Recording and Fallback Estimation Tests ── + + +class TestCIAttemptFixTokenRecording: + """Test token recording and fallback estimation in attempt_ci_fix.""" + + @pytest.mark.asyncio + async def test_successful_phases_record_actual_tokens(self, tmp_path): + """When both phases run successfully and return valid token metrics, they are recorded and accumulated.""" + state = create_base_state( + workspace_path=str(tmp_path), + ci_fix_attempt=1, + ci_failed_checks=[{"name": "test", "conclusion": "failure"}], + ) + + # Create a mock fix plan file so Phase 2 is not skipped + fix_plan_file = tmp_path / ".forge" / "fix-plan.md" + fix_plan_file.parent.mkdir(parents=True, exist_ok=True) + fix_plan_file.write_text("apply some fix") + + mock_jira = AsyncMock() + mock_jira.close = AsyncMock() + + result_phase1 = MagicMock() + result_phase1.input_tokens = 120 + result_phase1.output_tokens = 80 + result_phase1.stdout = "phase 1 stdout" + result_phase1.stderr = "" + + result_phase2 = MagicMock() + result_phase2.input_tokens = 150 + result_phase2.output_tokens = 90 + result_phase2.stdout = "phase 2 stdout" + result_phase2.stderr = "" + + # Side effect to return result_phase1 on first run, result_phase2 on second + mock_runner = MagicMock() + mock_runner.run = AsyncMock(side_effect=[result_phase1, result_phase2]) + + with ( + patch("forge.workflow.nodes.ci_evaluator.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.ci_evaluator.ContainerRunner", return_value=mock_runner), + patch( + "forge.workflow.nodes.ci_evaluator.prepare_workspace", + return_value=(str(tmp_path), None), + ), + patch( + "forge.workflow.nodes.ci_evaluator._fetch_ci_logs_and_artifacts", + new_callable=AsyncMock, + ), + patch("forge.workflow.nodes.ci_evaluator.GitOperations") as mock_git_class, + patch( + "forge.workflow.nodes.ci_evaluator.run_post_change_review", new_callable=AsyncMock + ), + patch("forge.workflow.nodes.ci_evaluator.sync_pr_description", new_callable=AsyncMock), + patch("forge.workflow.nodes.ci_evaluator.get_settings") as mock_settings, + ): + mock_settings.return_value.container_model = "claude-sonnet-4-5" + + # Setup Git mock + mock_git = MagicMock() + mock_git.has_uncommitted_changes.return_value = False + mock_git._run_git.return_value.stdout = "some commit hash" + mock_git_class.return_value = mock_git + + new_state = await attempt_ci_fix(state) + + # Total expected input = 120 + 150 = 270 + # Total expected output = 80 + 90 = 170 + from forge.workflow.stats import STAGE_CI + + assert new_state["stage_token_usage"][STAGE_CI]["input_tokens"] == 270 + assert new_state["stage_token_usage"][STAGE_CI]["output_tokens"] == 170 + + @pytest.mark.asyncio + async def test_empty_or_zero_tokens_fallback_to_heuristic(self, tmp_path): + """When container returns 0 or empty token metrics, it falls back to _estimate_tokens.""" + state = create_base_state( + workspace_path=str(tmp_path), + ci_fix_attempt=1, + ci_failed_checks=[{"name": "test", "conclusion": "failure"}], + ) + + fix_plan_file = tmp_path / ".forge" / "fix-plan.md" + fix_plan_file.parent.mkdir(parents=True, exist_ok=True) + fix_plan_file.write_text("apply some fix") + + mock_jira = AsyncMock() + mock_jira.close = AsyncMock() + + result_phase1 = MagicMock() + result_phase1.input_tokens = 0 # Should trigger fallback + result_phase1.output_tokens = 0 # Should trigger fallback + result_phase1.stdout = "phase 1 output" + result_phase1.stderr = "some stderr" + + result_phase2 = MagicMock() + result_phase2.input_tokens = None # Should trigger fallback + result_phase2.output_tokens = None # Should trigger fallback + result_phase2.stdout = "phase 2 output" + result_phase2.stderr = "" + + mock_runner = MagicMock() + mock_runner.run = AsyncMock(side_effect=[result_phase1, result_phase2]) + + with ( + patch("forge.workflow.nodes.ci_evaluator.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.ci_evaluator.ContainerRunner", return_value=mock_runner), + patch( + "forge.workflow.nodes.ci_evaluator.prepare_workspace", + return_value=(str(tmp_path), None), + ), + patch( + "forge.workflow.nodes.ci_evaluator._fetch_ci_logs_and_artifacts", + new_callable=AsyncMock, + ), + patch("forge.workflow.nodes.ci_evaluator.GitOperations") as mock_git_class, + patch( + "forge.workflow.nodes.ci_evaluator.run_post_change_review", new_callable=AsyncMock + ), + patch("forge.workflow.nodes.ci_evaluator.sync_pr_description", new_callable=AsyncMock), + patch("forge.workflow.nodes.ci_evaluator.get_settings") as mock_settings, + ): + mock_settings.return_value.container_model = "claude-sonnet-4-5" + + mock_git = MagicMock() + mock_git.has_uncommitted_changes.return_value = False + mock_git._run_git.return_value.stdout = "some commit hash" + mock_git_class.return_value = mock_git + + new_state = await attempt_ci_fix(state) + + from forge.workflow.stats import STAGE_CI + + # Input tokens should be non-zero (estimated from prompts) + assert new_state["stage_token_usage"][STAGE_CI]["input_tokens"] > 0 + # Output tokens should be non-zero (estimated from stdout/stderr) + assert new_state["stage_token_usage"][STAGE_CI]["output_tokens"] > 0 + + @pytest.mark.asyncio + async def test_skipped_phase2_records_only_phase1_tokens(self, tmp_path): + """When Phase 2 is skipped because fix plan file does not exist, only Phase 1 tokens are recorded.""" + state = create_base_state( + workspace_path=str(tmp_path), + ci_fix_attempt=1, + ci_failed_checks=[{"name": "test", "conclusion": "failure"}], + ) + + # Ensure fix plan file does NOT exist + fix_plan_file = tmp_path / ".forge" / "fix-plan.md" + if fix_plan_file.exists(): + fix_plan_file.unlink() + + mock_jira = AsyncMock() + mock_jira.close = AsyncMock() + + result_phase1 = MagicMock() + result_phase1.input_tokens = 50 + result_phase1.output_tokens = 30 + result_phase1.stdout = "phase 1 stdout" + result_phase1.stderr = "" + + mock_runner = MagicMock() + mock_runner.run = AsyncMock(return_value=result_phase1) + + with ( + patch("forge.workflow.nodes.ci_evaluator.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.ci_evaluator.ContainerRunner", return_value=mock_runner), + patch( + "forge.workflow.nodes.ci_evaluator.prepare_workspace", + return_value=(str(tmp_path), None), + ), + patch( + "forge.workflow.nodes.ci_evaluator._fetch_ci_logs_and_artifacts", + new_callable=AsyncMock, + ), + patch("forge.workflow.nodes.ci_evaluator.get_settings") as mock_settings, + ): + mock_settings.return_value.container_model = "claude-sonnet-4-5" + new_state = await attempt_ci_fix(state) + + from forge.workflow.stats import STAGE_CI + + assert new_state["stage_token_usage"][STAGE_CI]["input_tokens"] == 50 + assert new_state["stage_token_usage"][STAGE_CI]["output_tokens"] == 30 + + @pytest.mark.asyncio + async def test_failure_in_subsequent_steps_preserves_recorded_tokens(self, tmp_path): + """When subsequent step (such as Phase 2 or Git operations) raises an exception, preceding recorded tokens are preserved in the returned state.""" + state = create_base_state( + workspace_path=str(tmp_path), + ci_fix_attempt=1, + ci_failed_checks=[{"name": "test", "conclusion": "failure"}], + ) + + fix_plan_file = tmp_path / ".forge" / "fix-plan.md" + fix_plan_file.parent.mkdir(parents=True, exist_ok=True) + fix_plan_file.write_text("apply some fix") + + mock_jira = AsyncMock() + mock_jira.close = AsyncMock() + + result_phase1 = MagicMock() + result_phase1.input_tokens = 80 + result_phase1.output_tokens = 40 + result_phase1.stdout = "phase 1 stdout" + result_phase1.stderr = "" + + mock_runner = MagicMock() + # Phase 2 run raises an Exception + mock_runner.run = AsyncMock( + side_effect=[result_phase1, Exception("Phase 2 simulated failure")] + ) + + with ( + patch("forge.workflow.nodes.ci_evaluator.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.ci_evaluator.ContainerRunner", return_value=mock_runner), + patch( + "forge.workflow.nodes.ci_evaluator.prepare_workspace", + return_value=(str(tmp_path), None), + ), + patch( + "forge.workflow.nodes.ci_evaluator._fetch_ci_logs_and_artifacts", + new_callable=AsyncMock, + ), + patch("forge.workflow.nodes.ci_evaluator.notify_error", new_callable=AsyncMock), + patch("forge.workflow.nodes.ci_evaluator.get_settings") as mock_settings, + ): + mock_settings.return_value.container_model = "claude-sonnet-4-5" + new_state = await attempt_ci_fix(state) + + from forge.workflow.stats import STAGE_CI + + # Phase 1 tokens (80 and 40) must be preserved in the final returned state + assert new_state["stage_token_usage"][STAGE_CI]["input_tokens"] == 80 + assert new_state["stage_token_usage"][STAGE_CI]["output_tokens"] == 40 diff --git a/tests/unit/workflow/nodes/test_prd_spec_stats.py b/tests/unit/workflow/nodes/test_prd_spec_stats.py new file mode 100644 index 00000000..01815482 --- /dev/null +++ b/tests/unit/workflow/nodes/test_prd_spec_stats.py @@ -0,0 +1,737 @@ +"""Unit tests for stats recording in PRD and Spec generation nodes.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from forge.models.workflow import TicketType +from forge.workflow.feature.state import create_initial_feature_state +from forge.workflow.stats import STAGE_PRD, STAGE_SPEC + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def create_mock_jira( + description: str = "Raw requirements text", + summary: str = "Test Feature", + project_key: str = "TEST", +) -> MagicMock: + """Return a JiraClient mock with default async methods.""" + mock = MagicMock() + mock.close = AsyncMock() + mock.update_description = AsyncMock() + mock.add_structured_comment = AsyncMock() + mock.set_workflow_label = AsyncMock() + mock.get_prd_proposals_repo = AsyncMock(return_value=None) + mock.add_comment = AsyncMock() + mock.get_issue = AsyncMock( + return_value=MagicMock( + summary=summary, + description=description, + project_key=project_key, + ) + ) + return mock + + +def create_mock_agent( + prd_content: str = "# Generated PRD\n\nContent here.", + spec_content: str = "# Generated Spec\n\nAcceptance criteria here.", +) -> MagicMock: + """Return a ForgeAgent mock with default async methods.""" + mock = MagicMock() + mock.close = AsyncMock() + mock.generate_prd = AsyncMock(return_value=prd_content) + mock.generate_spec = AsyncMock(return_value=spec_content) + mock.regenerate_with_feedback = AsyncMock(return_value="# Revised content") + return mock + + +def _get_stage(result: dict, stage_name: str) -> dict: + """Extract a stage entry from result state, or {} if absent.""" + return (result.get("stage_timestamps") or {}).get(stage_name, {}) + + +# --------------------------------------------------------------------------- +# PRD generation stats tests +# --------------------------------------------------------------------------- + + +class TestGeneratePrdStatsRecording: + """Tests for stats recording in generate_prd node.""" + + @pytest.mark.asyncio + async def test_records_stage_start_on_entry(self): + """generate_prd should initialise the PRD stage with a started_at timestamp.""" + from forge.workflow.nodes.prd_generation import generate_prd + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + ) + + with ( + patch("forge.workflow.nodes.prd_generation.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.prd_generation.ForgeAgent", return_value=mock_agent), + patch( + "forge.workflow.nodes.prd_generation.post_status_comment", + new_callable=AsyncMock, + ), + ): + result = await generate_prd(state) + + stage = _get_stage(result, STAGE_PRD) + assert stage, "stage_timestamps[STAGE_PRD] should be populated" + assert stage.get("started_at") is not None, "started_at must be set" + + @pytest.mark.asyncio + async def test_records_stage_end_with_machine_time(self): + """generate_prd should populate ended_at and positive machine_time_seconds.""" + from forge.workflow.nodes.prd_generation import generate_prd + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + ) + + with ( + patch("forge.workflow.nodes.prd_generation.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.prd_generation.ForgeAgent", return_value=mock_agent), + patch( + "forge.workflow.nodes.prd_generation.post_status_comment", + new_callable=AsyncMock, + ), + ): + result = await generate_prd(state) + + stage = _get_stage(result, STAGE_PRD) + assert stage.get("ended_at") is not None, "ended_at must be set on success" + assert stage.get("machine_time_seconds", 0.0) >= 0.0, "machine_time must be non-negative" + + @pytest.mark.asyncio + async def test_records_tokens_from_llm_response(self): + """generate_prd should record non-zero token counts after LLM call.""" + from forge.workflow.nodes.prd_generation import generate_prd + + mock_jira = create_mock_jira(description="A" * 400) # 100 estimated tokens + mock_agent = create_mock_agent(prd_content="B" * 800) # 200 estimated tokens + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + ) + + with ( + patch("forge.workflow.nodes.prd_generation.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.prd_generation.ForgeAgent", return_value=mock_agent), + patch( + "forge.workflow.nodes.prd_generation.post_status_comment", + new_callable=AsyncMock, + ), + ): + result = await generate_prd(state) + + stage = _get_stage(result, STAGE_PRD) + assert stage.get("input_tokens", 0) > 0, "input_tokens should be positive" + assert stage.get("output_tokens", 0) > 0, "output_tokens should be positive" + + @pytest.mark.asyncio + async def test_stats_recorded_on_missing_requirements(self): + """generate_prd should record stage_end even when requirements are empty.""" + from forge.workflow.nodes.prd_generation import generate_prd + + mock_jira = create_mock_jira(description="") + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + ) + + with ( + patch("forge.workflow.nodes.prd_generation.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.prd_generation.ForgeAgent", return_value=mock_agent), + patch( + "forge.workflow.nodes.prd_generation.post_status_comment", + new_callable=AsyncMock, + ), + ): + result = await generate_prd(state) + + stage = _get_stage(result, STAGE_PRD) + assert stage.get("started_at") is not None + assert stage.get("ended_at") is not None + + @pytest.mark.asyncio + async def test_stats_recorded_on_exception(self): + """generate_prd should record stage_end even when an exception is raised.""" + from forge.workflow.nodes.prd_generation import generate_prd + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + mock_agent.generate_prd = AsyncMock(side_effect=RuntimeError("LLM failure")) + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + ) + + with ( + patch("forge.workflow.nodes.prd_generation.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.prd_generation.ForgeAgent", return_value=mock_agent), + patch( + "forge.workflow.nodes.prd_generation.post_status_comment", + new_callable=AsyncMock, + ), + patch( + "forge.workflow.nodes.error_handler.notify_error", + new_callable=AsyncMock, + ), + ): + result = await generate_prd(state) + + stage = _get_stage(result, STAGE_PRD) + assert stage.get("started_at") is not None + assert stage.get("ended_at") is not None + assert result.get("last_error") is not None + + +# --------------------------------------------------------------------------- +# PRD regeneration stats tests +# --------------------------------------------------------------------------- + + +class TestRegeneratePrdStatsRecording: + """Tests for stats recording in regenerate_prd_with_feedback node.""" + + @pytest.mark.asyncio + async def test_increments_revision_on_feedback(self): + """regenerate_prd_with_feedback should increment iteration_count by 1.""" + from forge.workflow.nodes.prd_generation import regenerate_prd_with_feedback + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + prd_content="# Original PRD", + feedback_comment="! Please add more detail about authentication", + ) + + with ( + patch( + "forge.workflow.nodes.prd_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.prd_generation.ForgeAgent", + return_value=mock_agent, + ), + ): + result = await regenerate_prd_with_feedback(state) + + stage = _get_stage(result, STAGE_PRD) + assert stage.get("iteration_count", 0) >= 1, "iteration_count must be incremented" + + @pytest.mark.asyncio + async def test_records_stage_start_on_feedback(self): + """regenerate_prd_with_feedback should set started_at on re-entry.""" + from forge.workflow.nodes.prd_generation import regenerate_prd_with_feedback + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + prd_content="# Original PRD", + feedback_comment="! Needs more detail", + ) + + with ( + patch( + "forge.workflow.nodes.prd_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.prd_generation.ForgeAgent", + return_value=mock_agent, + ), + ): + result = await regenerate_prd_with_feedback(state) + + stage = _get_stage(result, STAGE_PRD) + assert stage.get("started_at") is not None + + @pytest.mark.asyncio + async def test_records_stage_end_on_feedback(self): + """regenerate_prd_with_feedback should record ended_at and machine_time.""" + from forge.workflow.nodes.prd_generation import regenerate_prd_with_feedback + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + prd_content="# Original PRD", + feedback_comment="! Add more context", + ) + + with ( + patch( + "forge.workflow.nodes.prd_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.prd_generation.ForgeAgent", + return_value=mock_agent, + ), + ): + result = await regenerate_prd_with_feedback(state) + + stage = _get_stage(result, STAGE_PRD) + assert stage.get("ended_at") is not None + assert stage.get("machine_time_seconds", 0.0) >= 0.0 + + @pytest.mark.asyncio + async def test_records_tokens_on_feedback(self): + """regenerate_prd_with_feedback should record tokens for the revision.""" + from forge.workflow.nodes.prd_generation import regenerate_prd_with_feedback + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + mock_agent.regenerate_with_feedback = AsyncMock(return_value="D" * 800) + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + prd_content="C" * 400, + feedback_comment="! " + "E" * 40, + ) + + with ( + patch( + "forge.workflow.nodes.prd_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.prd_generation.ForgeAgent", + return_value=mock_agent, + ), + ): + result = await regenerate_prd_with_feedback(state) + + stage = _get_stage(result, STAGE_PRD) + assert stage.get("input_tokens", 0) > 0 + assert stage.get("output_tokens", 0) > 0 + + @pytest.mark.asyncio + async def test_no_feedback_returns_unchanged_state(self): + """regenerate_prd_with_feedback with no feedback should return state unchanged.""" + from forge.workflow.nodes.prd_generation import regenerate_prd_with_feedback + + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + prd_content="# Original PRD", + ) + + result = await regenerate_prd_with_feedback(state) + + # State returned unchanged — no stage_timestamps mutation + assert result is state + + @pytest.mark.asyncio + async def test_stats_recorded_on_exception(self): + """regenerate_prd_with_feedback records stage_end even on exception.""" + from forge.workflow.nodes.prd_generation import regenerate_prd_with_feedback + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + mock_agent.regenerate_with_feedback = AsyncMock(side_effect=RuntimeError("API error")) + state = create_initial_feature_state( + ticket_key="TEST-1", + ticket_type=TicketType.FEATURE, + prd_content="# Original PRD", + feedback_comment="! Add more detail", + ) + + with ( + patch( + "forge.workflow.nodes.prd_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.prd_generation.ForgeAgent", + return_value=mock_agent, + ), + patch( + "forge.workflow.nodes.error_handler.notify_error", + new_callable=AsyncMock, + ), + ): + result = await regenerate_prd_with_feedback(state) + + stage = _get_stage(result, STAGE_PRD) + assert stage.get("ended_at") is not None + assert result.get("last_error") is not None + + +# --------------------------------------------------------------------------- +# Spec generation stats tests +# --------------------------------------------------------------------------- + + +class TestGenerateSpecStatsRecording: + """Tests for stats recording in generate_spec node.""" + + @pytest.mark.asyncio + async def test_records_stage_start_on_entry(self): + """generate_spec should initialise the SPEC stage with a started_at timestamp.""" + from forge.workflow.nodes.spec_generation import generate_spec + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + prd_content="# Approved PRD", + ) + + with ( + patch("forge.workflow.nodes.spec_generation.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.spec_generation.ForgeAgent", return_value=mock_agent), + patch( + "forge.workflow.nodes.spec_generation.post_status_comment", + new_callable=AsyncMock, + ), + ): + result = await generate_spec(state) + + stage = _get_stage(result, STAGE_SPEC) + assert stage, "stage_timestamps[STAGE_SPEC] should be populated" + assert stage.get("started_at") is not None + + @pytest.mark.asyncio + async def test_records_stage_end_with_machine_time(self): + """generate_spec should populate ended_at and machine_time_seconds.""" + from forge.workflow.nodes.spec_generation import generate_spec + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + prd_content="# Approved PRD", + ) + + with ( + patch("forge.workflow.nodes.spec_generation.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.spec_generation.ForgeAgent", return_value=mock_agent), + patch( + "forge.workflow.nodes.spec_generation.post_status_comment", + new_callable=AsyncMock, + ), + ): + result = await generate_spec(state) + + stage = _get_stage(result, STAGE_SPEC) + assert stage.get("ended_at") is not None + assert stage.get("machine_time_seconds", 0.0) >= 0.0 + + @pytest.mark.asyncio + async def test_records_tokens_from_llm_response(self): + """generate_spec should record non-zero token counts after LLM call.""" + from forge.workflow.nodes.spec_generation import generate_spec + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent(spec_content="F" * 800) + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + prd_content="G" * 400, + ) + + with ( + patch("forge.workflow.nodes.spec_generation.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.spec_generation.ForgeAgent", return_value=mock_agent), + patch( + "forge.workflow.nodes.spec_generation.post_status_comment", + new_callable=AsyncMock, + ), + ): + result = await generate_spec(state) + + stage = _get_stage(result, STAGE_SPEC) + assert stage.get("input_tokens", 0) > 0 + assert stage.get("output_tokens", 0) > 0 + + @pytest.mark.asyncio + async def test_stats_recorded_on_missing_prd(self): + """generate_spec should record stage_end even when PRD content is empty.""" + from forge.workflow.nodes.spec_generation import generate_spec + + # No prd_content in state, and Jira returns empty description + mock_jira = create_mock_jira(description="") + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + ) + + with ( + patch("forge.workflow.nodes.spec_generation.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.spec_generation.ForgeAgent", return_value=mock_agent), + patch( + "forge.workflow.nodes.spec_generation.post_status_comment", + new_callable=AsyncMock, + ), + ): + result = await generate_spec(state) + + stage = _get_stage(result, STAGE_SPEC) + assert stage.get("started_at") is not None + assert stage.get("ended_at") is not None + + @pytest.mark.asyncio + async def test_stats_recorded_on_exception(self): + """generate_spec should record stage_end even when an exception is raised.""" + from forge.workflow.nodes.spec_generation import generate_spec + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + mock_agent.generate_spec = AsyncMock(side_effect=RuntimeError("Spec LLM failure")) + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + prd_content="# Approved PRD", + ) + + with ( + patch("forge.workflow.nodes.spec_generation.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.spec_generation.ForgeAgent", return_value=mock_agent), + patch( + "forge.workflow.nodes.spec_generation.post_status_comment", + new_callable=AsyncMock, + ), + patch( + "forge.workflow.nodes.error_handler.notify_error", + new_callable=AsyncMock, + ), + ): + result = await generate_spec(state) + + stage = _get_stage(result, STAGE_SPEC) + assert stage.get("started_at") is not None + assert stage.get("ended_at") is not None + assert result.get("last_error") is not None + + +# --------------------------------------------------------------------------- +# Spec regeneration stats tests +# --------------------------------------------------------------------------- + + +class TestRegenerateSpecStatsRecording: + """Tests for stats recording in regenerate_spec_with_feedback node.""" + + @pytest.mark.asyncio + async def test_increments_revision_on_feedback(self): + """regenerate_spec_with_feedback should increment iteration_count.""" + from forge.workflow.nodes.spec_generation import regenerate_spec_with_feedback + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + spec_content="# Original Spec", + feedback_comment="! Please add more Given/When/Then scenarios", + ) + + with ( + patch( + "forge.workflow.nodes.spec_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.spec_generation.ForgeAgent", + return_value=mock_agent, + ), + ): + result = await regenerate_spec_with_feedback(state) + + stage = _get_stage(result, STAGE_SPEC) + assert stage.get("iteration_count", 0) >= 1 + + @pytest.mark.asyncio + async def test_records_stage_start_on_feedback(self): + """regenerate_spec_with_feedback should set started_at on re-entry.""" + from forge.workflow.nodes.spec_generation import regenerate_spec_with_feedback + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + spec_content="# Original Spec", + feedback_comment="! Needs more detail", + ) + + with ( + patch( + "forge.workflow.nodes.spec_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.spec_generation.ForgeAgent", + return_value=mock_agent, + ), + ): + result = await regenerate_spec_with_feedback(state) + + stage = _get_stage(result, STAGE_SPEC) + assert stage.get("started_at") is not None + + @pytest.mark.asyncio + async def test_records_stage_end_on_feedback(self): + """regenerate_spec_with_feedback should record ended_at and machine_time.""" + from forge.workflow.nodes.spec_generation import regenerate_spec_with_feedback + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + spec_content="# Original Spec", + feedback_comment="! Add edge cases", + ) + + with ( + patch( + "forge.workflow.nodes.spec_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.spec_generation.ForgeAgent", + return_value=mock_agent, + ), + ): + result = await regenerate_spec_with_feedback(state) + + stage = _get_stage(result, STAGE_SPEC) + assert stage.get("ended_at") is not None + assert stage.get("machine_time_seconds", 0.0) >= 0.0 + + @pytest.mark.asyncio + async def test_records_tokens_on_feedback(self): + """regenerate_spec_with_feedback should record tokens for the revision.""" + from forge.workflow.nodes.spec_generation import regenerate_spec_with_feedback + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + mock_agent.regenerate_with_feedback = AsyncMock(return_value="H" * 800) + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + spec_content="I" * 400, + feedback_comment="! " + "J" * 40, + ) + + with ( + patch( + "forge.workflow.nodes.spec_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.spec_generation.ForgeAgent", + return_value=mock_agent, + ), + ): + result = await regenerate_spec_with_feedback(state) + + stage = _get_stage(result, STAGE_SPEC) + assert stage.get("input_tokens", 0) > 0 + assert stage.get("output_tokens", 0) > 0 + + @pytest.mark.asyncio + async def test_no_feedback_returns_unchanged_state(self): + """regenerate_spec_with_feedback with no feedback should return state unchanged.""" + from forge.workflow.nodes.spec_generation import regenerate_spec_with_feedback + + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + spec_content="# Original Spec", + ) + + result = await regenerate_spec_with_feedback(state) + + assert result is state + + @pytest.mark.asyncio + async def test_stats_recorded_on_exception(self): + """regenerate_spec_with_feedback records stage_end even on exception.""" + from forge.workflow.nodes.spec_generation import regenerate_spec_with_feedback + + mock_jira = create_mock_jira() + mock_agent = create_mock_agent() + mock_agent.regenerate_with_feedback = AsyncMock(side_effect=RuntimeError("API error")) + state = create_initial_feature_state( + ticket_key="TEST-2", + ticket_type=TicketType.FEATURE, + spec_content="# Original Spec", + feedback_comment="! Add more detail", + ) + + with ( + patch( + "forge.workflow.nodes.spec_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.spec_generation.ForgeAgent", + return_value=mock_agent, + ), + patch( + "forge.workflow.nodes.error_handler.notify_error", + new_callable=AsyncMock, + ), + ): + result = await regenerate_spec_with_feedback(state) + + stage = _get_stage(result, STAGE_SPEC) + assert stage.get("ended_at") is not None + assert result.get("last_error") is not None + + +# --------------------------------------------------------------------------- +# Token estimation helper tests +# --------------------------------------------------------------------------- + + +class TestEstimateTokens: + """Tests for the _estimate_tokens helper.""" + + def test_empty_string_returns_one(self): + from forge.workflow.nodes.prd_generation import _estimate_tokens + + assert _estimate_tokens("") == 1 + + def test_four_chars_returns_one(self): + from forge.workflow.nodes.prd_generation import _estimate_tokens + + assert _estimate_tokens("abcd") == 1 + + def test_estimate_scales_with_length(self): + from forge.workflow.nodes.prd_generation import _estimate_tokens + + assert _estimate_tokens("a" * 400) == 100 + + def test_spec_module_helper_matches(self): + from forge.workflow.nodes.prd_generation import _estimate_tokens as prd_est + from forge.workflow.nodes.spec_generation import _estimate_tokens as spec_est + + text = "Hello world test" + assert prd_est(text) == spec_est(text) diff --git a/tests/unit/workflow/nodes/test_stats_posting.py b/tests/unit/workflow/nodes/test_stats_posting.py new file mode 100644 index 00000000..0093168c --- /dev/null +++ b/tests/unit/workflow/nodes/test_stats_posting.py @@ -0,0 +1,388 @@ +"""Unit tests for the post_terminal_stats node (stats_posting.py). + +Tests cover: +- Outcome classification for Completed / Blocked / Failed states +- Outcome detail extraction (last_error, block reason, stats_outcome_reason) +- Integration with post_stats_comment and ensure_stats_is_final_comment +- Handling of both FeatureState and BugState +- Non-blocking behaviour on Jira API failures +""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from forge.workflow.bug.state import create_initial_bug_state +from forge.workflow.feature.state import create_initial_feature_state +from forge.workflow.nodes.stats_posting import ( + _determine_outcome, + _extract_outcome_detail, + post_terminal_stats, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def feature_state(): + """Minimal FeatureState with no terminal conditions set.""" + return create_initial_feature_state("FEAT-1") + + +@pytest.fixture() +def bug_state(): + """Minimal BugState with no terminal conditions set.""" + return create_initial_bug_state("BUG-1") + + +# --------------------------------------------------------------------------- +# _determine_outcome tests +# --------------------------------------------------------------------------- + + +class TestDetermineOutcome: + """Tests for the _determine_outcome helper.""" + + def test_completed_when_no_flags_set(self, feature_state): + """Returns 'Completed' when no error or block flag is set.""" + assert _determine_outcome(feature_state) == "Completed" + + def test_failed_when_last_error_set(self, feature_state): + """Returns 'Failed' when last_error contains a message.""" + feature_state["last_error"] = "Something went wrong" + assert _determine_outcome(feature_state) == "Failed" + + def test_blocked_when_is_blocked_true(self, feature_state): + """Returns 'Blocked' when is_blocked flag is True.""" + feature_state["is_blocked"] = True + assert _determine_outcome(feature_state) == "Blocked" + + def test_blocked_takes_precedence_over_last_error(self, feature_state): + """'Blocked' takes precedence over 'Failed' when both flags are set.""" + feature_state["is_blocked"] = True + feature_state["last_error"] = "Some error" + assert _determine_outcome(feature_state) == "Blocked" + + def test_existing_workflow_outcome_returned_directly(self, feature_state): + """If workflow_outcome is already set, it is returned without re-deriving.""" + feature_state["workflow_outcome"] = "Completed" + feature_state["last_error"] = "Some error" # would normally produce 'Failed' + assert _determine_outcome(feature_state) == "Completed" + + def test_existing_workflow_outcome_blocked(self, feature_state): + """Pre-set workflow_outcome of 'Blocked' is honoured directly.""" + feature_state["workflow_outcome"] = "Blocked" + assert _determine_outcome(feature_state) == "Blocked" + + def test_completed_for_bug_state(self, bug_state): + """Bug workflow: returns 'Completed' when no error or block.""" + assert _determine_outcome(bug_state) == "Completed" + + def test_failed_for_bug_state(self, bug_state): + """Bug workflow: returns 'Failed' when last_error is set.""" + bug_state["last_error"] = "container exited with code 1" + assert _determine_outcome(bug_state) == "Failed" + + def test_blocked_for_bug_state(self, bug_state): + """Bug workflow: returns 'Blocked' when is_blocked is True.""" + bug_state["is_blocked"] = True + assert _determine_outcome(bug_state) == "Blocked" + + +# --------------------------------------------------------------------------- +# _extract_outcome_detail tests +# --------------------------------------------------------------------------- + + +class TestExtractOutcomeDetail: + """Tests for the _extract_outcome_detail helper.""" + + def test_completed_returns_none(self, feature_state): + """Completed outcome has no detail.""" + assert _extract_outcome_detail(feature_state, "Completed") is None + + def test_failed_returns_last_error(self, feature_state): + """Failed outcome uses last_error as the detail string.""" + feature_state["last_error"] = "NullPointerException in validate()" + detail = _extract_outcome_detail(feature_state, "Failed") + assert detail == "NullPointerException in validate()" + + def test_failed_no_last_error_returns_none(self, feature_state): + """Failed outcome returns None when last_error is not set.""" + assert _extract_outcome_detail(feature_state, "Failed") is None + + def test_blocked_returns_feedback_comment(self, feature_state): + """Blocked outcome uses feedback_comment as the block reason.""" + feature_state["feedback_comment"] = "Waiting for third-party API key" + detail = _extract_outcome_detail(feature_state, "Blocked") + assert detail == "Waiting for third-party API key" + + def test_blocked_no_reason_returns_none(self, feature_state): + """Blocked outcome returns None when no reason is available.""" + assert _extract_outcome_detail(feature_state, "Blocked") is None + + def test_stats_outcome_reason_takes_precedence(self, feature_state): + """Pre-recorded stats_outcome_reason overrides derived detail.""" + feature_state["stats_outcome_reason"] = "Pre-recorded reason" + feature_state["last_error"] = "Some other error" + detail = _extract_outcome_detail(feature_state, "Failed") + assert detail == "Pre-recorded reason" + + def test_stats_outcome_reason_for_blocked(self, feature_state): + """Pre-recorded stats_outcome_reason is used for Blocked outcome too.""" + feature_state["stats_outcome_reason"] = "External dependency unavailable" + feature_state["feedback_comment"] = "Other comment" + detail = _extract_outcome_detail(feature_state, "Blocked") + assert detail == "External dependency unavailable" + + def test_failed_for_bug_state(self, bug_state): + """Bug workflow: Failed outcome extracts last_error.""" + bug_state["last_error"] = "RCA container timed out" + assert _extract_outcome_detail(bug_state, "Failed") == "RCA container timed out" + + +# --------------------------------------------------------------------------- +# post_terminal_stats integration tests +# --------------------------------------------------------------------------- + + +class TestPostTerminalStats: + """Tests for the post_terminal_stats async node function.""" + + @pytest.mark.asyncio + async def test_returns_empty_dict(self, feature_state): + """Node always returns an empty dict (state unchanged).""" + with ( + patch( + "forge.workflow.nodes.stats_posting.post_stats_comment", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "forge.workflow.nodes.stats_posting.ensure_stats_is_final_comment", + new_callable=AsyncMock, + return_value=True, + ), + ): + result = await post_terminal_stats(feature_state) + + assert result == {} + + @pytest.mark.asyncio + async def test_calls_post_stats_comment_with_correct_args(self, feature_state): + """post_stats_comment is called with ticket_key, state, and derived outcome.""" + mock_post = AsyncMock(return_value=True) + mock_ensure = AsyncMock(return_value=True) + + with ( + patch("forge.workflow.nodes.stats_posting.post_stats_comment", mock_post), + patch("forge.workflow.nodes.stats_posting.ensure_stats_is_final_comment", mock_ensure), + ): + await post_terminal_stats(feature_state) + + mock_post.assert_awaited_once_with( + ticket_key="FEAT-1", + stats=feature_state, + outcome="Completed", + outcome_detail=None, + ) + + @pytest.mark.asyncio + async def test_calls_ensure_stats_is_final_comment(self, feature_state): + """ensure_stats_is_final_comment is called with correct args.""" + mock_post = AsyncMock(return_value=True) + mock_ensure = AsyncMock(return_value=True) + + with ( + patch("forge.workflow.nodes.stats_posting.post_stats_comment", mock_post), + patch("forge.workflow.nodes.stats_posting.ensure_stats_is_final_comment", mock_ensure), + ): + await post_terminal_stats(feature_state) + + mock_ensure.assert_awaited_once_with( + ticket_key="FEAT-1", + stats=feature_state, + outcome="Completed", + outcome_detail=None, + ) + + @pytest.mark.asyncio + async def test_completed_outcome_for_clean_state(self, feature_state): + """Completed outcome is passed when state has no errors or blocks.""" + mock_post = AsyncMock(return_value=True) + mock_ensure = AsyncMock(return_value=True) + + with ( + patch("forge.workflow.nodes.stats_posting.post_stats_comment", mock_post), + patch("forge.workflow.nodes.stats_posting.ensure_stats_is_final_comment", mock_ensure), + ): + await post_terminal_stats(feature_state) + + _call_kwargs = mock_post.call_args.kwargs + assert _call_kwargs["outcome"] == "Completed" + assert _call_kwargs["outcome_detail"] is None + + @pytest.mark.asyncio + async def test_blocked_outcome_for_blocked_state(self, feature_state): + """Blocked outcome posts stats.""" + feature_state["is_blocked"] = True + feature_state["feedback_comment"] = "Waiting on legal approval" + + mock_post = AsyncMock(return_value=True) + mock_ensure = AsyncMock(return_value=True) + + with ( + patch("forge.workflow.nodes.stats_posting.post_stats_comment", mock_post), + patch("forge.workflow.nodes.stats_posting.ensure_stats_is_final_comment", mock_ensure), + ): + await post_terminal_stats(feature_state) + + mock_post.assert_awaited_once_with( + ticket_key="FEAT-1", + stats=feature_state, + outcome="Blocked", + outcome_detail="Waiting on legal approval", + ) + + @pytest.mark.asyncio + async def test_failed_outcome_for_error_state(self, feature_state): + """Failed outcome posts stats.""" + feature_state["last_error"] = "container exited with code 137" + + mock_post = AsyncMock(return_value=True) + mock_ensure = AsyncMock(return_value=True) + + with ( + patch("forge.workflow.nodes.stats_posting.post_stats_comment", mock_post), + patch("forge.workflow.nodes.stats_posting.ensure_stats_is_final_comment", mock_ensure), + ): + await post_terminal_stats(feature_state) + + mock_post.assert_awaited_once_with( + ticket_key="FEAT-1", + stats=feature_state, + outcome="Failed", + outcome_detail="container exited with code 137", + ) + + @pytest.mark.asyncio + async def test_handles_bug_state(self, bug_state): + """Node works with BugState as well as FeatureState.""" + bug_state["last_error"] = "triage failed" + + mock_post = AsyncMock(return_value=True) + mock_ensure = AsyncMock(return_value=True) + + with ( + patch("forge.workflow.nodes.stats_posting.post_stats_comment", mock_post), + patch("forge.workflow.nodes.stats_posting.ensure_stats_is_final_comment", mock_ensure), + ): + result = await post_terminal_stats(bug_state) + + assert result == {} + mock_post.assert_awaited_once_with( + ticket_key="BUG-1", + stats=bug_state, + outcome="Failed", + outcome_detail="triage failed", + ) + + @pytest.mark.asyncio + async def test_non_blocking_on_post_stats_failure(self, feature_state): + """post_stats_comment raising an exception does not propagate.""" + mock_post = AsyncMock(side_effect=RuntimeError("Jira is down")) + mock_ensure = AsyncMock(return_value=True) + + with ( + patch("forge.workflow.nodes.stats_posting.post_stats_comment", mock_post), + patch("forge.workflow.nodes.stats_posting.ensure_stats_is_final_comment", mock_ensure), + ): + # Should not raise + result = await post_terminal_stats(feature_state) + + assert result == {} + + @pytest.mark.asyncio + async def test_non_blocking_on_ensure_final_comment_failure(self, feature_state): + """ensure_stats_is_final_comment raising does not propagate.""" + mock_post = AsyncMock(return_value=True) + mock_ensure = AsyncMock(side_effect=RuntimeError("network timeout")) + + with ( + patch("forge.workflow.nodes.stats_posting.post_stats_comment", mock_post), + patch("forge.workflow.nodes.stats_posting.ensure_stats_is_final_comment", mock_ensure), + ): + result = await post_terminal_stats(feature_state) + + assert result == {} + + @pytest.mark.asyncio + async def test_non_blocking_when_both_services_fail(self, feature_state): + """Node returns empty dict even when both posting services raise.""" + mock_post = AsyncMock(side_effect=Exception("boom")) + mock_ensure = AsyncMock(side_effect=Exception("crash")) + + with ( + patch("forge.workflow.nodes.stats_posting.post_stats_comment", mock_post), + patch("forge.workflow.nodes.stats_posting.ensure_stats_is_final_comment", mock_ensure), + ): + result = await post_terminal_stats(feature_state) + + assert result == {} + + @pytest.mark.asyncio + async def test_skips_posting_when_no_ticket_key(self): + """Node skips posting gracefully when ticket_key is absent.""" + state_without_key = {"is_blocked": False, "last_error": None} + + mock_post = AsyncMock(return_value=True) + mock_ensure = AsyncMock(return_value=True) + + with ( + patch("forge.workflow.nodes.stats_posting.post_stats_comment", mock_post), + patch("forge.workflow.nodes.stats_posting.ensure_stats_is_final_comment", mock_ensure), + ): + result = await post_terminal_stats(state_without_key) # type: ignore[arg-type] + + assert result == {} + mock_post.assert_not_awaited() + mock_ensure.assert_not_awaited() + + @pytest.mark.asyncio + async def test_post_stats_comment_false_does_not_skip_ensure(self, feature_state): + """ensure_stats_is_final_comment is still called even when post returns False.""" + mock_post = AsyncMock(return_value=False) + mock_ensure = AsyncMock(return_value=True) + + with ( + patch("forge.workflow.nodes.stats_posting.post_stats_comment", mock_post), + patch("forge.workflow.nodes.stats_posting.ensure_stats_is_final_comment", mock_ensure), + ): + await post_terminal_stats(feature_state) + + mock_ensure.assert_awaited_once() + + @pytest.mark.asyncio + async def test_uses_pre_set_workflow_outcome(self, feature_state): + """If workflow_outcome is already set in state it is checked.""" + feature_state["workflow_outcome"] = "Blocked" + feature_state["stats_outcome_reason"] = "Awaiting vendor API" + feature_state["last_error"] = None # would normally produce 'Completed' + + mock_post = AsyncMock(return_value=True) + mock_ensure = AsyncMock(return_value=True) + + with ( + patch("forge.workflow.nodes.stats_posting.post_stats_comment", mock_post), + patch("forge.workflow.nodes.stats_posting.ensure_stats_is_final_comment", mock_ensure), + ): + await post_terminal_stats(feature_state) + + mock_post.assert_awaited_once_with( + ticket_key="FEAT-1", + stats=feature_state, + outcome="Blocked", + outcome_detail="Awaiting vendor API", + ) diff --git a/tests/unit/workflow/stats/__init__.py b/tests/unit/workflow/stats/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/workflow/stats/test_ensure_stats_final.py b/tests/unit/workflow/stats/test_ensure_stats_final.py new file mode 100644 index 00000000..c2c3fa88 --- /dev/null +++ b/tests/unit/workflow/stats/test_ensure_stats_final.py @@ -0,0 +1,561 @@ +"""Unit tests for ensure_stats_is_final_comment() in forge.workflow.stats.poster. + +Tests verify: +- No Forge comments exist → posts new stats comment +- Most recent Forge comment IS a stats comment → no re-post (returns True) +- Most recent Forge comment is NOT a stats comment → re-posts stats +- Service account ID filtering: only Forge comments are considered +- When service_account_id is empty, all comments are treated as Forge comments +- JiraClient.get_comments() failure → returns False gracefully +- JiraClient is always closed after fetching comments +- _is_stats_comment() correctly identifies stats comments by marker +""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from forge.workflow.stats.poster import ( + _STATS_BODY_MARKER, + _is_stats_comment, + ensure_stats_is_final_comment, +) + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + +TICKET_KEY = "PROJ-99" +OUTCOME = "completed" +SERVICE_ACCOUNT_ID = "forge-bot-123" + +# A body that looks like a stats comment (contains the marker) +STATS_BODY = f"h2. Workflow Stats\n...\n{_STATS_BODY_MARKER}run-abc -->" + +# A body that does NOT look like a stats comment +OTHER_BODY = "This is a regular error notification comment." + + +def _minimal_stats(**overrides) -> dict: + base = { + "stage_timestamps": {}, + "stats_pr_urls": [], + "stats_ci_cycles": 0, + "workflow_outcome": None, + "stats_outcome_reason": None, + "stats_comment_posted": False, + } + base.update(overrides) + return base + + +def _make_comment( + comment_id: str, + body: str, + author_id: str = SERVICE_ACCOUNT_ID, +) -> MagicMock: + """Build a mock JiraComment with the given attributes.""" + comment = MagicMock() + comment.id = comment_id + comment.body = body + comment.author_id = author_id + comment.created = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC) + return comment + + +def _make_jira_mock(comments: list) -> MagicMock: + """Return a mock JiraClient with get_comments returning *comments*.""" + mock = MagicMock() + mock.get_comments = AsyncMock(return_value=comments) + mock.add_comment = AsyncMock(return_value=MagicMock()) + mock.close = AsyncMock() + return mock + + +def _patch_service_account(account_id: str = SERVICE_ACCOUNT_ID): + """Context manager that patches get_settings to return account_id.""" + mock_settings = MagicMock() + mock_settings.jira_service_account_id = account_id + return patch("forge.workflow.stats.poster.get_settings", return_value=mock_settings) + + +# --------------------------------------------------------------------------- +# _is_stats_comment() helper +# --------------------------------------------------------------------------- + + +class TestIsStatsComment: + """Unit tests for the _is_stats_comment() detection helper.""" + + def test_returns_true_for_body_with_marker(self): + assert _is_stats_comment(STATS_BODY) is True + + def test_returns_true_for_minimal_marker(self): + assert _is_stats_comment("") is True + + def test_returns_false_for_plain_comment(self): + assert _is_stats_comment("Just a regular comment.") is False + + def test_returns_false_for_empty_body(self): + assert _is_stats_comment("") is False + + def test_returns_false_for_similar_but_wrong_marker(self): + # Must match the exact prefix _STATS_BODY_MARKER + assert _is_stats_comment("") is False + assert _is_stats_comment("") is False + + def test_marker_constant_starts_with_expected_prefix(self): + assert _STATS_BODY_MARKER == "") + + def test_includes_run_id(self): + marker = build_run_marker(RUN_ID) + assert RUN_ID in marker + + def test_contains_forge_stats_prefix(self): + marker = build_run_marker(RUN_ID) + assert "forge:stats:" in marker + + def test_format(self): + marker = build_run_marker("abc-123") + assert marker == "" + + def test_different_run_ids_produce_different_markers(self): + assert build_run_marker("run-1") != build_run_marker("run-2") + + +# --------------------------------------------------------------------------- +# TTL constant +# --------------------------------------------------------------------------- + + +class TestTtlConstant: + """Verify the 7-day TTL value.""" + + def test_seven_days_in_seconds(self): + assert STATS_IDEMPOTENCY_TTL_SECONDS == 7 * 24 * 60 * 60 + + def test_is_integer(self): + assert isinstance(STATS_IDEMPOTENCY_TTL_SECONDS, int) + + +# --------------------------------------------------------------------------- +# has_stats_been_posted +# --------------------------------------------------------------------------- + + +class TestHasStatsBeenPosted: + """has_stats_been_posted() checks Redis for the marker key.""" + + @pytest.mark.asyncio + async def test_returns_false_when_key_absent(self): + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=0) + + result = await has_stats_been_posted(TICKET_KEY, RUN_ID, redis_client=mock_redis) + + assert result is False + + @pytest.mark.asyncio + async def test_returns_true_when_key_present(self): + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=1) + + result = await has_stats_been_posted(TICKET_KEY, RUN_ID, redis_client=mock_redis) + + assert result is True + + @pytest.mark.asyncio + async def test_calls_exists_with_correct_key(self): + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=0) + + await has_stats_been_posted(TICKET_KEY, RUN_ID, redis_client=mock_redis) + + expected_key = _make_key(TICKET_KEY, RUN_ID) + mock_redis.exists.assert_called_once_with(expected_key) + + @pytest.mark.asyncio + async def test_uses_shared_client_when_none_provided(self): + """When redis_client is None, get_redis_client() is called.""" + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=0) + + with patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=mock_redis), + ): + result = await has_stats_been_posted(TICKET_KEY, RUN_ID) + + assert result is False + mock_redis.exists.assert_called_once() + + @pytest.mark.asyncio + async def test_truthy_redis_value_returns_true(self): + """Any non-zero integer from exists() is treated as True.""" + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=2) + + result = await has_stats_been_posted(TICKET_KEY, RUN_ID, redis_client=mock_redis) + + assert result is True + + +# --------------------------------------------------------------------------- +# mark_stats_posted +# --------------------------------------------------------------------------- + + +class TestMarkStatsPosted: + """mark_stats_posted() writes the marker key with correct TTL.""" + + @pytest.mark.asyncio + async def test_calls_setex(self): + mock_redis = AsyncMock() + mock_redis.setex = AsyncMock() + + await mark_stats_posted(TICKET_KEY, RUN_ID, redis_client=mock_redis) + + mock_redis.setex.assert_called_once() + + @pytest.mark.asyncio + async def test_setex_uses_correct_key(self): + mock_redis = AsyncMock() + mock_redis.setex = AsyncMock() + + await mark_stats_posted(TICKET_KEY, RUN_ID, redis_client=mock_redis) + + call_args = mock_redis.setex.call_args + key = call_args.args[0] + assert key == _make_key(TICKET_KEY, RUN_ID) + + @pytest.mark.asyncio + async def test_setex_uses_correct_ttl(self): + mock_redis = AsyncMock() + mock_redis.setex = AsyncMock() + + await mark_stats_posted(TICKET_KEY, RUN_ID, redis_client=mock_redis) + + call_args = mock_redis.setex.call_args + ttl = call_args.args[1] + assert ttl == STATS_IDEMPOTENCY_TTL_SECONDS + + @pytest.mark.asyncio + async def test_setex_stores_truthy_value(self): + mock_redis = AsyncMock() + mock_redis.setex = AsyncMock() + + await mark_stats_posted(TICKET_KEY, RUN_ID, redis_client=mock_redis) + + call_args = mock_redis.setex.call_args + value = call_args.args[2] + assert value # any truthy value is fine + + @pytest.mark.asyncio + async def test_uses_shared_client_when_none_provided(self): + mock_redis = AsyncMock() + mock_redis.setex = AsyncMock() + + with patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=mock_redis), + ): + await mark_stats_posted(TICKET_KEY, RUN_ID) + + mock_redis.setex.assert_called_once() + + @pytest.mark.asyncio + async def test_returns_none(self): + mock_redis = AsyncMock() + mock_redis.setex = AsyncMock() + + result = await mark_stats_posted(TICKET_KEY, RUN_ID, redis_client=mock_redis) + + assert result is None + + +# --------------------------------------------------------------------------- +# Integration with post_stats_comment +# --------------------------------------------------------------------------- + + +class TestPostStatsCommentIdempotency: + """post_stats_comment() integrates idempotency guard correctly.""" + + def _minimal_stats(self, **overrides) -> dict: + base = { + "stage_timestamps": {}, + "stats_pr_urls": [], + "stats_ci_cycles": 0, + "workflow_outcome": None, + "stats_outcome_reason": None, + "stats_comment_posted": False, + "workflow_run_id": RUN_ID, + } + base.update(overrides) + return base + + def _make_jira_mock(self, side_effect=None) -> MagicMock: + mock = MagicMock() + if side_effect is not None: + mock.add_comment = AsyncMock(side_effect=side_effect) + else: + mock.add_comment = AsyncMock(return_value=MagicMock()) + mock.close = AsyncMock() + return mock + + @pytest.mark.asyncio + async def test_skips_posting_when_already_posted(self): + """Returns True immediately without calling Jira when Redis marker exists.""" + from forge.workflow.stats.poster import post_stats_comment + + mock_jira = self._make_jira_mock() + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=1) # already posted + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=mock_redis), + ), + ): + result = await post_stats_comment( + TICKET_KEY, self._minimal_stats(), "completed", run_id=RUN_ID + ) + + assert result is True + mock_jira.add_comment.assert_not_called() + + @pytest.mark.asyncio + async def test_posts_and_marks_when_not_yet_posted(self): + """Posts the comment and writes the marker when Redis key is absent.""" + from forge.workflow.stats.poster import post_stats_comment + + mock_jira = self._make_jira_mock() + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=0) # not yet posted + mock_redis.setex = AsyncMock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=mock_redis), + ), + ): + result = await post_stats_comment( + TICKET_KEY, self._minimal_stats(), "completed", run_id=RUN_ID + ) + + assert result is True + mock_jira.add_comment.assert_called_once() + mock_redis.setex.assert_called_once() + + @pytest.mark.asyncio + async def test_comment_body_includes_run_marker(self): + """The posted comment body contains the hidden HTML marker.""" + from forge.workflow.stats.poster import post_stats_comment + + mock_jira = self._make_jira_mock() + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=0) + mock_redis.setex = AsyncMock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=mock_redis), + ), + ): + await post_stats_comment(TICKET_KEY, self._minimal_stats(), "completed", run_id=RUN_ID) + + args, _ = mock_jira.add_comment.call_args + comment_body = args[1] + assert f"" in comment_body + + @pytest.mark.asyncio + async def test_uses_workflow_run_id_from_stats_when_no_explicit_run_id(self): + """Falls back to stats['workflow_run_id'] when run_id not passed explicitly.""" + from forge.workflow.stats.poster import post_stats_comment + + mock_jira = self._make_jira_mock() + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=0) + mock_redis.setex = AsyncMock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=mock_redis), + ), + ): + # Note: no explicit run_id — should pick up workflow_run_id from stats + result = await post_stats_comment(TICKET_KEY, self._minimal_stats(), "completed") + + assert result is True + args, _ = mock_jira.add_comment.call_args + comment_body = args[1] + assert f"" in comment_body + + @pytest.mark.asyncio + async def test_redis_check_failure_does_not_block_post(self): + """If the Redis pre-check raises, the comment is still attempted.""" + from forge.workflow.stats.poster import post_stats_comment + + mock_jira = self._make_jira_mock() + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(side_effect=ConnectionError("redis down")) + mock_redis.setex = AsyncMock(side_effect=ConnectionError("redis down")) + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=mock_redis), + ), + ): + result = await post_stats_comment( + TICKET_KEY, self._minimal_stats(), "completed", run_id=RUN_ID + ) + + # Comment should still be posted even if Redis is unavailable + assert result is True + mock_jira.add_comment.assert_called_once() + + @pytest.mark.asyncio + async def test_marker_write_failure_does_not_affect_return_value(self): + """If the Redis marker write fails after a successful post, True is still returned.""" + from forge.workflow.stats.poster import post_stats_comment + + mock_jira = self._make_jira_mock() + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=0) + mock_redis.setex = AsyncMock(side_effect=ConnectionError("redis down")) + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=mock_redis), + ), + ): + result = await post_stats_comment( + TICKET_KEY, self._minimal_stats(), "completed", run_id=RUN_ID + ) + + assert result is True + + @pytest.mark.asyncio + async def test_no_marker_when_run_id_absent(self): + """When no run_id is available, the comment body has no HTML marker.""" + from forge.workflow.stats.poster import post_stats_comment + + mock_jira = self._make_jira_mock() + # Stats without workflow_run_id + stats = { + "stage_timestamps": {}, + "stats_pr_urls": [], + "stats_ci_cycles": 0, + "workflow_outcome": None, + "stats_outcome_reason": None, + "stats_comment_posted": False, + } + + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira): + await post_stats_comment(TICKET_KEY, stats, "completed") + + args, _ = mock_jira.add_comment.call_args + comment_body = args[1] + assert "forge:stats:" not in comment_body + + @pytest.mark.asyncio + async def test_does_not_mark_when_post_fails(self): + """Redis marker is NOT written if the Jira post fails.""" + from forge.workflow.stats.poster import post_stats_comment + + mock_jira = self._make_jira_mock(side_effect=Exception("API down")) + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=0) + mock_redis.setex = AsyncMock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch("forge.workflow.stats.poster.asyncio.sleep", new_callable=AsyncMock), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=mock_redis), + ), + ): + result = await post_stats_comment( + TICKET_KEY, self._minimal_stats(), "completed", run_id=RUN_ID + ) + + assert result is False + mock_redis.setex.assert_not_called() diff --git a/tests/unit/workflow/stats/test_poster.py b/tests/unit/workflow/stats/test_poster.py new file mode 100644 index 00000000..5189d5d8 --- /dev/null +++ b/tests/unit/workflow/stats/test_poster.py @@ -0,0 +1,435 @@ +"""Unit tests for forge.workflow.stats.poster. + +Tests verify: +- Successful comment posting returns True +- Jira API failures are handled gracefully (return False, log error) +- Retry logic with exponential backoff fires on transient failures +- Timeout handling returns False within the SLA +- JiraClient is always closed after use (resource cleanup) +- The correct comment body is passed to JiraClient.add_comment() +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from forge.workflow.stats.poster import ( + _INITIAL_BACKOFF_SECONDS, + _MAX_ATTEMPTS, + post_stats_comment, +) + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + +TICKET_KEY = "PROJ-42" +OUTCOME = "completed" +OUTCOME_DETAIL = None + + +def _minimal_stats(**overrides) -> dict: + base = { + "stage_timestamps": {}, + "stats_pr_urls": [], + "stats_ci_cycles": 0, + "workflow_outcome": None, + "stats_outcome_reason": None, + "stats_comment_posted": False, + } + base.update(overrides) + return base + + +def _make_jira_mock(side_effect=None) -> MagicMock: + """Return a mock JiraClient instance with add_comment and close as coroutines.""" + mock = MagicMock() + if side_effect is not None: + mock.add_comment = AsyncMock(side_effect=side_effect) + else: + mock.add_comment = AsyncMock(return_value=MagicMock()) + mock.close = AsyncMock() + return mock + + +# --------------------------------------------------------------------------- +# Success scenario +# --------------------------------------------------------------------------- + + +class TestPostStatsCommentSuccess: + """post_stats_comment() returns True when the comment is posted successfully.""" + + @pytest.mark.asyncio + async def test_returns_true_on_success(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira): + result = await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert result is True + + @pytest.mark.asyncio + async def test_calls_add_comment_with_correct_ticket(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira): + await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + mock_jira.add_comment.assert_called_once() + args, _ = mock_jira.add_comment.call_args + assert args[0] == TICKET_KEY + + @pytest.mark.asyncio + async def test_comment_body_contains_outcome(self): + """The comment body produced by the formatter should mention 'Completed'.""" + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira): + await post_stats_comment(TICKET_KEY, _minimal_stats(), "completed") + + args, _ = mock_jira.add_comment.call_args + comment_body = args[1] + assert "Completed" in comment_body + + @pytest.mark.asyncio + async def test_comment_body_contains_outcome_detail(self): + mock_jira = _make_jira_mock() + detail = "deployment succeeded" + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira): + await post_stats_comment(TICKET_KEY, _minimal_stats(), "blocked", detail) + + args, _ = mock_jira.add_comment.call_args + comment_body = args[1] + assert detail in comment_body + + @pytest.mark.asyncio + async def test_jira_client_closed_on_success(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira): + await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + mock_jira.close.assert_called_once() + + @pytest.mark.asyncio + async def test_only_one_attempt_on_success(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira): + await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert mock_jira.add_comment.call_count == 1 + + +# --------------------------------------------------------------------------- +# Jira API failure scenarios +# --------------------------------------------------------------------------- + + +class TestPostStatsCommentApiFailure: + """post_stats_comment() is non-blocking: logs errors and returns False.""" + + @pytest.mark.asyncio + async def test_returns_false_on_persistent_failure(self): + mock_jira = _make_jira_mock(side_effect=Exception("API down")) + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch("forge.workflow.stats.poster.asyncio.sleep", new_callable=AsyncMock), + ): + result = await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert result is False + + @pytest.mark.asyncio + async def test_does_not_raise_on_api_error(self): + """post_stats_comment must never propagate exceptions to callers.""" + mock_jira = _make_jira_mock(side_effect=RuntimeError("connection refused")) + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch("forge.workflow.stats.poster.asyncio.sleep", new_callable=AsyncMock), + ): + # Should not raise + result = await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert result is False + + @pytest.mark.asyncio + async def test_jira_client_closed_on_failure(self): + """JiraClient.close() must be called even when add_comment raises.""" + mock_jira = _make_jira_mock(side_effect=Exception("API down")) + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch("forge.workflow.stats.poster.asyncio.sleep", new_callable=AsyncMock), + ): + await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + # close() is called once per attempt + assert mock_jira.close.call_count == _MAX_ATTEMPTS + + @pytest.mark.asyncio + async def test_http_status_error_returns_false(self): + import httpx + + mock_request = MagicMock(spec=httpx.Request) + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 500 + http_error = httpx.HTTPStatusError( + "Internal Server Error", request=mock_request, response=mock_response + ) + + mock_jira = _make_jira_mock(side_effect=http_error) + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch("forge.workflow.stats.poster.asyncio.sleep", new_callable=AsyncMock), + ): + result = await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert result is False + + +# --------------------------------------------------------------------------- +# Retry logic +# --------------------------------------------------------------------------- + + +class TestRetryLogic: + """Verify exponential backoff and retry behaviour.""" + + @pytest.mark.asyncio + async def test_retries_up_to_max_attempts_on_failure(self): + mock_jira = _make_jira_mock(side_effect=Exception("transient")) + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch("forge.workflow.stats.poster.asyncio.sleep", new_callable=AsyncMock), + ): + await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert mock_jira.add_comment.call_count == _MAX_ATTEMPTS + + @pytest.mark.asyncio + async def test_succeeds_on_second_attempt(self): + """Returns True when the first attempt fails but the second succeeds.""" + mock_jira = MagicMock() + mock_jira.add_comment = AsyncMock(side_effect=[Exception("transient"), MagicMock()]) + mock_jira.close = AsyncMock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch("forge.workflow.stats.poster.asyncio.sleep", new_callable=AsyncMock), + ): + result = await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert result is True + assert mock_jira.add_comment.call_count == 2 + + @pytest.mark.asyncio + async def test_exponential_backoff_sleep_calls(self): + """sleep() is called between retries with exponentially increasing delays.""" + mock_jira = _make_jira_mock(side_effect=Exception("transient")) + mock_sleep = AsyncMock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch("forge.workflow.stats.poster.asyncio.sleep", mock_sleep), + ): + await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + # With _MAX_ATTEMPTS=3 there are 2 sleeps (after attempt 1 and 2) + expected_sleep_count = _MAX_ATTEMPTS - 1 + assert mock_sleep.call_count == expected_sleep_count + + # Verify delays grow (first < second for default backoff) + if expected_sleep_count >= 2: + delays = [c.args[0] for c in mock_sleep.call_args_list] + assert delays[1] > delays[0], "Second backoff should be larger than first" + + @pytest.mark.asyncio + async def test_initial_backoff_value(self): + """First retry uses _INITIAL_BACKOFF_SECONDS as the wait duration.""" + mock_jira = _make_jira_mock( + side_effect=[Exception("fail"), Exception("fail"), Exception("fail")] + ) + mock_sleep = AsyncMock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch("forge.workflow.stats.poster.asyncio.sleep", mock_sleep), + ): + await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + first_delay = mock_sleep.call_args_list[0].args[0] + assert first_delay == _INITIAL_BACKOFF_SECONDS + + @pytest.mark.asyncio + async def test_jira_client_instantiated_per_attempt(self): + """A fresh JiraClient is created for each attempt.""" + mock_jira = _make_jira_mock(side_effect=Exception("transient")) + mock_cls = MagicMock(return_value=mock_jira) + + with ( + patch("forge.workflow.stats.poster.JiraClient", mock_cls), + patch("forge.workflow.stats.poster.asyncio.sleep", new_callable=AsyncMock), + ): + await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert mock_cls.call_count == _MAX_ATTEMPTS + + @pytest.mark.asyncio + async def test_no_sleep_after_last_attempt(self): + """No sleep is issued after the final (exhausted) attempt.""" + mock_jira = _make_jira_mock(side_effect=Exception("transient")) + mock_sleep = AsyncMock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch("forge.workflow.stats.poster.asyncio.sleep", mock_sleep), + ): + await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + # sleeps = attempts - 1 + assert mock_sleep.call_count == _MAX_ATTEMPTS - 1 + + +# --------------------------------------------------------------------------- +# Timeout scenario +# --------------------------------------------------------------------------- + + +class TestTimeoutHandling: + """post_stats_comment() respects the 5-minute SLA timeout.""" + + @pytest.mark.asyncio + async def test_returns_false_on_timeout(self): + async def slow_add_comment(*_args, **_kwargs): + await asyncio.sleep(999) + + mock_jira = MagicMock() + mock_jira.add_comment = slow_add_comment + mock_jira.close = AsyncMock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.poster._OPERATION_TIMEOUT_SECONDS", + 0.05, # Use a very short timeout for the test + ), + ): + result = await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert result is False + + @pytest.mark.asyncio + async def test_does_not_raise_on_timeout(self): + """TimeoutError must be swallowed and False returned.""" + + async def slow_add_comment(*_args, **_kwargs): + await asyncio.sleep(999) + + mock_jira = MagicMock() + mock_jira.add_comment = slow_add_comment + mock_jira.close = AsyncMock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.poster._OPERATION_TIMEOUT_SECONDS", + 0.05, + ), + ): + # Should not raise TimeoutError + result = await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert result is False + + +# --------------------------------------------------------------------------- +# Comment content +# --------------------------------------------------------------------------- + + +class TestCommentContent: + """Verify the formatted comment body is constructed from stats correctly.""" + + @pytest.mark.asyncio + async def test_comment_includes_workflow_statistics_header(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira): + await post_stats_comment(TICKET_KEY, _minimal_stats(), "completed") + + args, _ = mock_jira.add_comment.call_args + assert "Workflow Statistics" in args[1] + + @pytest.mark.asyncio + async def test_comment_includes_ci_cycles(self): + stats = _minimal_stats(stats_ci_cycles=3) + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira): + await post_stats_comment(TICKET_KEY, stats, "completed") + + args, _ = mock_jira.add_comment.call_args + assert "3" in args[1] + + @pytest.mark.asyncio + async def test_comment_failed_outcome_with_detail(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira): + await post_stats_comment(TICKET_KEY, _minimal_stats(), "failed", "disk full") + + args, _ = mock_jira.add_comment.call_args + body = args[1] + assert "Failed" in body + assert "disk full" in body + + @pytest.mark.asyncio + async def test_format_stats_summary_called_with_correct_args(self): + """Ensure the formatter is invoked with the right stats, outcome, and detail.""" + mock_jira = _make_jira_mock() + stats = _minimal_stats(stats_ci_cycles=1) + detail = "some detail" + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.poster.format_stats_summary", + wraps=__import__( + "forge.workflow.stats.formatter", fromlist=["format_stats_summary"] + ).format_stats_summary, + ) as mock_fmt, + ): + await post_stats_comment(TICKET_KEY, stats, "blocked", detail) + + mock_fmt.assert_called_once() + call_kwargs = mock_fmt.call_args.kwargs + # Token-based threshold is passed when dollar threshold is not configured + assert call_kwargs.get("token_threshold") == 1_000_000 + assert call_kwargs.get("dollar_threshold") is None + + @pytest.mark.asyncio + async def test_dollar_threshold_passed_to_formatter_when_configured(self): + """When stats_alert_threshold_cost is set, it is passed to the formatter.""" + from unittest.mock import patch as _patch + + mock_jira = _make_jira_mock() + stats = _minimal_stats() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + _patch( + "forge.workflow.stats.poster.get_settings", + return_value=MagicMock( + stats_alert_enabled=True, + stats_alert_threshold_cost=5.0, + stats_alert_threshold_tokens=1_000_000, + llm_pricing={"claude-sonnet-4": {"input": 3.0, "output": 15.0}}, + ), + ), + patch( + "forge.workflow.stats.poster.format_stats_summary", + wraps=__import__( + "forge.workflow.stats.formatter", fromlist=["format_stats_summary"] + ).format_stats_summary, + ) as mock_fmt, + ): + await post_stats_comment(TICKET_KEY, stats, "completed") + + mock_fmt.assert_called_once() + call_kwargs = mock_fmt.call_args.kwargs + assert call_kwargs.get("dollar_threshold") == 5.0 + assert call_kwargs.get("token_threshold") is None diff --git a/tests/unit/workflow/stats/test_report_ticket.py b/tests/unit/workflow/stats/test_report_ticket.py new file mode 100644 index 00000000..01ab55c0 --- /dev/null +++ b/tests/unit/workflow/stats/test_report_ticket.py @@ -0,0 +1,419 @@ +"""Unit tests for forge.workflow.stats.report_ticket. + +Tests verify: +- resolve_report_ticket() uses the correct JQL and returns the first match key +- resolve_report_ticket() returns None when no issues are found +- create_report_ticket() calls create_task() with the correct args +- update_report_ticket() calls update_description() with the correct args +- ensure_report_ticket() creates a ticket when none exists +- ensure_report_ticket() updates an existing ticket (no duplicate) +- ensure_report_ticket() is idempotent — second call updates, not duplicates +- JiraClient is always closed after each operation +""" + +from __future__ import annotations + +from datetime import date +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from forge.workflow.stats.report_ticket import ( + REPORT_LABELS, + _report_jql, + _report_summary, + create_report_ticket, + ensure_report_ticket, + resolve_report_ticket, + update_report_ticket, +) + +# --------------------------------------------------------------------------- +# Fixtures / helpers +# --------------------------------------------------------------------------- + +PROJECT = "PROJ" +WEEK_START = date(2024, 1, 8) +REPORT_MARKDOWN = "## Weekly Report\n\nAll good." +TICKET_KEY = "PROJ-42" + + +def _make_jira_mock( + search_return: list | None = None, + create_task_return: str = TICKET_KEY, +) -> MagicMock: + """Return a mock JiraClient with async search_issues, create_task, update_description.""" + mock = MagicMock() + mock.search_issues = AsyncMock(return_value=search_return or []) + mock.create_task = AsyncMock(return_value=create_task_return) + mock.update_description = AsyncMock(return_value=None) + mock.close = AsyncMock() + return mock + + +def _make_issue(key: str = TICKET_KEY) -> MagicMock: + issue = MagicMock() + issue.key = key + return issue + + +# --------------------------------------------------------------------------- +# _report_summary +# --------------------------------------------------------------------------- + + +class TestReportSummary: + def test_format(self): + summary = _report_summary("PROJ", date(2024, 1, 8)) + assert summary == "Forge Weekly Report - PROJ - Week of 2024-01-08" + + def test_different_project(self): + summary = _report_summary("MYPROJ", date(2024, 6, 3)) + assert summary == "Forge Weekly Report - MYPROJ - Week of 2024-06-03" + + def test_contains_week_of_fragment(self): + summary = _report_summary("X", date(2024, 12, 30)) + assert "Week of 2024-12-30" in summary + + +# --------------------------------------------------------------------------- +# _report_jql +# --------------------------------------------------------------------------- + + +class TestReportJql: + def test_contains_project(self): + jql = _report_jql("PROJ", date(2024, 1, 8)) + assert '"PROJ"' in jql + + def test_contains_label(self): + jql = _report_jql("PROJ", date(2024, 1, 8)) + assert '"forge:weekly-report"' in jql + + def test_contains_week_of(self): + jql = _report_jql("PROJ", date(2024, 1, 8)) + assert "Week of 2024-01-08" in jql + + def test_full_jql(self): + jql = _report_jql("PROJ", date(2024, 1, 8)) + assert 'project = "PROJ"' in jql + assert 'labels = "forge:weekly-report"' in jql + assert 'summary ~ "Week of 2024-01-08"' in jql + + +# --------------------------------------------------------------------------- +# resolve_report_ticket +# --------------------------------------------------------------------------- + + +class TestResolveReportTicket: + @pytest.mark.asyncio + async def test_returns_none_when_no_issues(self): + mock_jira = _make_jira_mock(search_return=[]) + with patch("forge.workflow.stats.report_ticket.JiraClient", return_value=mock_jira): + result = await resolve_report_ticket(PROJECT, WEEK_START) + + assert result is None + + @pytest.mark.asyncio + async def test_returns_first_issue_key(self): + issues = [_make_issue("PROJ-42"), _make_issue("PROJ-43")] + mock_jira = _make_jira_mock(search_return=issues) + with patch("forge.workflow.stats.report_ticket.JiraClient", return_value=mock_jira): + result = await resolve_report_ticket(PROJECT, WEEK_START) + + assert result == "PROJ-42" + + @pytest.mark.asyncio + async def test_calls_search_issues_with_correct_jql(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.report_ticket.JiraClient", return_value=mock_jira): + await resolve_report_ticket(PROJECT, WEEK_START) + + mock_jira.search_issues.assert_called_once() + call_kwargs = mock_jira.search_issues.call_args + jql = call_kwargs[1].get("jql") or call_kwargs[0][0] + assert "PROJ" in jql + assert "forge:weekly-report" in jql + assert "2024-01-08" in jql + + @pytest.mark.asyncio + async def test_limits_results(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.report_ticket.JiraClient", return_value=mock_jira): + await resolve_report_ticket(PROJECT, WEEK_START) + + _, kwargs = mock_jira.search_issues.call_args + assert kwargs.get("max_results", 50) <= 10 + + @pytest.mark.asyncio + async def test_closes_client_on_success(self): + mock_jira = _make_jira_mock(search_return=[_make_issue()]) + with patch("forge.workflow.stats.report_ticket.JiraClient", return_value=mock_jira): + await resolve_report_ticket(PROJECT, WEEK_START) + + mock_jira.close.assert_called_once() + + @pytest.mark.asyncio + async def test_closes_client_on_empty_result(self): + mock_jira = _make_jira_mock(search_return=[]) + with patch("forge.workflow.stats.report_ticket.JiraClient", return_value=mock_jira): + await resolve_report_ticket(PROJECT, WEEK_START) + + mock_jira.close.assert_called_once() + + @pytest.mark.asyncio + async def test_closes_client_on_error(self): + mock_jira = _make_jira_mock() + mock_jira.search_issues = AsyncMock(side_effect=RuntimeError("network error")) + with patch("forge.workflow.stats.report_ticket.JiraClient", return_value=mock_jira): + with pytest.raises(RuntimeError): + await resolve_report_ticket(PROJECT, WEEK_START) + + mock_jira.close.assert_called_once() + + +# --------------------------------------------------------------------------- +# create_report_ticket +# --------------------------------------------------------------------------- + + +class TestCreateReportTicket: + @pytest.mark.asyncio + async def test_returns_ticket_key(self): + mock_jira = _make_jira_mock(create_task_return="PROJ-42") + with patch("forge.workflow.stats.report_ticket.JiraClient", return_value=mock_jira): + result = await create_report_ticket(PROJECT, WEEK_START, REPORT_MARKDOWN) + + assert result == "PROJ-42" + + @pytest.mark.asyncio + async def test_calls_create_task_with_correct_project(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.report_ticket.JiraClient", return_value=mock_jira): + await create_report_ticket(PROJECT, WEEK_START, REPORT_MARKDOWN) + + _, kwargs = mock_jira.create_task.call_args + assert kwargs.get("project_key") == PROJECT + + @pytest.mark.asyncio + async def test_calls_create_task_with_correct_summary(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.report_ticket.JiraClient", return_value=mock_jira): + await create_report_ticket(PROJECT, WEEK_START, REPORT_MARKDOWN) + + _, kwargs = mock_jira.create_task.call_args + expected_summary = "Forge Weekly Report - PROJ - Week of 2024-01-08" + assert kwargs.get("summary") == expected_summary + + @pytest.mark.asyncio + async def test_calls_create_task_with_correct_description(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.report_ticket.JiraClient", return_value=mock_jira): + await create_report_ticket(PROJECT, WEEK_START, REPORT_MARKDOWN) + + _, kwargs = mock_jira.create_task.call_args + assert kwargs.get("description") == REPORT_MARKDOWN + + @pytest.mark.asyncio + async def test_calls_create_task_with_correct_labels(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.report_ticket.JiraClient", return_value=mock_jira): + await create_report_ticket(PROJECT, WEEK_START, REPORT_MARKDOWN) + + _, kwargs = mock_jira.create_task.call_args + labels = kwargs.get("labels") or [] + assert "forge:weekly-report" in labels + assert "forge:generated" in labels + + @pytest.mark.asyncio + async def test_closes_client_on_success(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.report_ticket.JiraClient", return_value=mock_jira): + await create_report_ticket(PROJECT, WEEK_START, REPORT_MARKDOWN) + + mock_jira.close.assert_called_once() + + @pytest.mark.asyncio + async def test_closes_client_on_error(self): + mock_jira = _make_jira_mock() + mock_jira.create_task = AsyncMock(side_effect=RuntimeError("API error")) + with patch("forge.workflow.stats.report_ticket.JiraClient", return_value=mock_jira): + with pytest.raises(RuntimeError): + await create_report_ticket(PROJECT, WEEK_START, REPORT_MARKDOWN) + + mock_jira.close.assert_called_once() + + +# --------------------------------------------------------------------------- +# update_report_ticket +# --------------------------------------------------------------------------- + + +class TestUpdateReportTicket: + @pytest.mark.asyncio + async def test_calls_update_description_with_correct_key(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.report_ticket.JiraClient", return_value=mock_jira): + await update_report_ticket(TICKET_KEY, REPORT_MARKDOWN) + + mock_jira.update_description.assert_called_once_with(TICKET_KEY, REPORT_MARKDOWN) + + @pytest.mark.asyncio + async def test_calls_update_description_with_correct_content(self): + new_content = "## Updated Report\n\nNew data." + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.report_ticket.JiraClient", return_value=mock_jira): + await update_report_ticket(TICKET_KEY, new_content) + + mock_jira.update_description.assert_called_once_with(TICKET_KEY, new_content) + + @pytest.mark.asyncio + async def test_does_not_call_create_task(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.report_ticket.JiraClient", return_value=mock_jira): + await update_report_ticket(TICKET_KEY, REPORT_MARKDOWN) + + mock_jira.create_task.assert_not_called() + + @pytest.mark.asyncio + async def test_returns_none(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.report_ticket.JiraClient", return_value=mock_jira): + result = await update_report_ticket(TICKET_KEY, REPORT_MARKDOWN) + + assert result is None + + @pytest.mark.asyncio + async def test_closes_client_on_success(self): + mock_jira = _make_jira_mock() + with patch("forge.workflow.stats.report_ticket.JiraClient", return_value=mock_jira): + await update_report_ticket(TICKET_KEY, REPORT_MARKDOWN) + + mock_jira.close.assert_called_once() + + @pytest.mark.asyncio + async def test_closes_client_on_error(self): + mock_jira = _make_jira_mock() + mock_jira.update_description = AsyncMock(side_effect=RuntimeError("fail")) + with patch("forge.workflow.stats.report_ticket.JiraClient", return_value=mock_jira): + with pytest.raises(RuntimeError): + await update_report_ticket(TICKET_KEY, REPORT_MARKDOWN) + + mock_jira.close.assert_called_once() + + +# --------------------------------------------------------------------------- +# ensure_report_ticket +# --------------------------------------------------------------------------- + + +class TestEnsureReportTicket: + @pytest.mark.asyncio + async def test_creates_ticket_when_none_exists(self): + """When resolve returns None, create_report_ticket should be called.""" + with ( + patch( + "forge.workflow.stats.report_ticket.resolve_report_ticket", + new=AsyncMock(return_value=None), + ) as mock_resolve, + patch( + "forge.workflow.stats.report_ticket.create_report_ticket", + new=AsyncMock(return_value=TICKET_KEY), + ) as mock_create, + patch( + "forge.workflow.stats.report_ticket.update_report_ticket", + new=AsyncMock(), + ) as mock_update, + ): + result = await ensure_report_ticket(PROJECT, WEEK_START, REPORT_MARKDOWN) + + assert result == TICKET_KEY + mock_resolve.assert_called_once_with(PROJECT, WEEK_START) + mock_create.assert_called_once_with(PROJECT, WEEK_START, REPORT_MARKDOWN) + # update is NOT called when creating (create already sets description) + mock_update.assert_not_called() + + @pytest.mark.asyncio + async def test_updates_existing_ticket(self): + """When resolve returns a key, update_report_ticket should be called.""" + with ( + patch( + "forge.workflow.stats.report_ticket.resolve_report_ticket", + new=AsyncMock(return_value=TICKET_KEY), + ) as mock_resolve, + patch( + "forge.workflow.stats.report_ticket.create_report_ticket", + new=AsyncMock(return_value="PROJ-99"), + ) as mock_create, + patch( + "forge.workflow.stats.report_ticket.update_report_ticket", + new=AsyncMock(), + ) as mock_update, + ): + result = await ensure_report_ticket(PROJECT, WEEK_START, REPORT_MARKDOWN) + + assert result == TICKET_KEY + mock_resolve.assert_called_once_with(PROJECT, WEEK_START) + mock_create.assert_not_called() + mock_update.assert_called_once_with(TICKET_KEY, REPORT_MARKDOWN) + + @pytest.mark.asyncio + async def test_idempotent_on_existing_ticket(self): + """Calling ensure_report_ticket twice should yield the same key (no duplicate).""" + with ( + patch( + "forge.workflow.stats.report_ticket.resolve_report_ticket", + new=AsyncMock(return_value=TICKET_KEY), + ), + patch( + "forge.workflow.stats.report_ticket.create_report_ticket", + new=AsyncMock(return_value="PROJ-99"), + ) as mock_create, + patch( + "forge.workflow.stats.report_ticket.update_report_ticket", + new=AsyncMock(), + ), + ): + key1 = await ensure_report_ticket(PROJECT, WEEK_START, REPORT_MARKDOWN) + key2 = await ensure_report_ticket(PROJECT, WEEK_START, REPORT_MARKDOWN) + + assert key1 == key2 == TICKET_KEY + mock_create.assert_not_called() + + @pytest.mark.asyncio + async def test_returns_created_key(self): + new_key = "PROJ-100" + with ( + patch( + "forge.workflow.stats.report_ticket.resolve_report_ticket", + new=AsyncMock(return_value=None), + ), + patch( + "forge.workflow.stats.report_ticket.create_report_ticket", + new=AsyncMock(return_value=new_key), + ), + patch( + "forge.workflow.stats.report_ticket.update_report_ticket", + new=AsyncMock(), + ), + ): + result = await ensure_report_ticket(PROJECT, WEEK_START, REPORT_MARKDOWN) + + assert result == new_key + + +# --------------------------------------------------------------------------- +# REPORT_LABELS constant +# --------------------------------------------------------------------------- + + +class TestReportLabels: + def test_contains_weekly_report_label(self): + assert "forge:weekly-report" in REPORT_LABELS + + def test_contains_generated_label(self): + assert "forge:generated" in REPORT_LABELS + + def test_is_list(self): + assert isinstance(REPORT_LABELS, list) diff --git a/tests/unit/workflow/stats/test_stats_idempotency_integration.py b/tests/unit/workflow/stats/test_stats_idempotency_integration.py new file mode 100644 index 00000000..acfe7390 --- /dev/null +++ b/tests/unit/workflow/stats/test_stats_idempotency_integration.py @@ -0,0 +1,196 @@ +"""Integration test demonstrating stats comment duplicate prevention. + +This test shows the full idempotency flow end-to-end: + +1. First call to post_stats_comment() — Redis has no marker → posts comment + and writes the marker. +2. Second call to post_stats_comment() with the same run_id — Redis marker + present → skips posting entirely. + +The test uses an in-memory dict backed fake Redis to avoid requiring a +running Redis instance. This is an integration-level test because it +exercises the interaction between poster.py and idempotency.py rather than +testing each module in isolation. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Fake Redis implementation (in-memory dict — no real Redis required) +# --------------------------------------------------------------------------- + + +class FakeRedis: + """Minimal in-memory Redis stub supporting exists() and setex().""" + + def __init__(self): + self._store: dict[str, str] = {} + + async def exists(self, key: str) -> int: + return 1 if key in self._store else 0 + + async def setex(self, key: str, _ttl: int, value: str) -> None: + self._store[key] = value + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +TICKET_KEY = "INTTEST-99" +RUN_ID = "aabbccdd-1234-5678-abcd-000000000001" +OUTCOME = "completed" + + +def _minimal_stats(run_id: str = RUN_ID) -> dict: + return { + "stage_timestamps": {}, + "stats_pr_urls": [], + "stats_ci_cycles": 0, + "workflow_outcome": None, + "stats_outcome_reason": None, + "stats_comment_posted": False, + "workflow_run_id": run_id, + } + + +def _make_jira_mock() -> MagicMock: + mock = MagicMock() + mock.add_comment = AsyncMock(return_value=MagicMock()) + mock.close = AsyncMock() + return mock + + +# --------------------------------------------------------------------------- +# Integration tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_first_call_posts_comment_and_marks_redis(): + """First invocation posts the comment and records the marker in Redis.""" + from forge.workflow.stats.poster import post_stats_comment + + fake_redis = FakeRedis() + mock_jira = _make_jira_mock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=fake_redis), + ), + ): + result = await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert result is True + mock_jira.add_comment.assert_called_once() + + # Marker must now be present in our fake Redis (key format: forge:stats:posted::) + assert await fake_redis.exists(f"forge:stats:posted:{TICKET_KEY}:{RUN_ID}") == 1 + + +@pytest.mark.asyncio +async def test_second_call_skips_posting(): + """Second invocation with the same run_id skips Jira entirely.""" + from forge.workflow.stats.poster import post_stats_comment + + fake_redis = FakeRedis() + mock_jira = _make_jira_mock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=fake_redis), + ), + ): + # First call — should post + result_first = await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + # Second call — should skip + result_second = await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + assert result_first is True + assert result_second is True # still "successful" — just a no-op + # Jira was only called once despite two invocations + assert mock_jira.add_comment.call_count == 1 + + +@pytest.mark.asyncio +async def test_different_run_ids_each_post_independently(): + """Two calls with different run_ids each result in a Jira post.""" + from forge.workflow.stats.poster import post_stats_comment + + fake_redis = FakeRedis() + mock_jira = _make_jira_mock() + run_id_a = "aaaaaaaa-0000-0000-0000-000000000001" + run_id_b = "bbbbbbbb-0000-0000-0000-000000000002" + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=fake_redis), + ), + ): + result_a = await post_stats_comment(TICKET_KEY, _minimal_stats(run_id_a), OUTCOME) + result_b = await post_stats_comment(TICKET_KEY, _minimal_stats(run_id_b), OUTCOME) + + assert result_a is True + assert result_b is True + assert mock_jira.add_comment.call_count == 2 + + +@pytest.mark.asyncio +async def test_comment_body_contains_unique_marker(): + """The posted comment embeds the hidden HTML marker for the run_id.""" + from forge.workflow.stats.poster import post_stats_comment + + fake_redis = FakeRedis() + mock_jira = _make_jira_mock() + + with ( + patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=fake_redis), + ), + ): + await post_stats_comment(TICKET_KEY, _minimal_stats(), OUTCOME) + + args, _ = mock_jira.add_comment.call_args + comment_body = args[1] + assert f"" in comment_body + + +@pytest.mark.asyncio +async def test_same_ticket_different_runs_are_independent(): + """Same ticket key but different run IDs behave as independent posts.""" + from forge.workflow.stats.poster import post_stats_comment + + fake_redis = FakeRedis() + mock_jira_1 = _make_jira_mock() + mock_jira_2 = _make_jira_mock() + run_id_1 = "run-11111111-0000-0000-0000-000000000001" + run_id_2 = "run-22222222-0000-0000-0000-000000000002" + + with ( + patch( + "forge.workflow.stats.idempotency.get_redis_client", + new=AsyncMock(return_value=fake_redis), + ), + ): + # First run on the same ticket + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira_1): + r1 = await post_stats_comment(TICKET_KEY, _minimal_stats(run_id_1), OUTCOME) + + # Second run (new run_id) on the same ticket — should also post + with patch("forge.workflow.stats.poster.JiraClient", return_value=mock_jira_2): + r2 = await post_stats_comment(TICKET_KEY, _minimal_stats(run_id_2), OUTCOME) + + assert r1 is True + assert r2 is True + mock_jira_1.add_comment.assert_called_once() + mock_jira_2.add_comment.assert_called_once() diff --git a/tests/unit/workflow/stats/test_weekly_formatter.py b/tests/unit/workflow/stats/test_weekly_formatter.py new file mode 100644 index 00000000..6116af62 --- /dev/null +++ b/tests/unit/workflow/stats/test_weekly_formatter.py @@ -0,0 +1,794 @@ +"""Unit tests for forge.workflow.stats.weekly_formatter. + +Coverage: +- _format_duration: edge cases (0s, minutes, hours, > 24h) +- _format_token_count: abbreviation thresholds (raw, k, M) +- _format_bottleneck_section: all fields present/absent +- format_weekly_report_cli: structure, sections, empty lists, feature rollups +- format_weekly_report_markdown: valid markdown structure, tables, rollups +- format_weekly_report_json: valid parseable JSON, all fields, rollups +""" + +from __future__ import annotations + +import json + +import pytest + +from forge.workflow.stats.weekly_formatter import ( + _format_bottleneck_section, + _format_duration, + _format_token_count, + format_weekly_report_cli, + format_weekly_report_json, + format_weekly_report_markdown, +) +from forge.workflow.stats.weekly_report import ( + UNASSIGNED_FEATURE_KEY, + BottleneckAnalysis, + FeatureRollup, + TicketSummary, + WeeklyReportData, +) + +# --------------------------------------------------------------------------- +# Fixtures / helpers +# --------------------------------------------------------------------------- + + +def _make_ticket( + ticket_key: str = "AISOS-1", + ticket_type: str = "Feature", + status: str = "completed", + duration_seconds: float | None = 3600.0, + input_tokens: int = 1000, + output_tokens: int = 500, + ci_cycles: int = 0, + outcome: str | None = "Completed", +) -> TicketSummary: + return TicketSummary( + ticket_key=ticket_key, + ticket_type=ticket_type, + status=status, + duration_seconds=duration_seconds, + input_tokens=input_tokens, + output_tokens=output_tokens, + ci_cycles=ci_cycles, + outcome=outcome, + tokens_by_stage={"prd": (input_tokens, output_tokens)}, + revision_counts={"prd": 1}, + stage_durations={"prd": duration_seconds or 0.0}, + ) + + +def _make_report( + project: str = "AISOS", + period_days: int = 7, + completed: list[TicketSummary] | None = None, + in_progress: list[TicketSummary] | None = None, + blocked: list[TicketSummary] | None = None, + tokens_by_stage: dict | None = None, + avg_cycle_time: float | None = None, + bottlenecks: BottleneckAnalysis | None = None, + feature_rollups: dict | None = None, +) -> WeeklyReportData: + completed = completed or [] + in_progress = in_progress or [] + blocked = blocked or [] + all_tickets = completed + in_progress + blocked + total_in = sum(t.input_tokens for t in all_tickets) + total_out = sum(t.output_tokens for t in all_tickets) + return WeeklyReportData( + project=project, + period_days=period_days, + report_start="2024-06-08T00:00:00+00:00", + report_end="2024-06-15T00:00:00+00:00", + completed_tickets=completed, + in_progress_tickets=in_progress, + blocked_tickets=blocked, + total_input_tokens=total_in, + total_output_tokens=total_out, + tokens_by_stage=tokens_by_stage or {}, + avg_cycle_time=avg_cycle_time, + bottlenecks=bottlenecks or BottleneckAnalysis(), + all_tickets=all_tickets, + feature_rollups=feature_rollups or {}, + ) + + +# --------------------------------------------------------------------------- +# Tests: _format_duration +# --------------------------------------------------------------------------- + + +class TestFormatDuration: + def test_zero_seconds(self) -> None: + assert _format_duration(0) == "0s" + + def test_sub_minute(self) -> None: + assert _format_duration(45) == "45s" + + def test_exactly_one_minute(self) -> None: + assert _format_duration(60) == "1m 0s" + + def test_minutes_and_seconds(self) -> None: + assert _format_duration(90) == "1m 30s" + + def test_minutes_only(self) -> None: + assert _format_duration(120) == "2m 0s" + + def test_exactly_one_hour(self) -> None: + assert _format_duration(3600) == "1h 0m" + + def test_hours_and_minutes(self) -> None: + assert _format_duration(3662) == "1h 1m" + + def test_large_hours_and_minutes(self) -> None: + assert _format_duration(13320) == "3h 42m" + + def test_over_24_hours(self) -> None: + # 25 hours + 1 minute + assert _format_duration(90061) == "25h 1m" + + def test_fractional_seconds_truncated(self) -> None: + # float with fractional part — truncated (not rounded) + assert _format_duration(61.9) == "1m 1s" + + def test_exactly_one_hour_one_minute(self) -> None: + assert _format_duration(3660) == "1h 1m" + + def test_seconds_only_large(self) -> None: + assert _format_duration(59) == "59s" + + +# --------------------------------------------------------------------------- +# Tests: _format_token_count +# --------------------------------------------------------------------------- + + +class TestFormatTokenCount: + def test_zero(self) -> None: + assert _format_token_count(0) == "0" + + def test_below_1k(self) -> None: + assert _format_token_count(999) == "999" + + def test_exactly_1k(self) -> None: + assert _format_token_count(1000) == "1k" + + def test_1500_is_1_point_5k(self) -> None: + assert _format_token_count(1500) == "1.5k" + + def test_31k(self) -> None: + assert _format_token_count(31000) == "31k" + + def test_999k(self) -> None: + assert _format_token_count(999000) == "999k" + + def test_exactly_1m(self) -> None: + assert _format_token_count(1_000_000) == "1M" + + def test_1_5m(self) -> None: + assert _format_token_count(1_500_000) == "1.5M" + + def test_10m(self) -> None: + assert _format_token_count(10_000_000) == "10M" + + def test_2500_is_2_point_5k(self) -> None: + assert _format_token_count(2500) == "2.5k" + + def test_500(self) -> None: + assert _format_token_count(500) == "500" + + def test_round_thousands(self) -> None: + assert _format_token_count(5000) == "5k" + + def test_2m_exact(self) -> None: + assert _format_token_count(2_000_000) == "2M" + + +# --------------------------------------------------------------------------- +# Tests: _format_bottleneck_section +# --------------------------------------------------------------------------- + + +class TestFormatBottleneckSection: + def test_empty_bottlenecks(self) -> None: + b = BottleneckAnalysis() + result = _format_bottleneck_section(b) + assert "Tickets Analysed : 0" in result + assert "Slowest Stage" in result + assert "CI Fix Rate : 0%" in result + assert "Most Revised" in result + + def test_with_slowest_stage(self) -> None: + b = BottleneckAnalysis( + avg_stage_durations={"prd": 3600.0}, + slowest_stage="prd", + total_tickets_analyzed=5, + ) + result = _format_bottleneck_section(b) + assert "PRD" in result + assert "1h 0m" in result + assert "Tickets Analysed : 5" in result + + def test_ci_fix_rate_percentage(self) -> None: + b = BottleneckAnalysis(ci_fix_rate=0.4, total_tickets_analyzed=10) + result = _format_bottleneck_section(b) + assert "CI Fix Rate : 40%" in result + + def test_ci_fix_rate_zero_percent(self) -> None: + b = BottleneckAnalysis(ci_fix_rate=0.0) + result = _format_bottleneck_section(b) + assert "CI Fix Rate : 0%" in result + + def test_ci_fix_rate_100_percent(self) -> None: + b = BottleneckAnalysis(ci_fix_rate=1.0, total_tickets_analyzed=3) + result = _format_bottleneck_section(b) + assert "CI Fix Rate : 100%" in result + + def test_most_revised_stages_top_3(self) -> None: + b = BottleneckAnalysis( + most_revised_stages=["prd", "spec", "implementation", "ci"], + ) + result = _format_bottleneck_section(b) + assert "PRD" in result + assert "Spec" in result + assert "Implementation" in result + # 4th stage should NOT appear (top 3 only) + assert "CI" not in result.split("Most Revised")[1].split("\n")[0] + + def test_most_revised_empty(self) -> None: + b = BottleneckAnalysis(most_revised_stages=[]) + result = _format_bottleneck_section(b) + assert "Most Revised" in result + + def test_avg_stage_durations_shown(self) -> None: + b = BottleneckAnalysis( + avg_stage_durations={"prd": 120.0, "spec": 240.0}, + ) + result = _format_bottleneck_section(b) + assert "Stage Avg Durations" in result + assert "PRD" in result + assert "Spec" in result + + def test_no_avg_durations_no_subsection(self) -> None: + b = BottleneckAnalysis(avg_stage_durations={}) + result = _format_bottleneck_section(b) + assert "Stage Avg Durations" not in result + + def test_unknown_stage_key_title_cased(self) -> None: + b = BottleneckAnalysis( + avg_stage_durations={"custom_stage": 60.0}, + slowest_stage="custom_stage", + ) + result = _format_bottleneck_section(b) + assert "Custom_Stage" in result + + +# --------------------------------------------------------------------------- +# Tests: format_weekly_report_cli +# --------------------------------------------------------------------------- + + +class TestFormatWeeklyReportCli: + def test_returns_string(self) -> None: + report = _make_report() + result = format_weekly_report_cli(report) + assert isinstance(result, str) + assert len(result) > 0 + + def test_header_contains_project(self) -> None: + report = _make_report(project="MYPROJ") + result = format_weekly_report_cli(report) + assert "MYPROJ" in result + + def test_period_in_header(self) -> None: + report = _make_report(period_days=14) + result = format_weekly_report_cli(report) + assert "14" in result + + def test_date_range_in_header(self) -> None: + report = _make_report() + result = format_weekly_report_cli(report) + assert "2024-06-08" in result + assert "2024-06-15" in result + + def test_summary_section_present(self) -> None: + report = _make_report() + result = format_weekly_report_cli(report) + assert "Summary" in result + assert "Total Tickets" in result + assert "Avg Cycle Time" in result + + def test_ticket_counts_match(self) -> None: + t1 = _make_ticket("AISOS-1", status="completed") + t2 = _make_ticket("AISOS-2", status="in_progress") + t3 = _make_ticket("AISOS-3", status="blocked") + report = _make_report(completed=[t1], in_progress=[t2], blocked=[t3]) + result = format_weekly_report_cli(report) + assert "Completed : 1" in result + assert "In Progress : 1" in result + assert "Blocked : 1" in result + assert "Total Tickets : 3" in result + + def test_avg_cycle_time_shown(self) -> None: + report = _make_report(avg_cycle_time=3600.0) + result = format_weekly_report_cli(report) + assert "1h 0m" in result + + def test_avg_cycle_time_none_shows_dash(self) -> None: + report = _make_report(avg_cycle_time=None) + result = format_weekly_report_cli(report) + assert "Avg Cycle Time" in result + assert "\u2014" in result # em-dash + + def test_token_counts_shown(self) -> None: + t1 = _make_ticket(input_tokens=31000, output_tokens=5000) + report = _make_report(completed=[t1]) + result = format_weekly_report_cli(report) + assert "Total Tokens" in result + + def test_completed_tickets_section(self) -> None: + t1 = _make_ticket("AISOS-100") + report = _make_report(completed=[t1]) + result = format_weekly_report_cli(report) + assert "Completed Tickets" in result + assert "AISOS-100" in result + + def test_empty_completed_shows_none(self) -> None: + report = _make_report(completed=[]) + result = format_weekly_report_cli(report) + assert "(none)" in result + + def test_in_progress_section(self) -> None: + t = _make_ticket("AISOS-200", status="in_progress") + report = _make_report(in_progress=[t]) + result = format_weekly_report_cli(report) + assert "In-Progress Tickets" in result + assert "AISOS-200" in result + + def test_blocked_section(self) -> None: + t = _make_ticket("AISOS-300", status="blocked") + report = _make_report(blocked=[t]) + result = format_weekly_report_cli(report) + assert "Blocked Tickets" in result + assert "AISOS-300" in result + + def test_token_by_stage_section(self) -> None: + report = _make_report(tokens_by_stage={"prd": (1000, 500)}) + result = format_weekly_report_cli(report) + assert "Token Usage by Stage" in result + assert "PRD" in result + + def test_bottleneck_section_present(self) -> None: + report = _make_report() + result = format_weekly_report_cli(report) + assert "Bottleneck Analysis" in result + + def test_feature_rollup_included_when_present(self) -> None: + rollup = FeatureRollup( + feature_key="AISOS-10", + feature_summary="My Feature", + linked_tickets=[_make_ticket("AISOS-11")], + total_input_tokens=1000, + total_output_tokens=500, + tickets_completed=1, + tickets_in_progress=0, + completion_percentage=100.0, + ) + report = _make_report(feature_rollups={"AISOS-10": rollup}) + result = format_weekly_report_cli(report) + assert "Feature Rollup" in result + assert "AISOS-10" in result + assert "My Feature" in result + + def test_no_feature_rollup_section_when_empty(self) -> None: + report = _make_report(feature_rollups={}) + result = format_weekly_report_cli(report) + assert "Feature Rollup" not in result + + def test_ticket_duration_in_list(self) -> None: + t = _make_ticket(duration_seconds=7380.0) # 2h 3m + report = _make_report(completed=[t]) + result = format_weekly_report_cli(report) + assert "2h 3m" in result + + def test_ticket_duration_none_shown_as_dash(self) -> None: + t = _make_ticket(duration_seconds=None) + report = _make_report(completed=[t]) + result = format_weekly_report_cli(report) + assert "\u2014" in result + + def test_total_tokens_abbreviated(self) -> None: + t = _make_ticket(input_tokens=500_000, output_tokens=500_000) + report = _make_report(completed=[t]) + result = format_weekly_report_cli(report) + assert "1M" in result or "1000k" not in result # abbreviated + + def test_unassigned_feature_rollup(self) -> None: + rollup = FeatureRollup( + feature_key=UNASSIGNED_FEATURE_KEY, + feature_summary="", + linked_tickets=[], + ) + report = _make_report(feature_rollups={UNASSIGNED_FEATURE_KEY: rollup}) + result = format_weekly_report_cli(report) + assert UNASSIGNED_FEATURE_KEY in result + + +# --------------------------------------------------------------------------- +# Tests: format_weekly_report_markdown +# --------------------------------------------------------------------------- + + +class TestFormatWeeklyReportMarkdown: + def test_returns_string(self) -> None: + report = _make_report() + result = format_weekly_report_markdown(report) + assert isinstance(result, str) + + def test_h1_header_contains_project(self) -> None: + report = _make_report(project="TESTPROJ") + result = format_weekly_report_markdown(report) + assert "# Weekly Report" in result + assert "TESTPROJ" in result + + def test_h2_sections_present(self) -> None: + report = _make_report() + result = format_weekly_report_markdown(report) + assert "## Summary" in result + assert "## Completed Tickets" in result + assert "## In-Progress Tickets" in result + assert "## Blocked Tickets" in result + assert "## Token Usage by Stage" in result + assert "## Bottleneck Analysis" in result + + def test_summary_table_has_rows(self) -> None: + t = _make_ticket() + report = _make_report(completed=[t]) + result = format_weekly_report_markdown(report) + assert "| Total Tickets |" in result + assert "| Completed |" in result + + def test_completed_tickets_table(self) -> None: + t = _make_ticket("AISOS-1") + report = _make_report(completed=[t]) + result = format_weekly_report_markdown(report) + assert "| Ticket | Type | Duration | Tokens |" in result + assert "| AISOS-1 |" in result + + def test_empty_completed_shows_italic_none(self) -> None: + report = _make_report(completed=[]) + result = format_weekly_report_markdown(report) + assert "_No completed tickets this period._" in result + + def test_empty_in_progress_shows_italic_none(self) -> None: + report = _make_report(in_progress=[]) + result = format_weekly_report_markdown(report) + assert "_No in-progress tickets this period._" in result + + def test_empty_blocked_shows_italic_none(self) -> None: + report = _make_report(blocked=[]) + result = format_weekly_report_markdown(report) + assert "_No blocked tickets this period._" in result + + def test_token_usage_table_with_data(self) -> None: + report = _make_report(tokens_by_stage={"prd": (1000, 500), "spec": (2000, 800)}) + result = format_weekly_report_markdown(report) + assert "| Stage | Input | Output | Total |" in result + assert "| PRD |" in result + assert "| Spec |" in result + + def test_no_token_data_shows_message(self) -> None: + report = _make_report(tokens_by_stage={}) + result = format_weekly_report_markdown(report) + assert "_No stage token data available._" in result + + def test_bottleneck_table_present(self) -> None: + report = _make_report( + bottlenecks=BottleneckAnalysis( + total_tickets_analyzed=5, + ci_fix_rate=0.4, + slowest_stage="prd", + avg_stage_durations={"prd": 3600.0}, + ) + ) + result = format_weekly_report_markdown(report) + assert "| Tickets Analysed |" in result + assert "| CI Fix Rate |" in result + assert "40%" in result + + def test_feature_rollup_section_included(self) -> None: + rollup = FeatureRollup( + feature_key="AISOS-10", + feature_summary="My Feature", + linked_tickets=[_make_ticket("AISOS-11")], + total_input_tokens=5000, + total_output_tokens=2000, + tickets_completed=1, + tickets_in_progress=0, + completion_percentage=100.0, + ) + report = _make_report(feature_rollups={"AISOS-10": rollup}) + result = format_weekly_report_markdown(report) + assert "## Feature Rollup" in result + assert "| AISOS-10 |" in result + assert "My Feature" in result + + def test_no_feature_rollup_section_when_empty(self) -> None: + report = _make_report(feature_rollups={}) + result = format_weekly_report_markdown(report) + assert "## Feature Rollup" not in result + + def test_markdown_table_separator_present(self) -> None: + report = _make_report() + result = format_weekly_report_markdown(report) + # All tables should have separator rows with |---| + assert "|--------|-------|" in result + + def test_avg_cycle_time_in_summary(self) -> None: + report = _make_report(avg_cycle_time=7200.0) + result = format_weekly_report_markdown(report) + assert "2h 0m" in result + + def test_stage_avg_durations_subsection(self) -> None: + b = BottleneckAnalysis( + avg_stage_durations={"prd": 3600.0}, + slowest_stage="prd", + ) + report = _make_report(bottlenecks=b) + result = format_weekly_report_markdown(report) + assert "### Stage Average Durations" in result + assert "| PRD |" in result + + def test_period_days_in_header(self) -> None: + report = _make_report(period_days=30) + result = format_weekly_report_markdown(report) + assert "Last 30 Days" in result + + def test_date_range_present(self) -> None: + report = _make_report() + result = format_weekly_report_markdown(report) + assert "2024-06-08" in result + assert "2024-06-15" in result + + def test_completion_percentage_in_rollup(self) -> None: + rollup = FeatureRollup( + feature_key="F-1", + linked_tickets=[_make_ticket("T-1")], + tickets_completed=1, + tickets_in_progress=0, + completion_percentage=66.7, + ) + report = _make_report(feature_rollups={"F-1": rollup}) + result = format_weekly_report_markdown(report) + assert "67%" in result + + def test_ticket_type_in_table(self) -> None: + t = _make_ticket("BUG-1", ticket_type="Bug", status="completed") + report = _make_report(completed=[t]) + result = format_weekly_report_markdown(report) + assert "Bug" in result + + +# --------------------------------------------------------------------------- +# Tests: format_weekly_report_json +# --------------------------------------------------------------------------- + + +class TestFormatWeeklyReportJson: + def test_returns_valid_json(self) -> None: + report = _make_report() + result = format_weekly_report_json(report) + parsed = json.loads(result) + assert isinstance(parsed, dict) + + def test_top_level_keys_present(self) -> None: + report = _make_report() + parsed = json.loads(format_weekly_report_json(report)) + required_keys = { + "project", + "period_days", + "report_start", + "report_end", + "summary", + "tokens_by_stage", + "bottlenecks", + "completed_tickets", + "in_progress_tickets", + "blocked_tickets", + "feature_rollups", + } + assert required_keys.issubset(parsed.keys()) + + def test_project_name_in_json(self) -> None: + report = _make_report(project="MYPROJ") + parsed = json.loads(format_weekly_report_json(report)) + assert parsed["project"] == "MYPROJ" + + def test_period_days_in_json(self) -> None: + report = _make_report(period_days=14) + parsed = json.loads(format_weekly_report_json(report)) + assert parsed["period_days"] == 14 + + def test_summary_section_structure(self) -> None: + t = _make_ticket() + report = _make_report(completed=[t]) + parsed = json.loads(format_weekly_report_json(report)) + summary = parsed["summary"] + assert "total_tickets" in summary + assert "completed" in summary + assert "in_progress" in summary + assert "blocked" in summary + assert "avg_cycle_time_seconds" in summary + assert "total_input_tokens" in summary + assert "total_output_tokens" in summary + + def test_completed_count_in_summary(self) -> None: + t1 = _make_ticket("T1") + t2 = _make_ticket("T2") + report = _make_report(completed=[t1, t2]) + parsed = json.loads(format_weekly_report_json(report)) + assert parsed["summary"]["completed"] == 2 + assert parsed["summary"]["total_tickets"] == 2 + + def test_ticket_dict_fields(self) -> None: + t = _make_ticket("AISOS-5", input_tokens=2000, output_tokens=800) + report = _make_report(completed=[t]) + parsed = json.loads(format_weekly_report_json(report)) + ticket = parsed["completed_tickets"][0] + assert ticket["ticket_key"] == "AISOS-5" + assert ticket["input_tokens"] == 2000 + assert ticket["output_tokens"] == 800 + assert "status" in ticket + assert "duration_seconds" in ticket + assert "ci_cycles" in ticket + assert "outcome" in ticket + assert "tokens_by_stage" in ticket + assert "revision_counts" in ticket + assert "stage_durations" in ticket + + def test_tokens_by_stage_in_json(self) -> None: + report = _make_report(tokens_by_stage={"prd": (1000, 500)}) + parsed = json.loads(format_weekly_report_json(report)) + assert "prd" in parsed["tokens_by_stage"] + assert parsed["tokens_by_stage"]["prd"]["input"] == 1000 + assert parsed["tokens_by_stage"]["prd"]["output"] == 500 + + def test_bottlenecks_section(self) -> None: + b = BottleneckAnalysis( + avg_stage_durations={"prd": 120.0}, + most_revised_stages=["prd", "spec"], + ci_fix_rate=0.5, + slowest_stage="prd", + total_tickets_analyzed=10, + ) + report = _make_report(bottlenecks=b) + parsed = json.loads(format_weekly_report_json(report)) + bn = parsed["bottlenecks"] + assert bn["total_tickets_analyzed"] == 10 + assert bn["slowest_stage"] == "prd" + assert bn["ci_fix_rate"] == pytest.approx(0.5) + assert bn["most_revised_stages"] == ["prd", "spec"] + assert bn["avg_stage_durations"]["prd"] == pytest.approx(120.0) + + def test_feature_rollup_in_json(self) -> None: + rollup = FeatureRollup( + feature_key="AISOS-10", + feature_summary="Feature Summary", + linked_tickets=[_make_ticket("AISOS-11")], + total_input_tokens=5000, + total_output_tokens=2000, + tickets_completed=1, + tickets_in_progress=0, + completion_percentage=100.0, + ) + report = _make_report(feature_rollups={"AISOS-10": rollup}) + parsed = json.loads(format_weekly_report_json(report)) + assert "AISOS-10" in parsed["feature_rollups"] + fr = parsed["feature_rollups"]["AISOS-10"] + assert fr["feature_key"] == "AISOS-10" + assert fr["feature_summary"] == "Feature Summary" + assert fr["total_input_tokens"] == 5000 + assert fr["total_output_tokens"] == 2000 + assert fr["tickets_completed"] == 1 + assert fr["completion_percentage"] == pytest.approx(100.0) + assert "AISOS-11" in fr["linked_tickets"] + + def test_empty_feature_rollups_is_empty_dict(self) -> None: + report = _make_report(feature_rollups={}) + parsed = json.loads(format_weekly_report_json(report)) + assert parsed["feature_rollups"] == {} + + def test_avg_cycle_time_none_serialized(self) -> None: + report = _make_report(avg_cycle_time=None) + parsed = json.loads(format_weekly_report_json(report)) + assert parsed["summary"]["avg_cycle_time_seconds"] is None + + def test_avg_cycle_time_value_serialized(self) -> None: + report = _make_report(avg_cycle_time=7200.0) + parsed = json.loads(format_weekly_report_json(report)) + assert parsed["summary"]["avg_cycle_time_seconds"] == pytest.approx(7200.0) + + def test_output_is_sorted_keys(self) -> None: + report = _make_report() + result = format_weekly_report_json(report) + parsed = json.loads(result) + keys = list(parsed.keys()) + assert keys == sorted(keys) + + def test_in_progress_tickets_list(self) -> None: + t = _make_ticket("AISOS-99", status="in_progress") + report = _make_report(in_progress=[t]) + parsed = json.loads(format_weekly_report_json(report)) + assert len(parsed["in_progress_tickets"]) == 1 + assert parsed["in_progress_tickets"][0]["ticket_key"] == "AISOS-99" + + def test_blocked_tickets_list(self) -> None: + t = _make_ticket("AISOS-88", status="blocked") + report = _make_report(blocked=[t]) + parsed = json.loads(format_weekly_report_json(report)) + assert len(parsed["blocked_tickets"]) == 1 + assert parsed["blocked_tickets"][0]["ticket_key"] == "AISOS-88" + + def test_multiple_tickets_in_json(self) -> None: + c1 = _make_ticket("AISOS-1") + c2 = _make_ticket("AISOS-2") + ip = _make_ticket("AISOS-3", status="in_progress") + report = _make_report(completed=[c1, c2], in_progress=[ip]) + parsed = json.loads(format_weekly_report_json(report)) + assert parsed["summary"]["total_tickets"] == 3 + assert len(parsed["completed_tickets"]) == 2 + assert len(parsed["in_progress_tickets"]) == 1 + + def test_token_raw_integers_not_abbreviated(self) -> None: + """JSON should contain raw int values, not abbreviated strings like '1k'.""" + t = _make_ticket(input_tokens=31_000, output_tokens=5_000) + report = _make_report(completed=[t]) + parsed = json.loads(format_weekly_report_json(report)) + assert parsed["completed_tickets"][0]["input_tokens"] == 31_000 + assert parsed["completed_tickets"][0]["output_tokens"] == 5_000 + + def test_report_dates_preserved(self) -> None: + report = _make_report() + parsed = json.loads(format_weekly_report_json(report)) + assert parsed["report_start"] == "2024-06-08T00:00:00+00:00" + assert parsed["report_end"] == "2024-06-15T00:00:00+00:00" + + +# --------------------------------------------------------------------------- +# Tests: import paths +# --------------------------------------------------------------------------- + + +class TestImportPaths: + def test_format_duration_importable(self) -> None: + from forge.workflow.stats.weekly_formatter import _format_duration + + assert callable(_format_duration) + + def test_format_token_count_importable(self) -> None: + from forge.workflow.stats.weekly_formatter import _format_token_count + + assert callable(_format_token_count) + + def test_format_bottleneck_section_importable(self) -> None: + from forge.workflow.stats.weekly_formatter import _format_bottleneck_section + + assert callable(_format_bottleneck_section) + + def test_cli_formatter_importable(self) -> None: + from forge.workflow.stats.weekly_formatter import format_weekly_report_cli + + assert callable(format_weekly_report_cli) + + def test_markdown_formatter_importable(self) -> None: + from forge.workflow.stats.weekly_formatter import format_weekly_report_markdown + + assert callable(format_weekly_report_markdown) + + def test_json_formatter_importable(self) -> None: + from forge.workflow.stats.weekly_formatter import format_weekly_report_json + + assert callable(format_weekly_report_json) diff --git a/tests/unit/workflow/stats/test_weekly_report.py b/tests/unit/workflow/stats/test_weekly_report.py new file mode 100644 index 00000000..d2d284b4 --- /dev/null +++ b/tests/unit/workflow/stats/test_weekly_report.py @@ -0,0 +1,1003 @@ +"""Unit tests for forge.workflow.stats.weekly_report. + +All Redis and external I/O is mocked. Tests cover: + +- WeeklyReportData dataclass construction and fields +- TicketSummary dataclass construction and fields +- BottleneckAnalysis dataclass construction and fields +- _parse_checkpoint_stats: extraction from various checkpoint shapes +- _calculate_bottlenecks: averages, ordering, CI fix rate +- _is_within_window: time-window filtering +- _aggregate_tokens: cross-ticket aggregation +- _avg_cycle_time: average cycle time computation +- collect_weekly_data: Redis scan integration with mocked client +""" + +from __future__ import annotations + +import json +from datetime import UTC, datetime, timedelta +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from forge.workflow.stats.weekly_report import ( + BottleneckAnalysis, + TicketSummary, + WeeklyReportData, + _aggregate_tokens, + _avg_cycle_time, + _calculate_bottlenecks, + _is_within_window, + _parse_checkpoint_stats, + collect_weekly_data, +) + +# --------------------------------------------------------------------------- +# Shared helpers / fixtures +# --------------------------------------------------------------------------- + +_NOW = datetime(2024, 6, 15, 12, 0, 0, tzinfo=UTC) + + +@pytest.fixture(autouse=True) +def _patch_get_checkpoint_state(): + async def mock_get_state(ticket_key: str): + from forge.workflow.stats.weekly_report import get_redis_client + + try: + redis_client = await get_redis_client() + key = f"checkpoint:{ticket_key}" + val = await redis_client.get(key) + if val is not None: + import json + + return json.loads(val) + except Exception: + pass + return None + + with patch( + "forge.workflow.stats.weekly_report.get_checkpoint_state", side_effect=mock_get_state + ): + yield + + +_ONE_DAY_AGO = (_NOW - timedelta(days=1)).isoformat() +_TWO_WEEKS_AGO = (_NOW - timedelta(weeks=2)).isoformat() +_TICKET = "AISOS-100" + + +def _make_stage_data( + *, + stage_name: str = "prd", + iteration_count: int = 1, + machine_time_seconds: float = 120.0, + human_time_seconds: float = 0.0, + input_tokens: int = 500, + output_tokens: int = 250, + started_at: str | None = None, + ended_at: str | None = None, +) -> dict: + if started_at is None: + started_at = _ONE_DAY_AGO + return { + "stage_name": stage_name, + "iteration_count": iteration_count, + "machine_time_seconds": machine_time_seconds, + "human_time_seconds": human_time_seconds, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "started_at": started_at, + "ended_at": ended_at, + } + + +def _make_state( + *, + ticket_key: str = _TICKET, + ticket_type: str = "Feature", + workflow_outcome: str | None = "Completed", + is_blocked: bool = False, + stage_timestamps: dict | None = None, + stats_ci_cycles: int = 0, + updated_at: str | None = None, + **extra, +) -> dict: + if stage_timestamps is None: + stage_timestamps = { + "prd": _make_stage_data( + stage_name="prd", + started_at=_ONE_DAY_AGO, + ended_at=_ONE_DAY_AGO, + ) + } + if updated_at is None: + updated_at = _ONE_DAY_AGO + return { + "ticket_key": ticket_key, + "ticket_type": ticket_type, + "workflow_outcome": workflow_outcome, + "is_blocked": is_blocked, + "stage_timestamps": stage_timestamps, + "stats_ci_cycles": stats_ci_cycles, + "updated_at": updated_at, + **extra, + } + + +# --------------------------------------------------------------------------- +# WeeklyReportData dataclass +# --------------------------------------------------------------------------- + + +class TestWeeklyReportData: + def test_construction_defaults(self) -> None: + report = WeeklyReportData(project="AISOS") + assert report.project == "AISOS" + assert report.period_days == 7 + assert report.completed_tickets == [] + assert report.in_progress_tickets == [] + assert report.blocked_tickets == [] + assert report.total_input_tokens == 0 + assert report.total_output_tokens == 0 + assert report.tokens_by_stage == {} + assert report.avg_cycle_time is None + assert isinstance(report.bottlenecks, BottleneckAnalysis) + assert report.all_tickets == [] + + def test_construction_with_values(self) -> None: + ticket = TicketSummary(ticket_key="AISOS-1", status="completed") + report = WeeklyReportData( + project="AISOS", + period_days=14, + completed_tickets=[ticket], + total_input_tokens=1000, + total_output_tokens=500, + avg_cycle_time=3600.0, + ) + assert report.period_days == 14 + assert len(report.completed_tickets) == 1 + assert report.total_input_tokens == 1000 + assert report.total_output_tokens == 500 + assert report.avg_cycle_time == 3600.0 + + def test_report_start_end_fields(self) -> None: + report = WeeklyReportData( + project="AISOS", + report_start="2024-06-08T00:00:00+00:00", + report_end="2024-06-15T00:00:00+00:00", + ) + assert report.report_start == "2024-06-08T00:00:00+00:00" + assert report.report_end == "2024-06-15T00:00:00+00:00" + + def test_mutable_defaults_are_independent(self) -> None: + r1 = WeeklyReportData(project="A") + r2 = WeeklyReportData(project="B") + r1.completed_tickets.append(TicketSummary(ticket_key="A-1")) + assert r2.completed_tickets == [] + + +# --------------------------------------------------------------------------- +# TicketSummary dataclass +# --------------------------------------------------------------------------- + + +class TestTicketSummary: + def test_defaults(self) -> None: + t = TicketSummary(ticket_key="AISOS-1") + assert t.ticket_type == "Feature" + assert t.status == "in_progress" + assert t.duration_seconds is None + assert t.input_tokens == 0 + assert t.output_tokens == 0 + assert t.tokens_by_stage == {} + assert t.revision_counts == {} + assert t.ci_cycles == 0 + assert t.outcome is None + assert t.stage_durations == {} + + def test_all_fields(self) -> None: + t = TicketSummary( + ticket_key="AISOS-2", + ticket_type="Bug", + status="completed", + duration_seconds=3600.0, + input_tokens=1000, + output_tokens=500, + tokens_by_stage={"prd": (1000, 500)}, + revision_counts={"prd": 2}, + ci_cycles=3, + outcome="Completed", + stage_durations={"prd": 120.0}, + ) + assert t.ticket_type == "Bug" + assert t.status == "completed" + assert t.duration_seconds == 3600.0 + assert t.ci_cycles == 3 + + +# --------------------------------------------------------------------------- +# BottleneckAnalysis dataclass +# --------------------------------------------------------------------------- + + +class TestBottleneckAnalysis: + def test_defaults(self) -> None: + b = BottleneckAnalysis() + assert b.avg_stage_durations == {} + assert b.most_revised_stages == [] + assert b.ci_fix_rate == 0.0 + assert b.slowest_stage is None + assert b.total_tickets_analyzed == 0 + + def test_with_values(self) -> None: + b = BottleneckAnalysis( + avg_stage_durations={"prd": 60.0, "spec": 120.0}, + most_revised_stages=["spec", "prd"], + ci_fix_rate=0.5, + slowest_stage="spec", + total_tickets_analyzed=4, + ) + assert b.ci_fix_rate == 0.5 + assert b.slowest_stage == "spec" + assert b.total_tickets_analyzed == 4 + + +# --------------------------------------------------------------------------- +# _parse_checkpoint_stats +# --------------------------------------------------------------------------- + + +class TestParseCheckpointStats: + def test_missing_ticket_key_returns_none(self) -> None: + result = _parse_checkpoint_stats({"stage_timestamps": {}}) + assert result is None + + def test_missing_stage_timestamps_returns_none(self) -> None: + result = _parse_checkpoint_stats({"ticket_key": "AISOS-1"}) + assert result is None + + def test_minimal_valid_state(self) -> None: + state = {"ticket_key": "AISOS-1", "stage_timestamps": {}} + result = _parse_checkpoint_stats(state) + assert result is not None + assert result.ticket_key == "AISOS-1" + assert result.input_tokens == 0 + assert result.output_tokens == 0 + + def test_token_aggregation(self) -> None: + state = { + "ticket_key": "AISOS-1", + "stage_timestamps": { + "prd": _make_stage_data(input_tokens=300, output_tokens=150), + "spec": _make_stage_data(stage_name="spec", input_tokens=200, output_tokens=100), + }, + "workflow_outcome": "Completed", + } + result = _parse_checkpoint_stats(state) + assert result is not None + assert result.input_tokens == 500 + assert result.output_tokens == 250 + assert result.tokens_by_stage["prd"] == (300, 150) + assert result.tokens_by_stage["spec"] == (200, 100) + + def test_status_completed(self) -> None: + state = _make_state(workflow_outcome="Completed") + result = _parse_checkpoint_stats(state) + assert result is not None + assert result.status == "completed" + + def test_status_blocked_from_is_blocked(self) -> None: + state = _make_state(workflow_outcome=None, is_blocked=True) + result = _parse_checkpoint_stats(state) + assert result is not None + assert result.status == "blocked" + + def test_status_blocked_from_outcome(self) -> None: + state = _make_state(workflow_outcome="Blocked: waiting for approval") + result = _parse_checkpoint_stats(state) + assert result is not None + assert result.status == "blocked" + + def test_status_in_progress(self) -> None: + state = _make_state(workflow_outcome=None, is_blocked=False) + result = _parse_checkpoint_stats(state) + assert result is not None + assert result.status == "in_progress" + + def test_ticket_type_extraction(self) -> None: + state = _make_state(ticket_type="Bug") + result = _parse_checkpoint_stats(state) + assert result is not None + assert result.ticket_type == "Bug" + + def test_ticket_type_defaults_to_feature(self) -> None: + state = {"ticket_key": "AISOS-1", "stage_timestamps": {}} + result = _parse_checkpoint_stats(state) + assert result is not None + assert result.ticket_type == "Feature" + + def test_ci_cycles_extracted(self) -> None: + state = _make_state(stats_ci_cycles=3) + result = _parse_checkpoint_stats(state) + assert result is not None + assert result.ci_cycles == 3 + + def test_revision_counts_extracted(self) -> None: + state = { + "ticket_key": "AISOS-1", + "stage_timestamps": { + "prd": _make_stage_data(iteration_count=3), + "spec": _make_stage_data(stage_name="spec", iteration_count=1), + }, + "workflow_outcome": "Completed", + } + result = _parse_checkpoint_stats(state) + assert result is not None + assert result.revision_counts["prd"] == 3 + assert result.revision_counts["spec"] == 1 + + def test_stage_durations_extracted(self) -> None: + state = { + "ticket_key": "AISOS-1", + "stage_timestamps": { + "prd": _make_stage_data(machine_time_seconds=60.0), + "spec": _make_stage_data(stage_name="spec", machine_time_seconds=90.0), + }, + "workflow_outcome": "Completed", + } + result = _parse_checkpoint_stats(state) + assert result is not None + assert result.stage_durations["prd"] == 60.0 + assert result.stage_durations["spec"] == 90.0 + + def test_duration_seconds_for_completed_ticket(self) -> None: + started = "2024-06-14T10:00:00+00:00" + ended = "2024-06-14T11:00:00+00:00" + state = { + "ticket_key": "AISOS-1", + "stage_timestamps": { + "prd": _make_stage_data(started_at=started, ended_at=ended), + }, + "workflow_outcome": "Completed", + } + result = _parse_checkpoint_stats(state) + assert result is not None + assert result.duration_seconds == 3600.0 + + def test_duration_seconds_none_when_no_timestamps(self) -> None: + state = { + "ticket_key": "AISOS-1", + "stage_timestamps": { + "prd": { + "stage_name": "prd", + "input_tokens": 0, + "output_tokens": 0, + "iteration_count": 1, + "machine_time_seconds": 0.0, + "started_at": None, + "ended_at": None, + } + }, + "workflow_outcome": "Completed", + } + result = _parse_checkpoint_stats(state) + assert result is not None + assert result.duration_seconds is None + + def test_in_progress_duration_measured_from_start_to_now(self) -> None: + # The start is 1 hour ago; outcome is None (in_progress) + one_hour_ago = (datetime.now(UTC) - timedelta(hours=1)).isoformat() + state = { + "ticket_key": "AISOS-1", + "stage_timestamps": { + "prd": _make_stage_data(started_at=one_hour_ago, ended_at=None), + }, + "workflow_outcome": None, + } + result = _parse_checkpoint_stats(state) + assert result is not None + assert result.status == "in_progress" + # Allow generous delta for test execution time + assert result.duration_seconds is not None + assert 3500 < result.duration_seconds < 3700 + + def test_malformed_stage_timestamps_treated_as_empty(self) -> None: + state = {"ticket_key": "AISOS-1", "stage_timestamps": "not-a-dict"} + result = _parse_checkpoint_stats(state) + assert result is not None + assert result.input_tokens == 0 + + +# --------------------------------------------------------------------------- +# _calculate_bottlenecks +# --------------------------------------------------------------------------- + + +class TestCalculateBottlenecks: + def test_empty_list(self) -> None: + result = _calculate_bottlenecks([]) + assert result.total_tickets_analyzed == 0 + assert result.avg_stage_durations == {} + assert result.most_revised_stages == [] + assert result.ci_fix_rate == 0.0 + assert result.slowest_stage is None + + def test_single_ticket_no_ci(self) -> None: + ticket = TicketSummary( + ticket_key="AISOS-1", + stage_durations={"prd": 60.0, "spec": 120.0}, + revision_counts={"prd": 2, "spec": 1}, + ci_cycles=0, + ) + result = _calculate_bottlenecks([ticket]) + assert result.total_tickets_analyzed == 1 + assert result.avg_stage_durations["prd"] == 60.0 + assert result.avg_stage_durations["spec"] == 120.0 + assert result.slowest_stage == "spec" + assert result.ci_fix_rate == 0.0 + + def test_ci_fix_rate_all_triggered(self) -> None: + tickets = [ + TicketSummary(ticket_key="A-1", ci_cycles=2), + TicketSummary(ticket_key="A-2", ci_cycles=1), + ] + result = _calculate_bottlenecks(tickets) + assert result.ci_fix_rate == 1.0 + + def test_ci_fix_rate_partial(self) -> None: + tickets = [ + TicketSummary(ticket_key="A-1", ci_cycles=1), + TicketSummary(ticket_key="A-2", ci_cycles=0), + TicketSummary(ticket_key="A-3", ci_cycles=0), + TicketSummary(ticket_key="A-4", ci_cycles=0), + ] + result = _calculate_bottlenecks(tickets) + assert result.ci_fix_rate == pytest.approx(0.25) + + def test_avg_stage_durations_across_tickets(self) -> None: + tickets = [ + TicketSummary(ticket_key="A-1", stage_durations={"prd": 60.0}), + TicketSummary(ticket_key="A-2", stage_durations={"prd": 120.0}), + ] + result = _calculate_bottlenecks(tickets) + assert result.avg_stage_durations["prd"] == pytest.approx(90.0) + + def test_most_revised_stages_ordering(self) -> None: + tickets = [ + TicketSummary( + ticket_key="A-1", + revision_counts={"spec": 5, "prd": 1, "ci": 3}, + ), + ] + result = _calculate_bottlenecks(tickets) + assert result.most_revised_stages[0] == "spec" + assert result.most_revised_stages[1] == "ci" + assert result.most_revised_stages[2] == "prd" + + def test_slowest_stage(self) -> None: + tickets = [ + TicketSummary( + ticket_key="A-1", + stage_durations={"prd": 60.0, "implementation": 3600.0, "ci": 300.0}, + ), + ] + result = _calculate_bottlenecks(tickets) + assert result.slowest_stage == "implementation" + + def test_stages_only_in_some_tickets(self) -> None: + tickets = [ + TicketSummary(ticket_key="A-1", stage_durations={"prd": 60.0, "spec": 90.0}), + TicketSummary(ticket_key="A-2", stage_durations={"prd": 120.0}), + ] + result = _calculate_bottlenecks(tickets) + # prd averaged across both; spec only from A-1 + assert result.avg_stage_durations["prd"] == pytest.approx(90.0) + assert result.avg_stage_durations["spec"] == pytest.approx(90.0) + + +# --------------------------------------------------------------------------- +# _is_within_window +# --------------------------------------------------------------------------- + + +class TestIsWithinWindow: + def _cutoff(self) -> datetime: + return _NOW - timedelta(days=7) + + def test_updated_at_within_window(self) -> None: + state = {"updated_at": _ONE_DAY_AGO} + assert _is_within_window(state, self._cutoff()) is True + + def test_updated_at_outside_window(self) -> None: + state = {"updated_at": _TWO_WEEKS_AGO} + assert _is_within_window(state, self._cutoff()) is False + + def test_stage_started_at_within_window(self) -> None: + state = { + "updated_at": _TWO_WEEKS_AGO, + "stage_timestamps": {"prd": {"started_at": _ONE_DAY_AGO, "ended_at": None}}, + } + assert _is_within_window(state, self._cutoff()) is True + + def test_stage_ended_at_within_window(self) -> None: + state = { + "updated_at": _TWO_WEEKS_AGO, + "stage_timestamps": {"prd": {"started_at": _TWO_WEEKS_AGO, "ended_at": _ONE_DAY_AGO}}, + } + assert _is_within_window(state, self._cutoff()) is True + + def test_all_timestamps_outside_window(self) -> None: + state = { + "updated_at": _TWO_WEEKS_AGO, + "stage_timestamps": {"prd": {"started_at": _TWO_WEEKS_AGO, "ended_at": _TWO_WEEKS_AGO}}, + } + assert _is_within_window(state, self._cutoff()) is False + + def test_no_timestamps(self) -> None: + state = {"stage_timestamps": {}} + assert _is_within_window(state, self._cutoff()) is False + + def test_missing_stage_timestamps(self) -> None: + state = {"updated_at": _TWO_WEEKS_AGO} + assert _is_within_window(state, self._cutoff()) is False + + def test_malformed_stage_timestamps(self) -> None: + state = {"stage_timestamps": "bad", "updated_at": _TWO_WEEKS_AGO} + assert _is_within_window(state, self._cutoff()) is False + + +# --------------------------------------------------------------------------- +# _aggregate_tokens +# --------------------------------------------------------------------------- + + +class TestAggregateTokens: + def test_empty_list(self) -> None: + total_in, total_out, by_stage = _aggregate_tokens([]) + assert total_in == 0 + assert total_out == 0 + assert by_stage == {} + + def test_single_ticket(self) -> None: + ticket = TicketSummary( + ticket_key="A-1", + input_tokens=1000, + output_tokens=500, + tokens_by_stage={"prd": (1000, 500)}, + ) + total_in, total_out, by_stage = _aggregate_tokens([ticket]) + assert total_in == 1000 + assert total_out == 500 + assert by_stage["prd"] == (1000, 500) + + def test_multiple_tickets_same_stage(self) -> None: + t1 = TicketSummary( + ticket_key="A-1", + input_tokens=300, + output_tokens=100, + tokens_by_stage={"prd": (300, 100)}, + ) + t2 = TicketSummary( + ticket_key="A-2", + input_tokens=200, + output_tokens=150, + tokens_by_stage={"prd": (200, 150)}, + ) + total_in, total_out, by_stage = _aggregate_tokens([t1, t2]) + assert total_in == 500 + assert total_out == 250 + assert by_stage["prd"] == (500, 250) + + def test_multiple_stages(self) -> None: + ticket = TicketSummary( + ticket_key="A-1", + input_tokens=700, + output_tokens=350, + tokens_by_stage={"prd": (300, 150), "spec": (400, 200)}, + ) + total_in, total_out, by_stage = _aggregate_tokens([ticket]) + assert total_in == 700 + assert total_out == 350 + assert by_stage["prd"] == (300, 150) + assert by_stage["spec"] == (400, 200) + + +# --------------------------------------------------------------------------- +# _avg_cycle_time +# --------------------------------------------------------------------------- + + +class TestAvgCycleTime: + def test_empty_list(self) -> None: + assert _avg_cycle_time([]) is None + + def test_no_completed_tickets(self) -> None: + tickets = [TicketSummary(ticket_key="A-1", status="in_progress", duration_seconds=100.0)] + assert _avg_cycle_time(tickets) is None + + def test_single_completed_ticket(self) -> None: + tickets = [TicketSummary(ticket_key="A-1", status="completed", duration_seconds=3600.0)] + assert _avg_cycle_time(tickets) == pytest.approx(3600.0) + + def test_multiple_completed_tickets(self) -> None: + tickets = [ + TicketSummary(ticket_key="A-1", status="completed", duration_seconds=3600.0), + TicketSummary(ticket_key="A-2", status="completed", duration_seconds=7200.0), + ] + assert _avg_cycle_time(tickets) == pytest.approx(5400.0) + + def test_completed_ticket_without_duration(self) -> None: + tickets = [ + TicketSummary(ticket_key="A-1", status="completed", duration_seconds=None), + TicketSummary(ticket_key="A-2", status="completed", duration_seconds=3600.0), + ] + assert _avg_cycle_time(tickets) == pytest.approx(3600.0) + + def test_mixed_statuses_only_completed_counted(self) -> None: + tickets = [ + TicketSummary(ticket_key="A-1", status="completed", duration_seconds=3600.0), + TicketSummary(ticket_key="A-2", status="in_progress", duration_seconds=1800.0), + TicketSummary(ticket_key="A-3", status="blocked", duration_seconds=7200.0), + ] + assert _avg_cycle_time(tickets) == pytest.approx(3600.0) + + +# --------------------------------------------------------------------------- +# collect_weekly_data — integration with mocked Redis +# --------------------------------------------------------------------------- + + +def _make_redis_mock(keys: list[str], states: dict[str, dict]) -> MagicMock: + """Build a fake async Redis client that returns the given keys and states.""" + mock = MagicMock() + + # scan returns (cursor, keys_list); call it once and return 0 to stop loop + async def scan_side_effect(cursor, match, count): + _ = count + if cursor == 0: + # Filter keys by match pattern (simple prefix check) + prefix = match.rstrip("*") + filtered = [k for k in keys if k.startswith(prefix)] + return (0, filtered) + return (0, []) + + mock.scan = AsyncMock(side_effect=scan_side_effect) + + async def get_side_effect(key): + state = states.get(key) + if state is None: + return None + return json.dumps(state) + + mock.get = AsyncMock(side_effect=get_side_effect) + return mock + + +@pytest.fixture +def _redis_mock_with_data(): + """Fixture providing a Redis mock with two checkpoints in the window.""" + ticket1 = "AISOS-1" + ticket2 = "AISOS-2" + key1 = f"checkpoint:{ticket1}" + key2 = f"checkpoint:{ticket2}" + state1 = _make_state( + ticket_key=ticket1, + workflow_outcome="Completed", + stage_timestamps={ + "prd": _make_stage_data( + stage_name="prd", + input_tokens=300, + output_tokens=150, + started_at=_ONE_DAY_AGO, + ended_at=_ONE_DAY_AGO, + machine_time_seconds=60.0, + iteration_count=1, + ) + }, + stats_ci_cycles=0, + ) + state2 = _make_state( + ticket_key=ticket2, + workflow_outcome=None, + is_blocked=False, + stage_timestamps={ + "prd": _make_stage_data( + stage_name="prd", + input_tokens=200, + output_tokens=100, + started_at=_ONE_DAY_AGO, + ended_at=None, + machine_time_seconds=30.0, + iteration_count=2, + ) + }, + stats_ci_cycles=1, + ) + redis_mock = _make_redis_mock( + keys=[key1, key2], + states={key1: state1, key2: state2}, + ) + return redis_mock + + +def _patch_now(fixed_now: datetime): + """Context manager that patches datetime.now(UTC) in the weekly_report module. + + Replaces the ``datetime`` name in the weekly_report module with a subclass + whose ``now()`` classmethod always returns *fixed_now*. All other + ``datetime`` functionality (fromisoformat, arithmetic, etc.) is inherited + unchanged. + """ + + class _FakeDatetime(datetime): + @classmethod + def now(cls, _tz=None): # type: ignore[override] + return fixed_now + + return patch("forge.workflow.stats.weekly_report.datetime", _FakeDatetime) + + +class TestCollectWeeklyData: + @pytest.mark.asyncio + async def test_returns_weekly_report_data(self, _redis_mock_with_data) -> None: + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=_redis_mock_with_data), + ), + _patch_now(_NOW), + ): + report = await collect_weekly_data("AISOS", days=7) + assert isinstance(report, WeeklyReportData) + + @pytest.mark.asyncio + async def test_project_and_period_fields(self, _redis_mock_with_data) -> None: + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=_redis_mock_with_data), + ), + _patch_now(_NOW), + ): + report = await collect_weekly_data("AISOS", days=14) + assert report.project == "AISOS" + assert report.period_days == 14 + + @pytest.mark.asyncio + async def test_completed_and_in_progress_split(self, _redis_mock_with_data) -> None: + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=_redis_mock_with_data), + ), + _patch_now(_NOW), + ): + report = await collect_weekly_data("AISOS") + assert len(report.completed_tickets) == 1 + assert len(report.in_progress_tickets) == 1 + assert len(report.blocked_tickets) == 0 + assert report.completed_tickets[0].ticket_key == "AISOS-1" + assert report.in_progress_tickets[0].ticket_key == "AISOS-2" + + @pytest.mark.asyncio + async def test_token_aggregation(self, _redis_mock_with_data) -> None: + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=_redis_mock_with_data), + ), + _patch_now(_NOW), + ): + report = await collect_weekly_data("AISOS") + assert report.total_input_tokens == 500 # 300 + 200 + assert report.total_output_tokens == 250 # 150 + 100 + + @pytest.mark.asyncio + async def test_bottlenecks_populated(self, _redis_mock_with_data) -> None: + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=_redis_mock_with_data), + ), + _patch_now(_NOW), + ): + report = await collect_weekly_data("AISOS") + assert report.bottlenecks.total_tickets_analyzed == 2 + assert "prd" in report.bottlenecks.avg_stage_durations + + @pytest.mark.asyncio + async def test_avg_cycle_time_computed(self, _redis_mock_with_data) -> None: + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=_redis_mock_with_data), + ), + _patch_now(_NOW), + ): + report = await collect_weekly_data("AISOS") + # Only the completed ticket has an ended_at timestamp; avg_cycle_time + # should be non-None for the completed one. + assert report.avg_cycle_time is not None + + @pytest.mark.asyncio + async def test_empty_project_returns_zero_report(self) -> None: + redis_mock = _make_redis_mock(keys=[], states={}) + with patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis_mock), + ): + report = await collect_weekly_data("EMPTY") + assert report.completed_tickets == [] + assert report.in_progress_tickets == [] + assert report.blocked_tickets == [] + assert report.total_input_tokens == 0 + assert report.avg_cycle_time is None + + @pytest.mark.asyncio + async def test_tickets_outside_window_excluded(self) -> None: + ticket_key = "AISOS-99" + redis_key = f"checkpoint:{ticket_key}" + # All timestamps are two weeks ago — outside a 7-day window + old_state = _make_state( + ticket_key=ticket_key, + workflow_outcome="Completed", + updated_at=_TWO_WEEKS_AGO, + stage_timestamps={ + "prd": _make_stage_data(started_at=_TWO_WEEKS_AGO, ended_at=_TWO_WEEKS_AGO) + }, + ) + redis_mock = _make_redis_mock(keys=[redis_key], states={redis_key: old_state}) + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis_mock), + ), + _patch_now(_NOW), + ): + report = await collect_weekly_data("AISOS", days=7) + assert report.all_tickets == [] + + @pytest.mark.asyncio + async def test_blocked_ticket_categorised(self) -> None: + ticket_key = "AISOS-77" + redis_key = f"checkpoint:{ticket_key}" + state = _make_state( + ticket_key=ticket_key, + workflow_outcome=None, + is_blocked=True, + updated_at=_ONE_DAY_AGO, + ) + redis_mock = _make_redis_mock(keys=[redis_key], states={redis_key: state}) + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis_mock), + ), + _patch_now(_NOW), + ): + report = await collect_weekly_data("AISOS") + assert len(report.blocked_tickets) == 1 + assert report.blocked_tickets[0].ticket_key == ticket_key + + @pytest.mark.asyncio + async def test_malformed_json_skipped(self) -> None: + redis_key = "checkpoint:AISOS-BAD" + mock = MagicMock() + + async def scan_side_effect(cursor, match, count): + _ = (match, count) + if cursor == 0: + return (0, [redis_key]) + return (0, []) + + mock.scan = AsyncMock(side_effect=scan_side_effect) + mock.get = AsyncMock(return_value="not-valid-json{{{{") + + with patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=mock), + ): + report = await collect_weekly_data("AISOS") + # Should not raise; simply skips the malformed key + assert report.all_tickets == [] + + @pytest.mark.asyncio + async def test_redis_scan_failure_returns_empty_report(self) -> None: + mock = MagicMock() + mock.scan = AsyncMock(side_effect=ConnectionError("Redis down")) + + with patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=mock), + ): + report = await collect_weekly_data("AISOS") + assert report.all_tickets == [] + + @pytest.mark.asyncio + async def test_report_start_end_populated(self) -> None: + redis_mock = _make_redis_mock(keys=[], states={}) + with patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=redis_mock), + ): + report = await collect_weekly_data("AISOS", days=7) + assert report.report_start != "" + assert report.report_end != "" + # Both should be parseable ISO-8601 + start = datetime.fromisoformat(report.report_start) + end = datetime.fromisoformat(report.report_end) + assert (end - start).days == 7 + + @pytest.mark.asyncio + async def test_all_tickets_field_populated(self, _redis_mock_with_data) -> None: + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=_redis_mock_with_data), + ), + _patch_now(_NOW), + ): + report = await collect_weekly_data("AISOS") + assert len(report.all_tickets) == 2 + + @pytest.mark.asyncio + async def test_tokens_by_stage_populated(self, _redis_mock_with_data) -> None: + with ( + patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=_redis_mock_with_data), + ), + _patch_now(_NOW), + ): + report = await collect_weekly_data("AISOS") + assert "prd" in report.tokens_by_stage + total_in, total_out = report.tokens_by_stage["prd"] + assert total_in == 500 # 300 + 200 + assert total_out == 250 # 150 + 100 + + @pytest.mark.asyncio + async def test_null_value_from_redis_skipped(self) -> None: + redis_key = "checkpoint:AISOS-NULL" + mock = MagicMock() + + async def scan_side_effect(cursor, match, count): + _ = (match, count) + if cursor == 0: + return (0, [redis_key]) + return (0, []) + + mock.scan = AsyncMock(side_effect=scan_side_effect) + mock.get = AsyncMock(return_value=None) + + with patch( + "forge.workflow.stats.weekly_report.get_redis_client", + new=AsyncMock(return_value=mock), + ): + report = await collect_weekly_data("AISOS") + assert report.all_tickets == [] + + +# --------------------------------------------------------------------------- +# Import path checks +# --------------------------------------------------------------------------- + + +class TestImports: + def test_public_symbols_importable(self) -> None: + from forge.workflow.stats.weekly_report import ( # noqa: F401 + BottleneckAnalysis, + TicketSummary, + WeeklyReportData, + collect_weekly_data, + ) + + def test_internal_helpers_importable(self) -> None: + from forge.workflow.stats.weekly_report import ( # noqa: F401 + _aggregate_tokens, + _avg_cycle_time, + _calculate_bottlenecks, + _is_within_window, + _parse_checkpoint_stats, + ) diff --git a/tests/unit/workflow/test_stats.py b/tests/unit/workflow/test_stats.py new file mode 100644 index 00000000..67204e1d --- /dev/null +++ b/tests/unit/workflow/test_stats.py @@ -0,0 +1,418 @@ +"""Unit tests for StageStats, StatsState TypedDicts, and stage constants.""" + +from typing import get_type_hints + +import pytest + + +class TestStageStats: + """Tests for StageStats TypedDict.""" + + def test_stage_stats_has_all_required_fields(self): + """StageStats defines every field required by SC-001.""" + from forge.workflow.stats import StageStats + + hints = get_type_hints(StageStats) + + assert "stage_name" in hints + assert "iteration_count" in hints + assert "machine_time_seconds" in hints + assert "input_tokens" in hints + assert "output_tokens" in hints + assert "started_at" in hints + assert "ended_at" in hints + + def test_stage_stats_field_types(self): + """StageStats fields carry the correct type annotations.""" + from forge.workflow.stats import StageStats + + hints = get_type_hints(StageStats) + + assert hints["stage_name"] is str + assert hints["iteration_count"] is int + assert hints["machine_time_seconds"] is float + assert hints["input_tokens"] is int + assert hints["output_tokens"] is int + + def test_stage_stats_nullable_timestamps(self): + """started_at and ended_at accept None (X | None convention).""" + from forge.workflow.stats import StageStats + + hints = get_type_hints(StageStats, include_extras=False) + + # Under Python 3.11+ X | None becomes types.UnionType. + # str(str | None) is 'str | None' on 3.10+ union syntax. + started_hint = str(hints["started_at"]) + ended_hint = str(hints["ended_at"]) + + assert "str" in started_hint + assert "None" in started_hint + assert "str" in ended_hint + assert "None" in ended_hint + + def test_stage_stats_is_total_false(self): + """StageStats allows partial initialisation.""" + from forge.workflow.stats import StageStats + + # Should not raise — total=False makes all keys optional + partial: StageStats = {"stage_name": "implement", "iteration_count": 1} + assert partial["stage_name"] == "implement" + assert partial["iteration_count"] == 1 + + def test_stage_stats_full_construction(self): + """StageStats can be constructed with all fields populated.""" + from forge.workflow.stats import StageStats + + stats: StageStats = { + "stage_name": "implement", + "iteration_count": 3, + "machine_time_seconds": 120.5, + "human_time_seconds": 300.0, + "input_tokens": 4096, + "output_tokens": 2048, + "started_at": "2024-01-01T00:00:00Z", + "ended_at": "2024-01-01T00:07:00Z", + } + + assert stats["stage_name"] == "implement" + assert stats["iteration_count"] == 3 + assert stats["machine_time_seconds"] == 120.5 + assert stats["human_time_seconds"] == 300.0 + assert stats["input_tokens"] == 4096 + assert stats["output_tokens"] == 2048 + assert stats["started_at"] == "2024-01-01T00:00:00Z" + assert stats["ended_at"] == "2024-01-01T00:07:00Z" + + def test_stage_stats_nullable_timestamps_accept_none(self): + """started_at and ended_at can be explicitly set to None.""" + from forge.workflow.stats import StageStats + + stats: StageStats = { + "stage_name": "triage", + "started_at": None, + "ended_at": None, + } + assert stats["started_at"] is None + assert stats["ended_at"] is None + + +class TestStatsState: + """Tests for StatsState TypedDict mixin.""" + + def test_stats_state_has_all_required_fields(self): + """StatsState defines all workflow-level statistics fields.""" + from forge.workflow.stats import StatsState + + hints = get_type_hints(StatsState) + + assert "stage_timestamps" in hints + assert "stats_pr_urls" in hints + assert "stats_ci_cycles" in hints + assert "workflow_outcome" in hints + assert "stats_outcome_reason" in hints + assert "stats_comment_posted" in hints + + def test_stats_state_is_total_false(self): + """StatsState allows partial initialisation.""" + from forge.workflow.stats import StatsState + + partial: StatsState = {"stats_ci_cycles": 0} + assert partial["stats_ci_cycles"] == 0 + + def test_stats_state_nullable_outcome_fields(self): + """workflow_outcome and stats_outcome_reason accept None.""" + from forge.workflow.stats import StatsState + + hints = get_type_hints(StatsState, include_extras=False) + + outcome_hint = str(hints["workflow_outcome"]) + reason_hint = str(hints["stats_outcome_reason"]) + + assert "str" in outcome_hint + assert "None" in outcome_hint + assert "str" in reason_hint + assert "None" in reason_hint + + def test_stats_state_full_construction(self): + """StatsState can be constructed with all fields populated.""" + from forge.workflow.stats import StageStats, StatsState + + stage: StageStats = { + "stage_name": "implement", + "iteration_count": 2, + "machine_time_seconds": 60.0, + "human_time_seconds": 0.0, + "input_tokens": 1000, + "output_tokens": 500, + "started_at": "2024-01-01T00:00:00Z", + "ended_at": "2024-01-01T00:01:00Z", + } + + state: StatsState = { + "stage_timestamps": {"implement": stage}, + "stats_pr_urls": ["https://github.com/org/repo/pull/42"], + "stats_ci_cycles": 1, + "workflow_outcome": "Completed", + "stats_outcome_reason": None, + "stats_comment_posted": True, + } + + assert state["stage_timestamps"]["implement"]["stage_name"] == "implement" + assert state["stats_pr_urls"] == ["https://github.com/org/repo/pull/42"] + assert state["stats_ci_cycles"] == 1 + assert state["workflow_outcome"] == "Completed" + assert state["stats_outcome_reason"] is None + assert state["stats_comment_posted"] is True + + @pytest.mark.parametrize( + "outcome", + [ + "Completed", + "Blocked: waiting for human approval", + "Failed: unrecoverable CI failure", + ], + ) + def test_stats_state_valid_outcome_values(self, outcome: str): + """workflow_outcome accepts the three documented outcome patterns.""" + from forge.workflow.stats import StatsState + + state: StatsState = {"workflow_outcome": outcome} + assert state["workflow_outcome"] == outcome + + def test_stats_state_comment_posted_defaults_pattern(self): + """stats_comment_posted is a bool field.""" + from forge.workflow.stats import StatsState + + hints = get_type_hints(StatsState) + assert hints["stats_comment_posted"] is bool + + def test_stage_timestamps_is_dict_of_stage_stats(self): + """stage_timestamps maps string keys to StageStats dicts.""" + from forge.workflow.stats import StageStats, StatsState + + s1: StageStats = {"stage_name": "triage", "iteration_count": 1} + s2: StageStats = {"stage_name": "implement", "iteration_count": 3} + + state: StatsState = {"stage_timestamps": {"triage": s1, "implement": s2}} + assert len(state["stage_timestamps"]) == 2 + assert state["stage_timestamps"]["triage"]["stage_name"] == "triage" + assert state["stage_timestamps"]["implement"]["iteration_count"] == 3 + + +class TestStatsStateExportedFromPackage: + """Verify the new types are accessible via the workflow package.""" + + def test_stage_stats_importable_from_workflow(self): + """StageStats is exported from forge.workflow.""" + from forge.workflow import StageStats # noqa: F401 + + def test_stats_state_importable_from_workflow(self): + """StatsState is exported from forge.workflow.""" + from forge.workflow import StatsState # noqa: F401 + + def test_stats_state_importable_from_base(self): + """StatsState is importable via forge.workflow.base (re-exported).""" + from forge.workflow.base import StatsState # noqa: F401 + + +class TestStageConstants: + """Tests for workflow stage name constants and ordered stage lists.""" + + # ------------------------------------------------------------------ + # Individual constant values + # ------------------------------------------------------------------ + + def test_stage_prd_value(self): + from forge.workflow.stats import STAGE_PRD + + assert STAGE_PRD == "prd" + + def test_stage_spec_value(self): + from forge.workflow.stats import STAGE_SPEC + + assert STAGE_SPEC == "spec" + + def test_stage_epics_value(self): + from forge.workflow.stats import STAGE_EPICS + + assert STAGE_EPICS == "epics" + + def test_stage_tasks_value(self): + from forge.workflow.stats import STAGE_TASKS + + assert STAGE_TASKS == "tasks" + + def test_stage_implementation_value(self): + from forge.workflow.stats import STAGE_IMPLEMENTATION + + assert STAGE_IMPLEMENTATION == "implementation" + + def test_stage_ci_value(self): + from forge.workflow.stats import STAGE_CI + + assert STAGE_CI == "ci" + + def test_stage_review_value(self): + from forge.workflow.stats import STAGE_REVIEW + + assert STAGE_REVIEW == "review" + + def test_stage_rca_value(self): + from forge.workflow.stats import STAGE_RCA + + assert STAGE_RCA == "rca" + + def test_stage_triage_value(self): + from forge.workflow.stats import STAGE_TRIAGE + + assert STAGE_TRIAGE == "triage" + + def test_stage_planning_value(self): + from forge.workflow.stats import STAGE_PLANNING + + assert STAGE_PLANNING == "planning" + + # ------------------------------------------------------------------ + # ALL_FEATURE_STAGES list + # ------------------------------------------------------------------ + + def test_all_feature_stages_is_list(self): + """ALL_FEATURE_STAGES is a list of strings.""" + from forge.workflow.stats import ALL_FEATURE_STAGES + + assert isinstance(ALL_FEATURE_STAGES, list) + assert all(isinstance(s, str) for s in ALL_FEATURE_STAGES) + + def test_all_feature_stages_length(self): + """ALL_FEATURE_STAGES contains exactly 7 stages.""" + from forge.workflow.stats import ALL_FEATURE_STAGES + + assert len(ALL_FEATURE_STAGES) == 7 + + def test_all_feature_stages_order(self): + """ALL_FEATURE_STAGES lists stages in the canonical display order.""" + from forge.workflow.stats import ( + ALL_FEATURE_STAGES, + STAGE_CI, + STAGE_EPICS, + STAGE_IMPLEMENTATION, + STAGE_PRD, + STAGE_REVIEW, + STAGE_SPEC, + STAGE_TASKS, + ) + + assert ALL_FEATURE_STAGES == [ + STAGE_PRD, + STAGE_SPEC, + STAGE_EPICS, + STAGE_TASKS, + STAGE_IMPLEMENTATION, + STAGE_CI, + STAGE_REVIEW, + ] + + def test_all_feature_stages_completeness(self): + """ALL_FEATURE_STAGES contains every expected Feature stage.""" + from forge.workflow.stats import ( + ALL_FEATURE_STAGES, + STAGE_CI, + STAGE_EPICS, + STAGE_IMPLEMENTATION, + STAGE_PRD, + STAGE_REVIEW, + STAGE_SPEC, + STAGE_TASKS, + ) + + expected = { + STAGE_PRD, + STAGE_SPEC, + STAGE_EPICS, + STAGE_TASKS, + STAGE_IMPLEMENTATION, + STAGE_CI, + STAGE_REVIEW, + } + assert set(ALL_FEATURE_STAGES) == expected + + # ------------------------------------------------------------------ + # ALL_BUG_STAGES list + # ------------------------------------------------------------------ + + def test_all_bug_stages_is_list(self): + """ALL_BUG_STAGES is a list of strings.""" + from forge.workflow.stats import ALL_BUG_STAGES + + assert isinstance(ALL_BUG_STAGES, list) + assert all(isinstance(s, str) for s in ALL_BUG_STAGES) + + def test_all_bug_stages_length(self): + """ALL_BUG_STAGES contains exactly 6 stages.""" + from forge.workflow.stats import ALL_BUG_STAGES + + assert len(ALL_BUG_STAGES) == 6 + + def test_all_bug_stages_order(self): + """ALL_BUG_STAGES lists stages in the canonical display order.""" + from forge.workflow.stats import ( + ALL_BUG_STAGES, + STAGE_CI, + STAGE_IMPLEMENTATION, + STAGE_PLANNING, + STAGE_RCA, + STAGE_REVIEW, + STAGE_TRIAGE, + ) + + assert ALL_BUG_STAGES == [ + STAGE_TRIAGE, + STAGE_RCA, + STAGE_PLANNING, + STAGE_IMPLEMENTATION, + STAGE_CI, + STAGE_REVIEW, + ] + + def test_all_bug_stages_completeness(self): + """ALL_BUG_STAGES contains every expected Bug stage.""" + from forge.workflow.stats import ( + ALL_BUG_STAGES, + STAGE_CI, + STAGE_IMPLEMENTATION, + STAGE_PLANNING, + STAGE_RCA, + STAGE_REVIEW, + STAGE_TRIAGE, + ) + + expected = { + STAGE_TRIAGE, + STAGE_RCA, + STAGE_PLANNING, + STAGE_IMPLEMENTATION, + STAGE_CI, + STAGE_REVIEW, + } + assert set(ALL_BUG_STAGES) == expected + + # ------------------------------------------------------------------ + # Export verification + # ------------------------------------------------------------------ + + def test_constants_importable_from_stats_module(self): + """All stage constants and lists are importable from forge.workflow.stats.""" + from forge.workflow.stats import ( # noqa: F401 + ALL_BUG_STAGES, + ALL_FEATURE_STAGES, + STAGE_CI, + STAGE_EPICS, + STAGE_IMPLEMENTATION, + STAGE_PLANNING, + STAGE_PRD, + STAGE_RCA, + STAGE_REVIEW, + STAGE_SPEC, + STAGE_TASKS, + STAGE_TRIAGE, + ) diff --git a/tests/unit/workflow/test_stats_utils.py b/tests/unit/workflow/test_stats_utils.py new file mode 100644 index 00000000..9fc873f9 --- /dev/null +++ b/tests/unit/workflow/test_stats_utils.py @@ -0,0 +1,364 @@ +"""Unit tests for forge.workflow.stats_utils.""" + +import pytest + +from forge.workflow.stats_utils import ( + add_pr_url, + increment_ci_cycle, + increment_revision, + record_stage_end, + record_stage_start, + record_tokens, + set_outcome, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _empty_state() -> dict: + """Return a minimal state with stats fields unset (simulates fresh run).""" + return {} + + +def _state_with_stage(stage_name: str, **overrides) -> dict: + """Return a state that already has one stage entry.""" + stage = { + "stage_name": stage_name, + "iteration_count": 0, + "machine_time_seconds": 0.0, + "input_tokens": 0, + "output_tokens": 0, + "started_at": "2024-01-01T00:00:00+00:00", + "ended_at": None, + } + stage.update(overrides) + return {"stage_timestamps": {stage_name: stage}} + + +# --------------------------------------------------------------------------- +# record_stage_start +# --------------------------------------------------------------------------- + + +class TestRecordStageStart: + def test_initialises_stage_with_timestamp(self): + result = record_stage_start(_empty_state(), "implement") + + assert "stage_timestamps" in result + stage = result["stage_timestamps"]["implement"] + assert stage["started_at"] is not None + assert "T" in stage["started_at"] # ISO-8601 + + def test_zeroed_numeric_metrics(self): + result = record_stage_start(_empty_state(), "implement") + stage = result["stage_timestamps"]["implement"] + + assert stage["iteration_count"] == 0 + assert stage["machine_time_seconds"] == 0.0 + assert stage["input_tokens"] == 0 + assert stage["output_tokens"] == 0 + + def test_ended_at_is_none_on_init(self): + result = record_stage_start(_empty_state(), "implement") + assert result["stage_timestamps"]["implement"]["ended_at"] is None + + def test_stage_name_recorded(self): + result = record_stage_start(_empty_state(), "triage") + assert result["stage_timestamps"]["triage"]["stage_name"] == "triage" + + def test_resets_ended_at_on_re_entry(self): + """Re-entering a stage clears ended_at (marks it in-progress again).""" + state = _state_with_stage("implement", ended_at="2024-01-01T01:00:00+00:00") + result = record_stage_start(state, "implement") + assert result["stage_timestamps"]["implement"]["ended_at"] is None + + def test_preserves_accumulated_metrics_on_re_entry(self): + """Re-entering should not zero out previously accumulated tokens.""" + state = _state_with_stage( + "implement", + input_tokens=500, + output_tokens=250, + machine_time_seconds=30.0, + ) + result = record_stage_start(state, "implement") + stage = result["stage_timestamps"]["implement"] + + assert stage["input_tokens"] == 500 + assert stage["output_tokens"] == 250 + assert stage["machine_time_seconds"] == 30.0 + + def test_handles_missing_stage_timestamps_key(self): + """Works when state has no stage_timestamps key at all.""" + result = record_stage_start({}, "plan") + assert "plan" in result["stage_timestamps"] + + def test_does_not_mutate_existing_stages(self): + """Other stages in stage_timestamps are preserved.""" + state = _state_with_stage("triage") + result = record_stage_start(state, "implement") + + assert "triage" in result["stage_timestamps"] + assert "implement" in result["stage_timestamps"] + + def test_returns_only_stage_timestamps_key(self): + result = record_stage_start(_empty_state(), "implement") + assert list(result.keys()) == ["stage_timestamps"] + + def test_model_name_recorded_when_provided(self): + result = record_stage_start(_empty_state(), "prd", model_name="claude-sonnet-4-5") + stage = result["stage_timestamps"]["prd"] + assert stage["model_name"] == "claude-sonnet-4-5" + + def test_model_name_defaults_to_none(self): + result = record_stage_start(_empty_state(), "implement") + stage = result["stage_timestamps"]["implement"] + assert stage["model_name"] is None + + def test_model_name_none_explicitly(self): + result = record_stage_start(_empty_state(), "ci", model_name=None) + stage = result["stage_timestamps"]["ci"] + assert stage["model_name"] is None + + def test_model_name_set_on_re_entry(self): + """Model name should be updated when re-entering an existing stage.""" + state = _state_with_stage("implement") + result = record_stage_start(state, "implement", model_name="gemini-2.5-flash") + stage = result["stage_timestamps"]["implement"] + assert stage["model_name"] == "gemini-2.5-flash" + + +# --------------------------------------------------------------------------- +# record_stage_end +# --------------------------------------------------------------------------- + + +class TestRecordStageEnd: + def test_sets_ended_at_timestamp(self): + state = _state_with_stage("implement") + result = record_stage_end(state, "implement", machine_time=60.0) + + assert result["stage_timestamps"]["implement"]["ended_at"] is not None + + def test_accumulates_machine_time(self): + state = _state_with_stage("implement", machine_time_seconds=10.0) + result = record_stage_end(state, "implement", machine_time=25.5) + + assert result["stage_timestamps"]["implement"]["machine_time_seconds"] == pytest.approx( + 35.5 + ) + + def test_handles_non_existent_stage(self): + """Calling on a stage that was never started should not raise.""" + result = record_stage_end(_empty_state(), "ghost_stage", machine_time=5.0) + + stage = result["stage_timestamps"]["ghost_stage"] + assert stage["machine_time_seconds"] == pytest.approx(5.0) + assert stage["ended_at"] is not None + + def test_returns_only_stage_timestamps_key(self): + state = _state_with_stage("implement") + result = record_stage_end(state, "implement", machine_time=1.0) + assert list(result.keys()) == ["stage_timestamps"] + + +# --------------------------------------------------------------------------- +# record_tokens +# --------------------------------------------------------------------------- + + +class TestRecordTokens: + def test_accumulates_input_tokens(self): + state = _state_with_stage("implement", input_tokens=100) + result = record_tokens(state, "implement", input_tokens=200, output_tokens=0) + + assert result["stage_timestamps"]["implement"]["input_tokens"] == 300 + + def test_accumulates_output_tokens(self): + state = _state_with_stage("implement", output_tokens=50) + result = record_tokens(state, "implement", input_tokens=0, output_tokens=75) + + assert result["stage_timestamps"]["implement"]["output_tokens"] == 125 + + def test_accumulates_both_simultaneously(self): + state = _state_with_stage("implement", input_tokens=10, output_tokens=5) + result = record_tokens(state, "implement", input_tokens=20, output_tokens=10) + + stage = result["stage_timestamps"]["implement"] + assert stage["input_tokens"] == 30 + assert stage["output_tokens"] == 15 + + def test_handles_non_existent_stage(self): + """Should initialise a new stage entry if it does not exist.""" + result = record_tokens(_empty_state(), "new_stage", input_tokens=50, output_tokens=25) + + stage = result["stage_timestamps"]["new_stage"] + assert stage["input_tokens"] == 50 + assert stage["output_tokens"] == 25 + + def test_does_not_replace_tokens(self): + """Calling twice should add, not replace.""" + state = _state_with_stage("implement") + first = record_tokens(state, "implement", input_tokens=100, output_tokens=50) + second = record_tokens(first, "implement", input_tokens=100, output_tokens=50) + + assert second["stage_timestamps"]["implement"]["input_tokens"] == 200 + assert second["stage_timestamps"]["implement"]["output_tokens"] == 100 + + def test_returns_stage_timestamps_and_token_usage_keys(self): + result = record_tokens(_empty_state(), "impl", input_tokens=1, output_tokens=1) + assert "stage_timestamps" in result + assert "stage_token_usage" in result + assert "token_usage" in result + + +# --------------------------------------------------------------------------- +# increment_revision +# --------------------------------------------------------------------------- + + +class TestIncrementRevision: + def test_increments_iteration_count_by_one(self): + state = _state_with_stage("implement", iteration_count=2) + result = increment_revision(state, "implement") + + assert result["stage_timestamps"]["implement"]["iteration_count"] == 3 + + def test_starts_at_one_for_new_stage(self): + result = increment_revision(_empty_state(), "plan") + + assert result["stage_timestamps"]["plan"]["iteration_count"] == 1 + + def test_multiple_increments_accumulate(self): + state = _empty_state() + for _ in range(5): + state = {**state, **increment_revision(state, "implement")} + + assert state["stage_timestamps"]["implement"]["iteration_count"] == 5 + + def test_returns_stage_timestamps_and_revision_counts_keys(self): + result = increment_revision(_empty_state(), "triage") + assert "stage_timestamps" in result + assert "revision_counts" in result + + +# --------------------------------------------------------------------------- +# increment_ci_cycle +# --------------------------------------------------------------------------- + + +class TestIncrementCiCycle: + def test_increments_counter_from_zero(self): + result = increment_ci_cycle(_empty_state()) + assert result["stats_ci_cycles"] == 1 + + def test_increments_existing_counter(self): + state = {"stats_ci_cycles": 3} + result = increment_ci_cycle(state) + assert result["stats_ci_cycles"] == 4 + + def test_handles_none_counter(self): + state = {"stats_ci_cycles": None} + result = increment_ci_cycle(state) + assert result["stats_ci_cycles"] == 1 + + def test_multiple_increments(self): + state = _empty_state() + for _ in range(7): + state = {**state, **increment_ci_cycle(state)} + + assert state["stats_ci_cycles"] == 7 + + def test_returns_only_stats_ci_cycles_key(self): + result = increment_ci_cycle(_empty_state()) + assert list(result.keys()) == ["stats_ci_cycles"] + + +# --------------------------------------------------------------------------- +# add_pr_url +# --------------------------------------------------------------------------- + + +class TestAddPrUrl: + def test_appends_url_to_empty_list(self): + result = add_pr_url(_empty_state(), "https://github.com/org/repo/pull/1") + assert result["stats_pr_urls"] == ["https://github.com/org/repo/pull/1"] + + def test_appends_to_existing_list(self): + state = {"stats_pr_urls": ["https://github.com/org/repo/pull/1"]} + result = add_pr_url(state, "https://github.com/org/repo/pull/2") + + assert result["stats_pr_urls"] == [ + "https://github.com/org/repo/pull/1", + "https://github.com/org/repo/pull/2", + ] + + def test_idempotent_no_duplicates(self): + url = "https://github.com/org/repo/pull/1" + state = {"stats_pr_urls": [url]} + result = add_pr_url(state, url) + + assert result["stats_pr_urls"] == [url] + assert len(result["stats_pr_urls"]) == 1 + + def test_calling_twice_does_not_duplicate(self): + url = "https://github.com/org/repo/pull/42" + state = _empty_state() + state = {**state, **add_pr_url(state, url)} + state = {**state, **add_pr_url(state, url)} + + assert state["stats_pr_urls"].count(url) == 1 + + def test_handles_none_pr_urls(self): + state = {"stats_pr_urls": None} + result = add_pr_url(state, "https://example.com/pr/1") + assert result["stats_pr_urls"] == ["https://example.com/pr/1"] + + def test_returns_only_stats_pr_urls_key(self): + result = add_pr_url(_empty_state(), "https://example.com/pr/1") + assert list(result.keys()) == ["stats_pr_urls"] + + def test_preserves_order(self): + urls = [f"https://example.com/pr/{i}" for i in range(5)] + state = _empty_state() + for url in urls: + state = {**state, **add_pr_url(state, url)} + + assert state["stats_pr_urls"] == urls + + +# --------------------------------------------------------------------------- +# set_outcome +# --------------------------------------------------------------------------- + + +class TestSetOutcome: + def test_sets_outcome(self): + result = set_outcome(_empty_state(), "Completed") + assert result["workflow_outcome"] == "Completed" + + def test_sets_reason_when_provided(self): + result = set_outcome(_empty_state(), "Blocked: awaiting review", "PR still open") + assert result["workflow_outcome"] == "Blocked: awaiting review" + assert result["stats_outcome_reason"] == "PR still open" + + def test_reason_defaults_to_none(self): + result = set_outcome(_empty_state(), "Completed") + assert result["stats_outcome_reason"] is None + + def test_overwrites_previous_outcome(self): + state = {"workflow_outcome": "Blocked", "stats_outcome_reason": "old reason"} + result = set_outcome(state, "Completed", None) + + assert result["workflow_outcome"] == "Completed" + assert result["stats_outcome_reason"] is None + + def test_returns_both_keys(self): + result = set_outcome(_empty_state(), "Failed: timeout") + assert set(result.keys()) == {"workflow_outcome", "stats_outcome_reason"} + + @pytest.mark.parametrize("outcome", ["Completed", "Blocked: foo", "Failed: bar"]) + def test_conventional_outcome_values(self, outcome: str): + result = set_outcome(_empty_state(), outcome) + assert result["workflow_outcome"] == outcome diff --git a/zensical.toml b/zensical.toml index 8fcb748f..f6128760 100644 --- a/zensical.toml +++ b/zensical.toml @@ -11,6 +11,7 @@ nav = [ {"Getting Started" = "getting-started.md"}, {"User Guide" = [ {"Feature Workflow" = "guide/feature-workflow.md"}, + {"Weekly Reporting" = "guide/weekly-reporting.md"}, {"Bug Workflow" = "guide/bug-workflow.md"}, {"Jira Labels" = "guide/labels.md"}, {"PR Commands" = "guide/pr-commands.md"}, @@ -28,6 +29,7 @@ nav = [ ]}, {"Reference" = [ {"API Endpoints" = "reference/api.md"}, + {"CLI Reference" = "reference/cli.md"}, {"Configuration" = "reference/config.md"}, {"Proposals" = "reference/proposals.md"}, ]},