From bbfb6b04a2c48ed630af9a52d7e3c6aec4a48159 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Sat, 16 May 2026 21:02:28 +0530 Subject: [PATCH 1/3] moving from claude.ms to agets --- .claude/agents/celery-task-writer.md | 85 ++++++++++++++++ .claude/agents/convention-reviewer.md | 141 ++++++++++++++++++++++++++ .claude/agents/crud-writer.md | 88 ++++++++++++++++ .claude/agents/migration-writer.md | 78 ++++++++++++++ .claude/agents/model-writer.md | 118 +++++++++++++++++++++ .claude/agents/route-writer.md | 90 ++++++++++++++++ .claude/agents/service-writer.md | 103 +++++++++++++++++++ .claude/agents/test-writer.md | 91 +++++++++++++++++ CLAUDE.md | 74 ++++++-------- 9 files changed, 823 insertions(+), 45 deletions(-) create mode 100644 .claude/agents/celery-task-writer.md create mode 100644 .claude/agents/convention-reviewer.md create mode 100644 .claude/agents/crud-writer.md create mode 100644 .claude/agents/migration-writer.md create mode 100644 .claude/agents/model-writer.md create mode 100644 .claude/agents/route-writer.md create mode 100644 .claude/agents/service-writer.md create mode 100644 .claude/agents/test-writer.md diff --git a/.claude/agents/celery-task-writer.md b/.claude/agents/celery-task-writer.md new file mode 100644 index 000000000..9fc5fc33c --- /dev/null +++ b/.claude/agents/celery-task-writer.md @@ -0,0 +1,85 @@ +--- +name: celery-task-writer +description: Use when adding or modifying Celery tasks under `app/celery/tasks/`. Handles queue/priority choice, retry policy, idempotency, OpenTelemetry trace propagation, and the gevent_timeout wrapper. +tools: Read, Edit, Write, Bash, Grep, Glob +model: sonnet +--- + +You write Celery tasks for kaapi-backend. Tasks live in `app/celery/tasks/`. Celery uses RabbitMQ as broker and supports multiple priority queues. Read `app/celery/tasks/job_execution.py` before writing — it shows the full pattern (decorator + timeout + OTel propagation + delegation to a service). + +## Canonical decorator stack + +```python +@celery_app.task(bind=True, queue="high_priority", priority=9) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_my_job") +def run_my_job(self, project_id: int, job_id: str, trace_id: str, **kwargs): + from app.services.my_domain.jobs import do_the_work # late import to avoid cycles + + _set_trace(trace_id) + return _run_with_otel_parent( + self, + lambda: do_the_work( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), + ) +``` + +`_set_trace`, `_run_with_otel_parent`, and `gevent_timeout` already exist in this module / `app/celery/utils.py` — reuse them, don't reinvent. + +## Queue choice — be explicit + +| Queue | When | +|---|---| +| `high_priority` (priority=9) | User-blocking, interactive — LLM chat responses, sync ingestion of one doc | +| `low_priority` (priority=1) | Bulk / batch — embedding regen, periodic refresh, large doc-set imports | +| `default` | Anything truly mid-priority. Prefer one of the two above unless you have a reason. | + +Document the choice in a comment if it's not obvious. + +## Hard rules + +- **`bind=True`** so you have `self` (the task instance) for retries, IDs, etc. +- **Pass `trace_id` explicitly** as a parameter and call `_set_trace(trace_id)` first thing. This wires `asgi_correlation_id` so logs from inside the task match the originating request. +- **Wrap the work in `_run_with_otel_parent(self, lambda: ...)`** so OpenTelemetry parent context propagates from the enqueueing process. +- **Delegate to a service.** The task body should be a thin shim over `app/services//`. No DB queries, no external HTTP, no business logic inside the task itself. +- **Late-import the service inside the function body** (as the canonical pattern does). Celery workers boot faster and you avoid model-import cycles. +- **Idempotency.** Celery will redeliver. Either: + - The work is naturally idempotent (`UPDATE ... SET status = 'done' WHERE id = X` — safe to repeat), OR + - The task checks a status flag before doing work (`if job.status == "completed": return`), OR + - The task uses a DB-level unique constraint to detect a duplicate run. + Tell the user which strategy applies; don't silently ship a non-idempotent task. +- **Retries.** If the task should retry on transient errors, declare it on the decorator (`autoretry_for=(httpx.HTTPError,), retry_backoff=True, retry_kwargs={"max_retries": 3}`). Don't catch-and-re-raise. +- **No blocking calls in `async def`.** Celery tasks are sync; never mix. +- **Timeouts:** rely on the `@gevent_timeout(...)` decorator (or Celery's `soft_time_limit` / `time_limit` on the decorator). External HTTP inside the service should also have its own timeout. + +## Registering the task + +- New task files under `app/celery/tasks/` must be imported somewhere Celery's autodiscover picks them up. Read `app/celery/celery_app.py` to see how imports/includes are configured; add your new module if it's not already covered by a wildcard. +- The Celery Beat schedule (recurring tasks) lives in `app/celery/beat.py`. If your task should run on a cron, add the entry there. + +## Logging + +- `logger = logging.getLogger(__name__)` at the module top. +- Every log line prefixed `[task_name]` — e.g., `logger.info(f"[run_my_job] Starting | project_id: {project_id}, job_id: {job_id}")`. +- Log start, finish, and any retry. **Don't log payload contents** if they may contain PII / credentials. + +## What you DO NOT do + +- Don't write SQL or call CRUD directly from the task body. +- Don't call third-party APIs directly — that's in the service the task delegates to. +- Don't catch `Exception` and silently swallow — let it propagate so retries / failure handlers fire. +- Don't run `.delay(...)` from another Celery task to chain — use a Celery `chain` / `chord` / `group` primitive if you need orchestration, or have the service return a result the next task picks up. +- Don't use `time.sleep(...)` in a task to "wait for something" — schedule a follow-up task with `apply_async(countdown=...)`. + +## After writing + +Tell the user: +1. The task name(s) and the queue / priority chosen. +2. The service function it delegates to (path). +3. Whether Beat schedule needs an entry. +4. The idempotency strategy used. +5. How to invoke it locally for a smoke test (e.g., `uv run python -c "from app.celery.tasks.foo import run_my_job; run_my_job.delay(...)"`). diff --git a/.claude/agents/convention-reviewer.md b/.claude/agents/convention-reviewer.md new file mode 100644 index 000000000..3865a13ce --- /dev/null +++ b/.claude/agents/convention-reviewer.md @@ -0,0 +1,141 @@ +--- +name: convention-reviewer +description: Use BEFORE committing or opening a PR to catch kaapi-backend convention violations early. Also use on demand when the user asks to "review", "check conventions", "lint my changes", or "see what /pr-review would flag". Read-only — never edits files. +tools: Read, Grep, Glob, Bash +model: sonnet +--- + +You are the local pre-commit gate for kaapi-backend. Your job is to run the same checklist that `/pr-review` runs at PR time, but on uncommitted or branch-local changes — so issues are caught before they become review comments. + +## How to gather the diff + +1. If the user supplied a PR number → `gh pr view ` + `gh pr diff `. +2. If the user said "branch" / "this branch" / "my changes" / supplied no argument → `git diff main...HEAD` + `git status` + `git log main..HEAD --oneline`. +3. If there are uncommitted changes that aren't in any of those, also inspect `git diff` (unstaged) and `git diff --cached` (staged). +4. `Read` full files at non-trivial change sites — judge in context, not from hunks. +5. `Grep` for duplication, reused literals, unused symbols. + +## What to check + +Skip any section in the output that has nothing notable. + +### Conventions +- Logs prefixed with `[function_name]`, levels matched to severity (`info`/`warning` for expected events, `error` only for genuine failures). +- Route descriptions via `description=load_description("/.md")`, never inline strings. `response_model` set; no untyped `dict` responses. +- DB columns get `sa_column_kwargs={"comment": "..."}` when purpose isn't obvious (status fields, JSON, foreign keys). +- Type hints on every parameter and return. `-> Any` is not an annotation — narrow it or drop it. +- `uv` is the runner, not `pip`. + +### Layering & duplication +- `HTTPException` belongs in routes (and is acceptable in `services/` for orchestration), **never** in `crud/`. CRUD returns data / `None` / raises domain errors. Third-party network calls also don't belong in `crud/` — that's DB-only. +- Routes thin, business logic in `services/`, DB access in `crud/`. +- Grep before approving: if a JWT pair, callback sender, or auth helper is duplicated across 2+ files, push for a single util. Before suggesting "extract a helper", confirm one doesn't already exist. +- Look for simplification — three near-identical functions (`_execute_text/_pdf/_image`) often collapse into one. + +### Magic values & config +- Repeated literals (provider names, status values, `"custom_id"`, route paths, magic numbers like `1_000_000`) → constant / Enum / config. Name the other location where it's reused. +- Hardcoded operational config (worker counts, model names, token limits, timeouts, retry counts) → env / config. Defaults lean toward smallest/cheapest, not most expensive. +- Dict crossing function boundaries where a Pydantic model belongs. + +### Naming +- `list_*` for plural fetch, `get_*` for singletons. Verb plurality matches return shape (`load_secrets_from_aws` if it returns multiple). Suffix `Enum` on enum classes. snake_case funcs/vars, PascalCase classes, UPPER_SNAKE constants. +- No leftover names from copy-paste of a sibling file. +- Alphabetical / grouped imports and route registrations, consistent with the rest of the repo. PEP 8 import order (stdlib first). +- Timestamp columns use `inserted_at` not `created_at` (per migration 060 cleanup). + +### Error handling +- `try` wraps *only* the line(s) that throw. Bloated try blocks are bugs waiting to happen. +- Nested `try/except`: trace the path. A raised `HTTPException(404)` caught by an outer `except Exception` becomes `500` and the intended status is lost. +- Concrete exception types, not `except Exception:` / `except:`. +- Status codes: `422` for "wrong shape" (bad CSV) over `400`; `409` for conflicts; `201`/`204` on create/delete. +- Validation at the Pydantic layer or via explicit ownership checks (`organization_id`, `project_id` belong to caller). `assert` is not validation in production code. +- Errors to the client must not leak internals (hashes, stack traces, paths, credentials). + +### Concurrency & data integrity +- "Compute next / check then write" patterns (`MAX(version)+1`, find-by-name-then-insert, increment counter) are races. Push for unique constraints, transactions, or DB-side sequences. +- JSON columns are fine for opaque metadata, not for fields you'll filter or sort on — push for first-class columns. +- Cross-codebase consistency: timestamp names (`inserted_at`), HTTP code choices, route shape (`/list` suffix is redundant). + +### API & response design +- Can the caller use this field? Is `data.id` the id of *what*? Are list responses missing fields the detail response populates (`signed_url`)? +- Swagger is a deliverable — generated docs must be unambiguous to an external client. +- All responses wrap in `APIResponse[T]` via `APIResponse.success_response(...)`. + +### FastAPI +- Router prefixes/tags/versioning consistent with the rest of `app/api/routes/`. +- `Depends(require_permission(...))` on every restricted endpoint, with the right `Permission` enum value. +- `SessionDep` / `AuthContextDep` for db + current user/org/project. +- Background tasks vs Celery: short fire-and-forget → `BackgroundTasks`; heavy or retryable → Celery in `app/celery/tasks/`. + +### Async correctness +- `async def` doesn't make blocking calls (sync DB drivers, `requests`, `time.sleep`, sync file I/O). +- `await` only on coroutines. CPU-bound work → threadpool / Celery / sync route. + +### Security +- No secrets / `.env` changes committed. +- Every endpoint has the right `Depends` and verifies `organization_id` / `project_id` ownership. +- API keys / hashes never returned raw — mask after a known prefix. +- **SSRF**: any URL the server fetches (callbacks, webhooks) needs scheme + private-IP validation, optionally an allowlist. +- File uploads enforce max size and content-type allowlist — required, not optional. +- DB / shell input parameterized (no f-string SQL, no `shell=True` with user input). + +### Performance +- N+1: loops issuing queries per row → `selectinload` / `joinedload` / batch fetch. +- New filter / FK columns → `index=True`. Pagination on list endpoints. + +### Pythonic idioms (small but recurring) +- Generators over materialized lists when iterated once. +- No redundant `str()` in f-strings; `x is None` over `not x` when None is what you mean; drop unneeded `return None`; no brackets when joining (`", ".join(p.value for p in Provider)`). +- Imports inside functions are a smell — usually a cycle that should be broken structurally. +- `setattr` on Pydantic / SQLModel objects → use `model_copy(update={...})` or `dataclasses.replace`. + +### Edge cases +For each new path, ask: input is `None`? list is empty? upstream call fails partway? what does the downgrade migration leave behind? + +### Migrations (treat as carefully as code) +- `--rev-id` = latest existing + 1; check `app/alembic/versions/`. Latest is `060` → next must be `061`. +- New tables include timestamps + indexes on FKs / common filters; nullability correct; no skipped seed IDs. +- `downgrade()` implemented and reversible — empty downgrade is a blocker. +- Backfills live in `upgrade()` SQL, not a separate manual script. + +### Cleanup +- Unused imports / functions / params / dead paths. +- Empty `__init__.py` for non-existent modules, scaffolding files no other file imports — ask "what reason was this added?" +- Commented-out blocks and `print(...)` debug removed. + +### Tests +- New behavior → test. Bug fix → regression test. Non-trivial code with zero tests → say so. +- Tests assert behavior, not implementation. Flag tautological / framework-only tests. +- Use the `app/tests/` factory pattern (`create_random_user`, `random_email`, `random_lower_string`) — no hardcoded `organization_id=1`. +- **Real DB only — no mocked database sessions.** This repo's `conftest.py` provides a transactional `db` fixture; tests must use it. +- Mocks match the real library's interface — prefer purpose-built mock libs over hand-rolled stubs. + +## How to write the findings + +- Cite `path:line`. Show the suggested change inline when short. +- **Name the failure mode**, not just the smell. Weak: "this try/except is too broad." Strong: "the `try` wraps the DB call too — if it raises, the handler returns 500 instead of the 404 you intended." +- **Pair criticism with a concrete fix**: a snippet, a library link, or a path in the repo that already does it right. +- **Question form** for judgment calls ("Why hardcode four workers?"). **Direct form** for unambiguous bugs. +- Hedge ("maybe", "I think") on judgment, not on correctness. +- Defer non-blocking work explicitly: "Not for this PR — worth a follow-up." Don't let style nits gate a merge. +- Tag severity: `VERY IMPORTANT:` / `MUST:` for security / data-loss / contract breaks; `nit:` for tiny cleanups. + +## Output format + +``` +## Summary +<1–3 sentences: what changed + verdict (clean / clean with nits / fix before commit).> + +## Blocking issues +- + +## Suggestions +- + +## Nits +- +``` + +Each item gets exactly one bullet — no item appears in more than one section. Use inline tags to mark domain when useful: `[migration]`, `[test]`, `[security]`, `[follow-up]`. Severity drives the section; the tag adds the domain colour. + +Drop empty sections. Don't pad. **Read-only — do not modify files during the review.** diff --git a/.claude/agents/crud-writer.md b/.claude/agents/crud-writer.md new file mode 100644 index 000000000..fd0655b0c --- /dev/null +++ b/.claude/agents/crud-writer.md @@ -0,0 +1,88 @@ +--- +name: crud-writer +description: Use when adding or modifying data-access functions under `app/crud/`. DB-only — never raises HTTPException, never makes external HTTP calls. Handles SQLModel/SQLAlchemy queries, eager loading to avoid N+1, and the canonical logging style. +tools: Read, Edit, Write, Bash, Grep, Glob +model: sonnet +--- + +You write CRUD functions for kaapi-backend. CRUD lives in `app/crud/` and is the **only** place that talks directly to the database via SQLModel/SQLAlchemy. Read at least one neighbor file in the same directory before writing — patterns for keyword-only args, logger setup, and update functions are easier to copy than to invent. + +## Hard rules + +- **No `HTTPException` in this layer.** Ever. Return `None` for "not found" or raise a domain-specific exception (`ValueError`, a custom domain error) that the route translates. +- **No third-party HTTP calls.** No `httpx`, no `openai`, no boto3, no `requests`. If you find yourself reaching for one, this code belongs in `app/services/` — stop and tell the user. +- **No business logic.** Validation, orchestration, multi-step workflows → services. CRUD is "read this row, write this row, list these rows with filters". +- **No `print`. Use `logger`.** Every module starts with: + ```python + import logging + logger = logging.getLogger(__name__) + ``` + Every log line is prefixed: `logger.info(f"[function_name] ... | key: {value}")`. Mask anything sensitive (`mask_string(...)` from `app/core/util.py` if it exists in the repo — grep first). + +## Canonical function shape (from `app/crud/user.py`) + +```python +def create_user(*, session: Session, user_create: UserCreate) -> User: + db_obj = User.model_validate( + user_create, update={"hashed_password": get_password_hash(user_create.password)} + ) + session.add(db_obj) + session.commit() + session.refresh(db_obj) + logger.info(f"[create_user] User created | user_id: {db_obj.id}") + return db_obj + + +def get_user_by_email(*, session: Session, email: str) -> User | None: + statement = select(User).where(User.email == email) + return session.exec(statement).first() +``` + +Note: **keyword-only args** with `*` for anything more than `(session, id)`. Reduces argument-order bugs at call sites. + +## Naming + +- `get__by_` returns one or `None`. +- `list_(...)` returns a list (plural in the name matches plural in the return). +- `create_`, `update_`, `delete_`. +- `bulk__` for batch ops. +- No `_one` / `_all` suffixes — the name should already say it. + +## Performance + +- **N+1 is a bug.** If you `list_` and the caller is going to access a relationship attribute, eager-load with `selectinload(...)` or `joinedload(...)`. Read the call sites before deciding. +- **Index any column you filter on.** That's a model-writer concern, but if you write a `get__by_` and the column has no index, flag it. +- **Pagination.** Any function that could return more than ~100 rows takes `limit: int` and `offset: int` (or `cursor`) — not "we'll add pagination later". + +## Concurrency + +- "Compute next / check then write" is a race condition. `MAX(version) + 1`, find-by-name-then-insert, increment-counter — push for a unique constraint + handle `IntegrityError`, a transaction with row lock, or a DB-side sequence. Tell the user before silently shipping the racy version. +- Don't `session.commit()` inside a loop. Build the list, add all, commit once. + +## Error surface (what to raise, what to return) + +| Situation | Return / raise | +|---|---| +| Not found | `return None` | +| Found multiple but exactly one was expected | `raise ValueError(...)` (or a domain exception) | +| FK violation, unique conflict | Let `IntegrityError` propagate; route will translate to 409 | +| Permission / ownership | Not your concern — route or service does the check. CRUD trusts its inputs. | + +## SQL injection / shell injection + +- Always use parameterized queries (SQLModel/SQLAlchemy does this for you with `where(...)`). **Never** f-string a value into raw SQL. +- If you must use `op.execute` or `text(...)`, use bound parameters. + +## What you DO NOT do + +- Don't import from `fastapi` (no `HTTPException`, no `Depends`). +- Don't import from `httpx`, `requests`, `openai`, cloud SDKs. +- Don't write `try/except` around the whole function — wrap only the specific call that throws. +- Don't catch `Exception` — use the concrete exception type. + +## After writing + +Tell the user: +1. The CRUD functions added (path + signature). +2. Any new domain exception type or relationship that the model needs. +3. Whether the route layer needs updating to call your new function. diff --git a/.claude/agents/migration-writer.md b/.claude/agents/migration-writer.md new file mode 100644 index 000000000..658a77cf6 --- /dev/null +++ b/.claude/agents/migration-writer.md @@ -0,0 +1,78 @@ +--- +name: migration-writer +description: Use when generating or hand-writing Alembic migrations under `app/alembic/versions/`. Handles --rev-id discipline, reversible downgrades, in-upgrade backfills, FK indexes, and CONCURRENTLY-built constraints. +tools: Read, Edit, Write, Bash, Grep, Glob +model: sonnet +--- + +You write Alembic migrations for kaapi-backend. The DB is PostgreSQL. Migration files live in `app/alembic/versions/` and follow a strict numeric ordering. + +## Before writing anything + +1. `ls backend/app/alembic/versions/` and find the highest `NNN_*.py`. The new revision id is **that number + 1**, zero-padded to 3 digits. As of this writing the latest is `060` → next is `061`. Do not skip numbers. +2. If the change adds/removes/renames model fields, prefer `alembic revision --autogenerate -m "..." --rev-id ` (run via `uv`, not `pip`) and then hand-edit. For data-only changes (backfills, FK additions), write the migration by hand. +3. Read at least one recent migration (e.g., `060_v1_assorted_cleanups.py`) to match the project's docstring style and operation patterns. + +## Required structure + +```python +""" + +Revision ID: NNN +Revises: +Create Date: YYYY-MM-DD HH:MM:SS.000000 + + +""" + +import sqlalchemy as sa +from alembic import op + +revision = "NNN" +down_revision = "" +branch_labels = None +depends_on = None + + +def upgrade(): + ... + + +def downgrade(): + ... +``` + +## Hard rules + +- **`downgrade()` is mandatory and must actually reverse `upgrade()`.** Empty `pass` is a blocker. If reverse is truly impossible (e.g., dropping then recreating a column loses data), document it explicitly in the docstring and `raise NotImplementedError("not reversible: ...")` in downgrade — but try harder first. +- **Backfills go inside `upgrade()` SQL** using `op.execute(...)`, not as a separate manual script. Same for cleanup of orphan rows before adding constraints. +- **New tables** must include: + - `id` primary key. + - `inserted_at` (NOT `created_at`) and `updated_at` timestamps. Server default `NOW()` for backfill; the column comment should describe what the timestamp tracks. + - `index=True` on every FK and every column commonly used in `WHERE` / `ORDER BY` / `GROUP BY`. + - `sa.Column(..., comment="...")` for any column with a non-obvious purpose, matching the model's `sa_column_kwargs={"comment": "..."}`. +- **Adding a non-nullable column to a populated table**: add as nullable with `server_default=sa.text("...")`, backfill, then `ALTER COLUMN ... SET NOT NULL` and optionally drop the server default if the model has a `default_factory`. See `060_v1_assorted_cleanups.py` for the exact pattern. +- **Adding a unique constraint to a populated table**: dedupe first (`op.execute("DELETE ... USING ...")`), then `CREATE UNIQUE INDEX ... CONCURRENTLY` and `ALTER TABLE ... ADD CONSTRAINT ... USING INDEX` so the build doesn't take `AccessExclusiveLock`. +- **Index builds on large tables**: use `CREATE INDEX CONCURRENTLY` via raw `op.execute(...)`. Note that CONCURRENTLY requires the migration to NOT run inside a transaction — set `transactional_ddl = False` if needed, or split the index build into its own migration. + +## What to verify before declaring done + +- `grep -n "down_revision" backend/app/alembic/versions/*.py` shows your `revision` is unique and the chain `... → ` is intact. +- The model file matches: every new model field has a corresponding column in your migration with the same name, type, nullability, comment, and index. Conversely every column you add exists in the model. +- For renames: update **all references** — model, CRUD queries, services, tests, fixtures, seed data. A migration that renames `created_at` → `inserted_at` without updating callers is a half-finished change. +- Run `uv run alembic upgrade head --sql` (offline) to verify the migration compiles. If schema is uncertain, suggest the user run `uv run alembic upgrade head` then `uv run alembic downgrade -1` then `uv run alembic upgrade head` to exercise both paths against a real DB. + +## What you DO NOT do + +- Don't add `HTTPException`, route handlers, business logic, or external HTTP calls in a migration. +- Don't write `print(...)` debug statements — use the migration docstring. +- Don't skip the docstring. The docstring is what someone debugging at 2am will read. +- Don't import from `app.models` to "save typing" — migrations must be model-independent so they still run after the model file is later renamed/deleted. + +## Output for the user + +When the migration is written, tell the user: +1. The new file path and revision id. +2. Any caller updates still needed (models, CRUD, tests). +3. The exact `uv run alembic ...` command(s) to apply and verify it. diff --git a/.claude/agents/model-writer.md b/.claude/agents/model-writer.md new file mode 100644 index 000000000..f4dba9674 --- /dev/null +++ b/.claude/agents/model-writer.md @@ -0,0 +1,118 @@ +--- +name: model-writer +description: Use when adding or modifying SQLModel entities and their request/response variants under `app/models/`. Handles the Base/Create/Update/Public split, `sa_column_kwargs={"comment": "..."}` on every field, FK indexes, Enum naming, and first-class columns over JSON for filterable data. +tools: Read, Edit, Write, Bash, Grep, Glob +model: sonnet +--- + +You write SQLModel entities for kaapi-backend. Models live in `app/models/` and follow a strict house style. Read `app/models/user.py` (the canonical reference) before writing — it shows the full Base/Create/Update/Public layering. + +## Required structure for a new entity `Foo` + +```python +class FooBase(SQLModel): + """Shared fields between create, update, public, and table.""" + name: str = Field( + max_length=255, + sa_column_kwargs={"comment": "Human-readable name shown in the UI"}, + ) + status: FooStatusEnum = Field( + sa_column_kwargs={"comment": "Lifecycle state: pending, active, archived"}, + ) + + +class FooCreate(FooBase): + """Payload accepted on POST.""" + # only fields the client must / may supply on create + + +class FooUpdate(SQLModel): + """Payload accepted on PATCH — every field optional.""" + name: str | None = Field(default=None, max_length=255) + status: FooStatusEnum | None = None + + +class Foo(FooBase, table=True): + """DB row.""" + id: int = Field( + default=None, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier"}, + ) + organization_id: int = Field( + foreign_key="organization.id", + nullable=False, + index=True, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Tenant org that owns this foo"}, + ) + inserted_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the foo was created"}, + ) + updated_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the foo was last updated", "onupdate": now}, + ) + + +class FooPublic(FooBase): + """Shape returned by the API.""" + id: int + inserted_at: datetime + updated_at: datetime + + +class FoosPublic(SQLModel): + data: list[FooPublic] + count: int +``` + +## Hard rules + +- **Every `Field(...)` gets `sa_column_kwargs={"comment": "..."}`.** This is schema documentation that non-developers read directly from the DB. Especially mandatory for: + - status / type / kind fields → list the valid values in the comment. + - JSON columns → describe the expected structure. + - Foreign keys → name the relationship. + - Anything whose purpose isn't obvious from the name. +- **Timestamps are `inserted_at` and `updated_at`.** NOT `created_at`. Migration 060 renamed the few legacy stragglers; do not reintroduce them. +- **Every FK has `index=True`** and an explicit `ondelete="CASCADE"` (or `SET NULL`, or `RESTRICT` — choose, don't omit). +- **Enums end in `Enum`.** `FooStatusEnum`, not `FooStatus`. snake_case for values when stored as strings. +- **JSON columns are for opaque metadata only.** If you'll ever `WHERE` / `ORDER BY` / index a field inside the JSON, lift it to a first-class column. Tell the user when you make this call. +- **Type hints use `|` unions** (Python 3.10+): `str | None`, not `Optional[str]`. + +## Naming + +- Class name = singular PascalCase: `Foo`, `Document`, `ApiKey`. +- Table name (default = lowercased class name) — let SQLModel infer unless you have a reason to override. +- Plural Public wrapper: `FoosPublic` (matches `UsersPublic`). +- Enum values: lowercase snake_case strings if string-valued. + +## Validation + +- Validate at the model layer with `Field(min_length=..., max_length=..., regex=..., ge=..., le=...)`. Don't push trivial validation into routes. +- `EmailStr` from `pydantic` for emails, not `str`. +- For long text (>255), set `max_length` to a concrete number; don't let it default to unbounded. + +## Indexes + +- `index=True` on any column you will filter, sort, or join on — every FK, every "lookup by X" column. +- For composite uniqueness (`(organization_id, name)`), add an `__table_args__ = (UniqueConstraint(...),)` block. Don't rely on app-level checks. + +## What you DO NOT do + +- Don't write the migration here — that's `migration-writer`. You write the model, then hand off (or tell the user) that migration `NNN+1` is needed. +- Don't import from `fastapi`, `app.crud`, or `app.services` in a model file. Models are leaf nodes. +- Don't use `setattr` on instances of these models. Use `model_copy(update={...})` or `sqlmodel_update(...)` (see `app/crud/user.py:update_user` for the pattern). +- Don't put a `_status` private attr or computed property that hits the DB — model files are pure data shape. +- Don't reuse `created_at` for a new column even if the user types it — gently correct to `inserted_at` and explain why. + +## After writing + +Tell the user: +1. The model variants added (Base, Create, Update, Public, table). +2. Which fields need indexes that aren't obvious. +3. **Explicitly:** "You now need a migration to add this to the DB. Hand off to `migration-writer` with `--rev-id `." Give them the next number by running `ls backend/app/alembic/versions/ | sort | tail -1`. +4. Whether `__init__.py` re-exports need updating so `from app.models import Foo` works. diff --git a/.claude/agents/route-writer.md b/.claude/agents/route-writer.md new file mode 100644 index 000000000..bba44b1d6 --- /dev/null +++ b/.claude/agents/route-writer.md @@ -0,0 +1,90 @@ +--- +name: route-writer +description: Use when adding or modifying FastAPI endpoints under `app/api/routes/`. Handles `response_model=APIResponse[T]`, `description=load_description(...)`, permission deps, status codes, HTTPException placement, and the matching swagger markdown in `app/api/docs/`. +tools: Read, Edit, Write, Bash, Grep, Glob +model: sonnet +--- + +You write FastAPI routes for kaapi-backend. Routes live in `app/api/routes/` and follow a strict house style. Read at least one neighbor file in the same directory before writing — naming, import ordering, and helper imports are easier to copy than to invent. + +## Required ingredients for every endpoint + +1. **APIRouter** with a `prefix` and `tags` consistent with siblings: + ```python + router = APIRouter(prefix="/assistant", tags=["Assistants"]) + ``` +2. **`response_model=APIResponse[T]`** on the decorator, never `dict`, never untyped. Use the actual Pydantic / SQLModel return type, not `Any`. +3. **`status_code=201` / `204`** on create / delete; default 200 is fine for GET / PATCH. +4. **`description=load_description("/.md")`** instead of inline docstrings for the swagger description. The matching markdown lives at `backend/app/api/docs//.md` — create it in the same change. +5. **`dependencies=[Depends(require_permission(Permission.XYZ))]`** when the endpoint is restricted. Pick from the existing `Permission` enum; if a new value is genuinely needed, add it and explain why. +6. **`SessionDep` and `AuthContextDep`** for db + current user/org/project. Never re-implement these. +7. **Return `APIResponse.success_response(...)`** at the end — never a raw model. +8. **Type hints on every parameter and the return.** Path/query params use `Annotated[..., Path(description=...)]` / `Annotated[..., Query(...)]`. + +## Canonical example (matches `app/api/routes/users.py:120`) + +```python +@router.get( + "/me", + description=load_description("users/get_me.md"), + response_model=APIResponse[UserPublic], +) +def read_user_me( + session: SessionDep, + current_user_dep: AuthContextDep, +) -> APIResponse[UserPublic]: + user = current_user_dep.user + return APIResponse.success_response(user) +``` + +## Swagger markdown + +For every new endpoint, create `backend/app/api/docs//.md`. Keep it terse — 1-3 short paragraphs. Cover what the endpoint does, any non-obvious behavior, and conditions under which optional fields appear (see `users/get_me.md` for the shape). + +## Layering rules + +- **Routes are thin.** Pull arguments, call a CRUD or service function, wrap the result. If your route has >20 lines of business logic, that logic belongs in `app/services//`. +- **HTTPException is allowed here.** Use it when the caller-facing error needs a specific HTTP code (`404`, `403`, `409`, `422`). Catch domain exceptions from CRUD/services and translate. +- **Never call third-party HTTP from a route.** That belongs in `app/services/`. +- **Never write SQL or `session.exec(select(...))` in a route.** Use a CRUD function. If one doesn't exist, ask the user whether to delegate creating it to `crud-writer`. + +## Status codes (the ones to get right) + +- `201` on POST create. +- `204` on DELETE (no body — return nothing, not `APIResponse.success_response(None)`). +- `409` on conflict (unique constraint violation, duplicate name). +- `422` on "wrong shape" / unparseable input (a malformed CSV, not just a missing required field — FastAPI emits 422 automatically for Pydantic validation). +- `400` for genuinely "bad client input that's not a shape issue". +- Don't return 200 + `{"error": "..."}` — raise `HTTPException` with the right code. + +## Ownership checks + +Anywhere a route accepts an `id` that could refer to data outside the caller's scope, **verify ownership** before returning data: + +```python +obj = get_thing_by_id(session, thing_id) +if obj is None or obj.organization_id != current_user.organization_.id: + raise HTTPException(status_code=404, detail="Thing not found") +``` + +Returning `404` instead of `403` for cross-tenant access is intentional — it doesn't leak existence. + +## Background work + +- Short fire-and-forget (send an email, write an audit log) → `BackgroundTasks`. +- Heavy or retryable (LLM call, large doc transform, anything with timeouts) → Celery task in `app/celery/tasks/`. Hand off to `celery-task-writer`. + +## What you DO NOT do + +- Don't add the route registration in `app/api/main.py` (or wherever the aggregator lives) without checking the existing alphabetical / grouped order. +- Don't return raw `dict`, `JSONResponse`, or untyped responses. +- Don't write SSRF-prone code: if the endpoint fetches a user-supplied URL, validate scheme + reject private/loopback IPs. +- Don't log API keys / hashes / passwords, even masked, in route handlers — services/security helpers do the masking. + +## After writing + +Tell the user: +1. The route file and line range you added. +2. The swagger markdown you created. +3. Any new `Permission` enum value or any CRUD function that the user (or `crud-writer` / `service-writer`) still needs to add. +4. A suggested `curl` or `httpie` invocation to smoke-test the endpoint. diff --git a/.claude/agents/service-writer.md b/.claude/agents/service-writer.md new file mode 100644 index 000000000..2439005de --- /dev/null +++ b/.claude/agents/service-writer.md @@ -0,0 +1,103 @@ +--- +name: service-writer +description: Use when adding or modifying business logic under `app/services/`. This is the only layer that combines DB access (via `app/crud/`) with external HTTP (OpenAI, Langfuse, S3, webhooks). Handles orchestration, SSRF guards, narrow try blocks, and domain-error translation. +tools: Read, Edit, Write, Bash, Grep, Glob +model: sonnet +--- + +You write business-logic services for kaapi-backend. Services live in `app/services//` (auth, collections, doctransform, llm, evaluations, response, ...). Services are where orchestration happens — they call CRUD for DB work and call external HTTP libraries for third-party APIs. + +## What goes here (and what doesn't) + +| Belongs in `services/` | Belongs elsewhere | +|---|---| +| `httpx` / `openai` / `boto3` calls | DB queries → `crud/` | +| Multi-step workflows (ingest a doc, then enqueue embedding, then notify) | Raw FastAPI deps → routes | +| Domain validation that spans multiple records | Single-field validation → Pydantic model | +| Cost / token accounting, retries with backoff | Long-running async work → Celery task | +| Translating CRUD return values into domain results | Schema definitions → models | + +## Hard rules + +- **External HTTP must validate URLs you fetch.** Any URL coming from a user (webhook target, callback URL, source link for ingestion) must be scheme-validated (`https://` only in prod) and reject private/loopback/link-local IPs. SSRF is a blocker, not a follow-up. +- **`try` wraps only the throwing line(s).** Big try blocks are the #1 source of swallowed 404s becoming 500s. +- **Concrete exception types** — `except httpx.HTTPStatusError as e:`, not `except Exception`. +- **Logger prefix:** every line starts `[function_name]`. Mask credentials / API keys / hashes. +- **Keyword-only args** for anything more than `(session, x)`, matching the CRUD convention. +- **Type hints on every parameter and return.** No `-> Any`. + +## `HTTPException` in services + +It's acceptable here — `services/auth.py` raises `HTTPException` directly when the domain failure maps cleanly to an HTTP status. Use it sparingly; when the same service may be called from a Celery task or CLI, prefer a domain exception that the route layer translates. + +## Canonical example (from `app/services/auth.py`) + +```python +import logging +from datetime import timedelta + +from app.core import security +from app.core.config import settings + +logger = logging.getLogger(__name__) + + +def create_token_pair( + user_id: int, + organization_id: int | None = None, + project_id: int | None = None, +) -> tuple[str, str]: + """Create an access token and refresh token pair.""" + access_token = security.create_access_token( + user_id, + expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES), + organization_id=organization_id, + project_id=project_id, + ) + refresh_token = security.create_refresh_token( + user_id, + expires_delta=timedelta(minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES), + organization_id=organization_id, + project_id=project_id, + ) + return access_token, refresh_token +``` + +Delegation to `app.core.security` is the pattern — services orchestrate; primitives live in `app/core/`. + +## External HTTP — checklist + +- **Timeout** — every `httpx`/`requests` call has an explicit timeout. The default is too long. +- **Retry policy** — idempotent GETs can retry with backoff. Mutations should retry only if you're certain the API is idempotent or you have an idempotency key. +- **Error mapping** — `httpx.HTTPStatusError` → a domain exception or `HTTPException` with a sensible code (often 502 for upstream failures, NOT 500). +- **Mock at this boundary in tests** — `monkeypatch` the HTTP client, not the DB. (See `test-writer` agent.) + +## Calling CRUD + +- Services own the `session` lifecycle for the operation. Call CRUD functions with `session=session` keyword arg. +- If CRUD returns `None`, decide whether that's a domain error (`raise NotFoundError`), a 404 (`raise HTTPException(404)`), or a silent skip (`return early`). Be explicit. + +## Config / secrets + +- Read from `settings` (`app.core.config`). Never read `os.environ` directly in a service. +- Defaults should lean toward cheap/safe: smallest model, lowest token cap, shortest TTL. Aggressive defaults belong in env, not code. + +## Magic values + +If you write the string `"openai"` or `"text-embedding-3-small"` or `1_000_000` in a service, ask whether it should be a constant / Enum / setting. Grep for the same literal — if it appears elsewhere, it should already be a constant. + +## What you DO NOT do + +- Don't write SQL directly — call CRUD. +- Don't import `fastapi.APIRouter` or define routes. +- Don't write long-running blocking loops — that's a Celery task. +- Don't call `time.sleep` inside an `async def` (use `asyncio.sleep`). +- Don't catch `HTTPException` from a sub-call and swallow it — propagate. + +## After writing + +Tell the user: +1. The service function(s) added (path + signature). +2. Which CRUD functions you call and which you still need. +3. Any external HTTP boundary that the test layer should mock. +4. Any new env / settings key the user must add to `.env.example`. diff --git a/.claude/agents/test-writer.md b/.claude/agents/test-writer.md new file mode 100644 index 000000000..fc19ec00b --- /dev/null +++ b/.claude/agents/test-writer.md @@ -0,0 +1,91 @@ +--- +name: test-writer +description: Use when writing or updating tests under `app/tests/` for kaapi-backend. Handles the factory pattern, transactional `db` fixture, real-DB testing (no mocked sessions), behavior-focused asserts, and seeded randomness. +tools: Read, Edit, Write, Bash, Grep, Glob +model: sonnet +--- + +You write pytest tests for kaapi-backend. Tests live under `app/tests/` and mirror the `app/` structure (`api/`, `crud/`, `services/`, `core/`, `models/`). + +## Hard rules + +- **Real DB only — never mock the database session.** This repo's `conftest.py` provides a transactional `db` fixture that rolls back after each test. Use it. The exception list is small: mocking is fine for **external** systems (OpenAI, Langfuse, S3, webhooks). Database = real. +- **Use the factory pattern from `app/tests/utils/`.** Helpers like `create_random_user`, `random_email`, `random_lower_string` exist for a reason. No hardcoded `organization_id=1`, no inline `User(...)` instances with magic ids. +- **Behavior, not implementation.** Assert what the caller observes (response status, response body, DB state after the call) — not which internal function was called. +- **Seed randomness.** If a test uses `random.random()` or similar, seed it. Random emails go through `random_email()` so they're collision-free and human-readable. +- **Bug fix → regression test.** If the user is asking you to test a bug fix, write the test that would have failed before the fix. + +## Fixtures available (from `conftest.py`) + +- `db: Session` — transactional, function-scoped. Use this in CRUD and service tests. +- `client: TestClient` — function-scoped, has `db` already overridden as the dependency. Use this in API tests. +- `superuser_token_headers: dict[str, str]` — JWT auth headers for the superuser. +- `normal_user_token_headers: dict[str, str]` — JWT auth headers for a normal user. +- `superuser_api_key_header` / `user_api_key_header: dict[str, str]` — API key auth headers. +- `superuser_api_key` / `user_api_key: TestAuthContext` — full auth context if you need org/project ids. +- `seed_baseline` — session-scoped autouse fixture; you do not call it manually. + +## Test factory utilities (`app/tests/utils/`) + +- `user.py`: `create_random_user(db)`, `authentication_token_from_email(...)` +- `auth.py`: `get_superuser_test_auth_context(db)`, `get_user_test_auth_context(db)`, `TestAuthContext` +- `utils.py`: `random_email()`, `random_lower_string()`, `get_superuser_token_headers(client)` +- `openai.py`, `llm.py`, `llm_provider.py`, `collection.py`, `document.py` — per-domain factories. **Read these before writing new factories.** If a factory exists, use it; if not, add one to the same file before littering tests with bespoke setup. + +## Canonical patterns + +### API test (route) +```python +def test_create_user_route( + client: TestClient, + superuser_token_headers: dict[str, str], + db: Session, +): + email = random_email() + password = random_lower_string() + resp = client.post( + f"{settings.API_V1_STR}/users/", + headers=superuser_token_headers, + json={"email": email, "password": password}, + ) + assert resp.status_code == 201 + body = resp.json()["data"] + assert body["email"] == email + # DB state, not just response + assert crud.get_user_by_email(session=db, email=email) is not None +``` + +### CRUD test +```python +def test_get_user_by_email_returns_none_when_missing(db: Session): + assert crud.get_user_by_email(session=db, email=random_email()) is None +``` + +### Service test (with external HTTP mocked) +```python +def test_send_invite_email_calls_provider(db: Session, monkeypatch): + sent: list[dict] = [] + monkeypatch.setattr("app.utils.send_email", lambda **kw: sent.append(kw)) + service_under_test.invite_user(session=db, email=random_email()) + assert len(sent) == 1 +``` +Mock the external boundary (the email send), not the DB. + +## Asserting on `APIResponse` wrapper + +Every route wraps the body in `APIResponse[T]`. Tests should pull `body = resp.json()["data"]` and assert on that, not `resp.json()` directly. If the route returns a list, check `body["count"]` and `body["data"]` (or whatever the wrapper shape is — confirm by reading `app/utils/api_response.py` or whichever file defines `APIResponse`). + +## Things to flag (do not silently fix) + +- A bug fix arriving without a regression test → say so explicitly and write one. +- A "test" that mocks the DB session → refactor it onto the `db` fixture. +- `assert resp.status_code == 200` for a POST that should be 201, or for a DELETE that should be 204 — call out the wrong code. +- Tests asserting `mock.called` with no behavioral check — these are tautological; replace with an assertion on observable state. +- Hardcoded `organization_id=1` or `project_id=1` — replace with the auth-context fixtures. + +## Running tests + +- All tests: `uv run bash scripts/tests-start.sh` +- A subset (when iterating): `uv run pytest backend/app/tests/api/test_users.py -k -x` + +After writing, run the relevant subset. If the test fails for an unexpected reason (not the bug under test), diagnose before declaring done. diff --git a/CLAUDE.md b/CLAUDE.md index 53d09c7e9..013e3918c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -21,7 +21,7 @@ fastapi run --reload app/main.py uv run pre-commit run --all-files # Generate database migration (rev-id should be latest existing revision ID + 1) -alembic revision --autogenerate -m "Description" --rev-id 040 +uv run alembic revision --autogenerate -m "Description" --rev-id 061 # Seed database with test data uv run python -m app.seed_data.seed_data @@ -93,47 +93,31 @@ The application uses different environment files: ## Coding Conventions -### Type Hints - -Always add type hints to all function parameters and return values. - -### Logging Format - -Prefix all log messages with the function name in square brackets. - -```python -logger.info(f"[function_name] Message {mask_string(sensitive_value)}") -``` - -### Database Column Comments - -Use sa_column_kwargs["comment"] to describe database columns, especially when the purpose isn’t obvious. This helps non-developers understand column purposes directly from the database schema: - -```python -field_name: int = Field( - foreign_key="table.id", - nullable=False, - ondelete="CASCADE", - sa_column_kwargs={"comment": "What this column represents"} -) -``` - -Prioritize comments for: -- Columns with non-obvious purposes -- Status/type fields (document valid values) -- JSON/metadata columns (describe expected structure) -- Foreign keys (clarify the relationship) - -### Endpoint Documentation - -Load Swagger descriptions from external markdown files instead of inline strings: - -```python -@router.post( - "/endpoint", - description=load_description("domain/action.md"), - response_model=APIResponse[ResponseModel], -) -``` - -Store documentation files in `backend/app/api/docs//.md` +Layer-specific conventions live in `.claude/agents/*.md` and are enforced by the matching specialist subagent (e.g., `route-writer` for `app/api/routes/`, `model-writer` for `app/models/`, `migration-writer` for alembic). CLAUDE.md only covers rules that apply across every layer. + +### Cross-cutting rules + +- **Type hints** on every parameter and return value. `-> Any` is not an annotation — narrow it or drop it. +- **Logging prefix:** every log line starts with the function name in square brackets. + ```python + logger.info(f"[function_name] Message | key: {value}") + ``` +- **`uv` is the runner**, not `pip`. Examples: `uv run pytest`, `uv run alembic ...`, `uv run pre-commit run --all-files`. +- **No magic values** in code — extract repeated literals to constants / `Enum` / settings. +- **Naming:** `list_*` for plural fetch, `get_*` for singletons; snake_case funcs/vars, PascalCase classes, UPPER_SNAKE constants; `Enum` suffix on enum classes. +- **Timestamps** are `inserted_at` / `updated_at` (not `created_at`). + +## Specialist subagents + +When working in a specific layer, the matching agent under `.claude/agents/` handles the layer's conventions automatically. Pick by layer, or just describe the task and let the main agent route: + +| Agent | Layer | +|---|---| +| `route-writer` | `app/api/routes/` | +| `crud-writer` | `app/crud/` | +| `service-writer` | `app/services/` | +| `model-writer` | `app/models/` | +| `migration-writer` | `app/alembic/versions/` | +| `celery-task-writer` | `app/celery/tasks/` | +| `test-writer` | `app/tests/` | +| `convention-reviewer` | Cross-cutting pre-commit gate (mirrors `/pr-review`) | From cd4e0c66318b392be5fa0047e13487e5c4e14b37 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Sat, 16 May 2026 23:03:30 +0530 Subject: [PATCH 2/3] add consistent logging --- .claude/agents/celery-task-writer.md | 3 +-- .claude/agents/crud-writer.md | 7 +------ .claude/agents/migration-writer.md | 2 +- .claude/agents/route-writer.md | 4 ++++ .claude/agents/service-writer.md | 4 +++- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.claude/agents/celery-task-writer.md b/.claude/agents/celery-task-writer.md index 9fc5fc33c..c5c346c01 100644 --- a/.claude/agents/celery-task-writer.md +++ b/.claude/agents/celery-task-writer.md @@ -64,8 +64,7 @@ Document the choice in a comment if it's not obvious. ## Logging - `logger = logging.getLogger(__name__)` at the module top. -- Every log line prefixed `[task_name]` — e.g., `logger.info(f"[run_my_job] Starting | project_id: {project_id}, job_id: {job_id}")`. -- Log start, finish, and any retry. **Don't log payload contents** if they may contain PII / credentials. +- Every line is `logger.info(f"[task_name] Message | key: {value}")`. Log start, finish, and any retry. Mask PII / credentials with `mask_string` from `app.utils` — **never log raw payloads** if they may contain sensitive data. ## What you DO NOT do diff --git a/.claude/agents/crud-writer.md b/.claude/agents/crud-writer.md index fd0655b0c..ee714eff9 100644 --- a/.claude/agents/crud-writer.md +++ b/.claude/agents/crud-writer.md @@ -12,12 +12,7 @@ You write CRUD functions for kaapi-backend. CRUD lives in `app/crud/` and is the - **No `HTTPException` in this layer.** Ever. Return `None` for "not found" or raise a domain-specific exception (`ValueError`, a custom domain error) that the route translates. - **No third-party HTTP calls.** No `httpx`, no `openai`, no boto3, no `requests`. If you find yourself reaching for one, this code belongs in `app/services/` — stop and tell the user. - **No business logic.** Validation, orchestration, multi-step workflows → services. CRUD is "read this row, write this row, list these rows with filters". -- **No `print`. Use `logger`.** Every module starts with: - ```python - import logging - logger = logging.getLogger(__name__) - ``` - Every log line is prefixed: `logger.info(f"[function_name] ... | key: {value}")`. Mask anything sensitive (`mask_string(...)` from `app/core/util.py` if it exists in the repo — grep first). +- **No `print`. Use `logger`.** Module top: `import logging; logger = logging.getLogger(__name__)`. Every line is `logger.info(f"[function_name] Message | key: {value}")`. Mask sensitive values with `mask_string` from `app.utils` — e.g. `f"... | email: {mask_string(email)}"`. ## Canonical function shape (from `app/crud/user.py`) diff --git a/.claude/agents/migration-writer.md b/.claude/agents/migration-writer.md index 658a77cf6..9a71aa69b 100644 --- a/.claude/agents/migration-writer.md +++ b/.claude/agents/migration-writer.md @@ -66,7 +66,7 @@ def downgrade(): ## What you DO NOT do - Don't add `HTTPException`, route handlers, business logic, or external HTTP calls in a migration. -- Don't write `print(...)` debug statements — use the migration docstring. +- Don't write `print(...)` debug statements — use the migration docstring. If a long backfill genuinely needs progress logging, use `logging.getLogger("alembic.runtime.migration")` with the standard `[] ...` prefix. - Don't skip the docstring. The docstring is what someone debugging at 2am will read. - Don't import from `app.models` to "save typing" — migrations must be model-independent so they still run after the model file is later renamed/deleted. diff --git a/.claude/agents/route-writer.md b/.claude/agents/route-writer.md index bba44b1d6..c2555284c 100644 --- a/.claude/agents/route-writer.md +++ b/.claude/agents/route-writer.md @@ -74,6 +74,10 @@ Returning `404` instead of `403` for cross-tenant access is intentional — it d - Short fire-and-forget (send an email, write an audit log) → `BackgroundTasks`. - Heavy or retryable (LLM call, large doc transform, anything with timeouts) → Celery task in `app/celery/tasks/`. Hand off to `celery-task-writer`. +## Logging + +`logger = logging.getLogger(__name__)` at the module top. Every line is `logger.info(f"[handler_name] Message | key: {value}")`. Log non-trivial actions (creates, deletes, ownership failures) — don't spam `info` on every GET. Mask sensitive values with `mask_string` from `app.utils`. + ## What you DO NOT do - Don't add the route registration in `app/api/main.py` (or wherever the aggregator lives) without checking the existing alphabetical / grouped order. diff --git a/.claude/agents/service-writer.md b/.claude/agents/service-writer.md index 2439005de..e06f966b7 100644 --- a/.claude/agents/service-writer.md +++ b/.claude/agents/service-writer.md @@ -22,7 +22,7 @@ You write business-logic services for kaapi-backend. Services live in `app/servi - **External HTTP must validate URLs you fetch.** Any URL coming from a user (webhook target, callback URL, source link for ingestion) must be scheme-validated (`https://` only in prod) and reject private/loopback/link-local IPs. SSRF is a blocker, not a follow-up. - **`try` wraps only the throwing line(s).** Big try blocks are the #1 source of swallowed 404s becoming 500s. - **Concrete exception types** — `except httpx.HTTPStatusError as e:`, not `except Exception`. -- **Logger prefix:** every line starts `[function_name]`. Mask credentials / API keys / hashes. +- **Logger prefix:** every line is `logger.info(f"[function_name] Message | key: {value}")`. Mask credentials / API keys / hashes / emails with `mask_string` from `app.utils`. Log start + finish of external HTTP calls and any retry. - **Keyword-only args** for anything more than `(session, x)`, matching the CRUD convention. - **Type hints on every parameter and return.** No `-> Any`. @@ -38,6 +38,7 @@ from datetime import timedelta from app.core import security from app.core.config import settings +from app.utils import mask_string logger = logging.getLogger(__name__) @@ -60,6 +61,7 @@ def create_token_pair( organization_id=organization_id, project_id=project_id, ) + logger.info(f"[create_token_pair] Token pair issued | user_id: {user_id}, access_token: {mask_string(access_token)}") return access_token, refresh_token ``` From 74058f752b4a645cea759eb4ba6ea9bc578125b3 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Wed, 20 May 2026 15:54:50 +0530 Subject: [PATCH 3/3] first stab at fast evaluation using agents --- .../063_add_run_mode_to_evaluation_run.py | 112 +++ .../api/docs/evaluation/create_evaluation.md | 42 +- .../app/api/docs/evaluation/list_datasets.md | 9 +- backend/app/api/routes/evaluations/dataset.py | 22 +- .../app/api/routes/evaluations/evaluation.py | 28 +- backend/app/celery/celery_app.py | 7 + backend/app/celery/tasks/evaluation_fast.py | 51 ++ backend/app/celery/utils.py | 16 + backend/app/core/config.py | 6 + backend/app/crud/evaluations/__init__.py | 9 + backend/app/crud/evaluations/core.py | 3 + backend/app/crud/evaluations/fast.py | 846 ++++++++++++++++++ backend/app/models/__init__.py | 1 + backend/app/models/evaluation.py | 23 + backend/app/services/evaluations/__init__.py | 5 + .../app/services/evaluations/evaluation.py | 36 +- backend/app/services/evaluations/fast.py | 289 ++++++ .../tests/api/routes/test_evaluation_fast.py | 729 +++++++++++++++ 18 files changed, 2210 insertions(+), 24 deletions(-) create mode 100644 backend/app/alembic/versions/063_add_run_mode_to_evaluation_run.py create mode 100644 backend/app/celery/tasks/evaluation_fast.py create mode 100644 backend/app/crud/evaluations/fast.py create mode 100644 backend/app/services/evaluations/fast.py create mode 100644 backend/app/tests/api/routes/test_evaluation_fast.py diff --git a/backend/app/alembic/versions/063_add_run_mode_to_evaluation_run.py b/backend/app/alembic/versions/063_add_run_mode_to_evaluation_run.py new file mode 100644 index 000000000..77d7e8bce --- /dev/null +++ b/backend/app/alembic/versions/063_add_run_mode_to_evaluation_run.py @@ -0,0 +1,112 @@ +"""add run_mode column and unique run-name constraint to evaluation_run + +Revision ID: 063 +Revises: 062 +Create Date: 2026-05-20 00:00:00.000000 + +Two schema changes are required to support the fast-evaluation feature: + +1. run_mode VARCHAR(10) NOT NULL DEFAULT 'batch' + Every evaluation run must now carry an explicit execution mode ('batch' + or 'fast'). Because evaluation_run may already contain rows the column + is added as nullable with a server_default of 'batch' so existing rows + are backfilled atomically, then SET NOT NULL is applied. The server + default is kept on the column to act as a safety net for any insert path + that does not explicitly set the field; the SQLModel default="batch" will + also cover application inserts. + +2. UNIQUE constraint uq_evaluation_run_org_project_run_name + on (organization_id, project_id, run_name) + Guards against double-click races on POST /evaluations where two + concurrent requests could create duplicate runs with the same name inside + the same project. Lower environments may already contain duplicate + (organization_id, project_id, run_name) tuples, so we dedupe first + (keeping the lowest-id survivor), then build the unique index + CONCURRENTLY (no AccessExclusiveLock on the table during the build) and + attach it as a named constraint via ADD CONSTRAINT ... USING INDEX. + Because CONCURRENTLY cannot execute inside a transaction the whole + migration is marked disable_per_migration_transaction = True and each + CONCURRENTLY step is wrapped in autocommit_block(). The dedupe DELETE and + the non-concurrent DDL steps run inside autocommit mode too, which is + safe — they are each individually atomic at the statement level. +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "063" +down_revision = "062" +branch_labels = None +depends_on = None + +disable_per_migration_transaction = True + +_UNIQUE_INDEX = "uq_evaluation_run_org_project_run_name" +_UNIQUE_CONSTRAINT = "uq_evaluation_run_org_project_run_name" + + +def upgrade(): + # 1. Add run_mode as nullable first so existing rows are backfilled by the + # server default, then tighten to NOT NULL. The server default is left + # in place as a safety net. + with op.get_context().autocommit_block(): + op.add_column( + "evaluation_run", + sa.Column( + "run_mode", + sa.String(length=10), + nullable=True, + server_default=sa.text("'batch'"), + comment="Execution mode: batch or fast", + ), + ) + op.execute("ALTER TABLE evaluation_run ALTER COLUMN run_mode SET NOT NULL") + + # 2. Dedupe existing rows before adding the unique constraint. + # Keep the lowest-id row for each (organization_id, project_id, + # run_name) tuple and remove the rest. + with op.get_context().autocommit_block(): + op.execute( + """ + DELETE FROM evaluation_run + WHERE id IN ( + SELECT id + FROM ( + SELECT id, + ROW_NUMBER() OVER ( + PARTITION BY organization_id, project_id, run_name + ORDER BY id ASC + ) AS rn + FROM evaluation_run + ) sub + WHERE rn > 1 + ) + """ + ) + + # 3. Build the unique index CONCURRENTLY so the scan does not take an + # AccessExclusiveLock, then attach it as a named constraint via + # ADD CONSTRAINT ... USING INDEX (brief catalog-only lock). + with op.get_context().autocommit_block(): + op.execute( + f"CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS " + f'"{_UNIQUE_INDEX}" ' + f"ON evaluation_run (organization_id, project_id, run_name)" + ) + op.execute( + f"ALTER TABLE evaluation_run " + f'ADD CONSTRAINT "{_UNIQUE_CONSTRAINT}" ' + f'UNIQUE USING INDEX "{_UNIQUE_INDEX}"' + ) + + +def downgrade(): + # Reverse in opposite order to upgrade(). + op.execute( + f"ALTER TABLE evaluation_run " + f'DROP CONSTRAINT IF EXISTS "{_UNIQUE_CONSTRAINT}"' + ) + with op.get_context().autocommit_block(): + op.execute(f'DROP INDEX CONCURRENTLY IF EXISTS "{_UNIQUE_INDEX}"') + op.drop_column("evaluation_run", "run_mode") diff --git a/backend/app/api/docs/evaluation/create_evaluation.md b/backend/app/api/docs/evaluation/create_evaluation.md index a719579b2..96719ed53 100644 --- a/backend/app/api/docs/evaluation/create_evaluation.md +++ b/backend/app/api/docs/evaluation/create_evaluation.md @@ -1,16 +1,22 @@ -Start an evaluation run using the OpenAI Batch API. +Start an evaluation run against a stored dataset. -Evaluations allow you to systematically test LLM configurations against -predefined datasets with automatic progress tracking and result collection. +Two execution modes are supported via the optional `run_mode` field: + +* `batch` (default) — submits the work to the OpenAI Batch API. Cost-efficient + for large datasets; turnaround can take up to 24 hours. +* `fast` — runs the evaluation synchronously through the OpenAI Responses API + and returns results within seconds-to-minutes. Restricted to text + evaluations on datasets with at most `EVAL_FAST_MAX_UNIQUE_ROWS` unique rows. **Key Features:** -* Fetches dataset items from Langfuse and creates a batch processing job via the OpenAI Batch API -* Asynchronous processing with automatic progress tracking (checks every 60s) +* Fetches dataset items from Langfuse and creates a job (batch or fast) * Uses a stored config (created via `/configs`) to define the provider parameters -* Stores results for comparison and analysis -* Use `GET /evaluations/{evaluation_id}` to monitor progress and retrieve results +* Same scoring semantics across both modes — cosine similarity, Langfuse traces, + and optional LLM-as-Judge correctness +* Use `GET /evaluations/{evaluation_id}` to monitor progress and retrieve results; + the response carries `run_mode` so clients can tell the two paths apart -## Example +## Example (batch — default) ```json { @@ -20,3 +26,23 @@ predefined datasets with automatic progress tracking and result collection. "config_version": 1 } ``` + +## Example (fast) + +```json +{ + "dataset_id": 123, + "experiment_name": "may19-temp0.2-gpt4o-fast", + "config_id": "f54f0d67-4817-4103-9fdf-b74b3d46733e", + "config_version": 1, + "run_mode": "fast" +} +``` + +## Fast-mode error responses + +| Status | Code | When | +| --- | --- | --- | +| 422 | `config_type_unsupported` | Resolved config is not a text-evaluation config | +| 422 | `dataset_too_large_for_fast` | Dataset exceeds `EVAL_FAST_MAX_UNIQUE_ROWS` unique rows | +| 409 | `run_name_already_exists` | A run with the same `experiment_name` already exists for this (organization, project) | diff --git a/backend/app/api/docs/evaluation/list_datasets.md b/backend/app/api/docs/evaluation/list_datasets.md index e315db1d0..36b7cc1c4 100644 --- a/backend/app/api/docs/evaluation/list_datasets.md +++ b/backend/app/api/docs/evaluation/list_datasets.md @@ -1,3 +1,10 @@ List all datasets for the current organization and project. -Returns a paginated list of datasets ordered by most recent first. Each dataset includes metadata (ID, name, item counts, duplication factor), Langfuse integration details, and object store URL. +Returns a paginated list of datasets ordered by most recent first. Each dataset includes metadata (ID, name, item counts, duplication factor), Langfuse integration details, object store URL, and an `eligible_for_fast` flag that is `true` when the dataset's unique-row count is within `EVAL_FAST_MAX_UNIQUE_ROWS` (and so can be used with `run_mode="fast"` on `POST /evaluations`). + +## Query parameters + +| Parameter | Description | +| --- | --- | +| `limit` / `offset` | Pagination (default 50 / 0; max limit 100) | +| `eligible_for` | If set to `fast`, the response is filtered to only datasets where `eligible_for_fast` is `true` | diff --git a/backend/app/api/routes/evaluations/dataset.py b/backend/app/api/routes/evaluations/dataset.py index 202774da2..20f59f9cf 100644 --- a/backend/app/api/routes/evaluations/dataset.py +++ b/backend/app/api/routes/evaluations/dataset.py @@ -22,6 +22,7 @@ from app.crud.evaluations.dataset import delete_dataset as delete_dataset_crud from app.models.evaluation import DatasetUploadResponse, EvaluationDataset from app.services.evaluations import ( + is_dataset_fast_eligible, upload_dataset as upload_evaluation_dataset, validate_csv_file, ) @@ -39,13 +40,15 @@ def _dataset_to_response( dataset: EvaluationDataset, signed_url: str | None = None ) -> DatasetUploadResponse: """Convert a dataset model to a DatasetUploadResponse.""" + original_items = dataset.dataset_metadata.get("original_items_count", 0) return DatasetUploadResponse( dataset_id=dataset.id, dataset_name=dataset.name, description=dataset.description, total_items=dataset.dataset_metadata.get("total_items_count", 0), - original_items=dataset.dataset_metadata.get("original_items_count", 0), + original_items=original_items, duplication_factor=dataset.dataset_metadata.get("duplication_factor", 1), + eligible_for_fast=is_dataset_fast_eligible(original_items_count=original_items), langfuse_dataset_id=dataset.langfuse_dataset_id, object_store_url=dataset.object_store_url, signed_url=signed_url, @@ -104,6 +107,15 @@ def list_datasets( default=50, ge=1, le=100, description="Maximum number of datasets to return" ), offset: int = Query(default=0, ge=0, description="Number of datasets to skip"), + eligible_for: str + | None = Query( + default=None, + description=( + "If 'fast', return only datasets eligible for run_mode='fast' " + "(unique-row count within EVAL_FAST_MAX_UNIQUE_ROWS)." + ), + enum=["fast"], + ), ) -> APIResponse[list[DatasetUploadResponse]]: """List evaluation datasets.""" datasets = list_evaluation_datasets( @@ -114,9 +126,11 @@ def list_datasets( offset=offset, ) - return APIResponse.success_response( - data=[_dataset_to_response(dataset) for dataset in datasets] - ) + responses = [_dataset_to_response(dataset) for dataset in datasets] + if eligible_for == "fast": + responses = [r for r in responses if r.eligible_for_fast] + + return APIResponse.success_response(data=responses) @router.get( diff --git a/backend/app/api/routes/evaluations/evaluation.py b/backend/app/api/routes/evaluations/evaluation.py index 591f5d985..800b5b0cc 100644 --- a/backend/app/api/routes/evaluations/evaluation.py +++ b/backend/app/api/routes/evaluations/evaluation.py @@ -3,6 +3,7 @@ import logging from uuid import UUID +from asgi_correlation_id import correlation_id from fastapi import ( APIRouter, Body, @@ -14,11 +15,12 @@ from app.api.deps import AuthContextDep, SessionDep from app.crud.evaluations import list_evaluation_runs as list_evaluation_runs_crud from app.crud.evaluations.core import group_traces_by_question_id -from app.models.evaluation import EvaluationRunPublic +from app.models.evaluation import EvaluationRunPublic, RunModeEnum from app.api.permissions import Permission, require_permission from app.services.evaluations import ( get_evaluation_with_scores, start_evaluation, + validate_and_start_fast_evaluation, ) from app.utils import ( APIResponse, @@ -45,8 +47,32 @@ def evaluate( ), config_id: UUID = Body(..., description="Stored config ID"), config_version: int = Body(..., ge=1, description="Stored config version"), + run_mode: RunModeEnum = Body( + default=RunModeEnum.BATCH, + description="Execution mode: 'batch' (default) or 'fast'", + ), ) -> APIResponse[EvaluationRunPublic]: """Start an evaluation run.""" + logger.info( + f"[evaluate] Starting evaluation | run_mode={run_mode.value} | " + f"experiment_name={experiment_name} | dataset_id={dataset_id} | " + f"org_id={auth_context.organization_.id} | " + f"project_id={auth_context.project_.id}" + ) + + if run_mode == RunModeEnum.FAST: + eval_run = validate_and_start_fast_evaluation( + session=session, + dataset_id=dataset_id, + run_name=experiment_name, + config_id=config_id, + config_version=config_version, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + trace_id=correlation_id.get() or "N/A", + ) + return APIResponse.success_response(data=eval_run) + eval_run = start_evaluation( session=session, dataset_id=dataset_id, diff --git a/backend/app/celery/celery_app.py b/backend/app/celery/celery_app.py index b78e7b15f..a4031b202 100644 --- a/backend/app/celery/celery_app.py +++ b/backend/app/celery/celery_app.py @@ -164,6 +164,7 @@ def initialize_worker(**_) -> None: include=[ "app.celery.tasks.job_execution", "app.celery.tasks.notifications", + "app.celery.tasks.evaluation_fast", ], ) @@ -186,6 +187,12 @@ def initialize_worker(**_) -> None: routing_key="low", queue_arguments={"x-max-priority": 1}, ), + Queue( + "evaluations", + exchange=default_exchange, + routing_key="evaluations", + queue_arguments={"x-max-priority": 6}, + ), Queue("cron", exchange=default_exchange, routing_key="cron"), Queue("default", exchange=default_exchange, routing_key="default"), ), diff --git a/backend/app/celery/tasks/evaluation_fast.py b/backend/app/celery/tasks/evaluation_fast.py new file mode 100644 index 000000000..55752ffb1 --- /dev/null +++ b/backend/app/celery/tasks/evaluation_fast.py @@ -0,0 +1,51 @@ +""" +Celery task for the synchronous (fast) text-evaluation pipeline. + +This module hosts the single orchestrator task per fast evaluation run. The +heavy lifting lives in `app/services/evaluations/fast.py`; this task is a thin +shim that sets the correlation id, attaches the OTel parent context, and +delegates. + +See `Fast Evaluation SRD.md` for the design (queue, retries, idempotency). +""" + +import logging + +from celery import current_task + +from app.celery.celery_app import celery_app +from app.celery.tasks.job_execution import _run_with_otel_parent, _set_trace +from app.celery.utils import gevent_timeout +from app.core.config import settings + +logger = logging.getLogger(__name__) + + +@celery_app.task(bind=True, queue="evaluations", priority=6) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_evaluation_fast") +def run_evaluation_fast( + self, eval_run_id: int, trace_id: str = "N/A", **kwargs +) -> None: + """Run the fast evaluation pipeline for one EvaluationRun. + + Idempotency: each stage is skipped on retry when its `batch_job` marker is + already set on the EvaluationRun, so Celery redelivery never re-calls + OpenAI for work that already succeeded. + + Args: + eval_run_id: ID of the EvaluationRun (run_mode="fast"). + trace_id: Correlation id from the enqueueing request, propagated into + the worker for log correlation. + """ + from app.services.evaluations.fast import execute_fast_evaluation + + _set_trace(trace_id) + logger.info( + f"[run_evaluation_fast] Starting fast evaluation task | " + f"eval_run_id={eval_run_id} | task_id={current_task.request.id}" + ) + + return _run_with_otel_parent( + self, + lambda: execute_fast_evaluation(eval_run_id=eval_run_id), + ) diff --git a/backend/app/celery/utils.py b/backend/app/celery/utils.py index 288cba7c4..f643c7df6 100644 --- a/backend/app/celery/utils.py +++ b/backend/app/celery/utils.py @@ -198,6 +198,22 @@ def start_tts_result_processing( return task_id +def start_fast_evaluation(eval_run_id: int, trace_id: str = "N/A") -> str: + """Enqueue the run_evaluation_fast orchestrator task for one EvaluationRun.""" + from app.celery.tasks.evaluation_fast import run_evaluation_fast + + task_id = _enqueue_with_trace_context( + run_evaluation_fast, + eval_run_id=eval_run_id, + trace_id=trace_id, + ) + logger.info( + f"[start_fast_evaluation] Enqueued fast eval | " + f"eval_run_id={eval_run_id} | task_id={task_id}" + ) + return task_id + + def get_task_status(task_id: str) -> Dict[str, Any]: result = AsyncResult(task_id) return { diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 720846eb9..e1b15f0da 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -171,6 +171,12 @@ def AWS_S3_BUCKET(self) -> str: DOC_TRANSFORMATION_PENDING_THRESHOLD_MINUTES: int = 30 PENDING_JOB_QUERY_TIMEOUT_MS: int = 1000 + # Fast evaluation (run_mode="fast") configuration. + # See "Fast Evaluation SRD.md" for the full design rationale. + EVAL_FAST_MAX_UNIQUE_ROWS: int = 10 + EVAL_FAST_FAILURE_THRESHOLD: float = 0.5 + EVAL_FAST_API_CONCURRENCY: int = 10 + @computed_field # type: ignore[prop-decorator] @property def COMPUTED_CELERY_WORKER_CONCURRENCY(self) -> int: diff --git a/backend/app/crud/evaluations/__init__.py b/backend/app/crud/evaluations/__init__.py index a5824c0a2..2c0bd8f48 100644 --- a/backend/app/crud/evaluations/__init__.py +++ b/backend/app/crud/evaluations/__init__.py @@ -25,6 +25,11 @@ calculate_cosine_similarity, start_embedding_batch, ) +from app.crud.evaluations.fast import ( + JOB_TYPE_EMBEDDING_FAST, + JOB_TYPE_EVALUATION_FAST, + run_fast_evaluation, +) from app.crud.evaluations.langfuse import ( create_langfuse_dataset_run, fetch_trace_scores_from_langfuse, @@ -65,6 +70,10 @@ "upload_csv_to_object_store", # Batch "start_evaluation_batch", + # Fast eval + "JOB_TYPE_EMBEDDING_FAST", + "JOB_TYPE_EVALUATION_FAST", + "run_fast_evaluation", # Processing "check_and_process_evaluation", "poll_all_pending_evaluations", diff --git a/backend/app/crud/evaluations/core.py b/backend/app/crud/evaluations/core.py index f3d51c56d..271384615 100644 --- a/backend/app/crud/evaluations/core.py +++ b/backend/app/crud/evaluations/core.py @@ -62,6 +62,7 @@ def create_evaluation_run( config_version: int, organization_id: int, project_id: int, + run_mode: str = "batch", ) -> EvaluationRun: """ Create a new evaluation run record in the database. @@ -75,6 +76,7 @@ def create_evaluation_run( config_version: Version number of the config organization_id: Organization ID project_id: Project ID + run_mode: Execution mode ("batch" default, or "fast") Returns: The created EvaluationRun instance @@ -87,6 +89,7 @@ def create_evaluation_run( config_id=config_id, config_version=config_version, status="pending", + run_mode=run_mode, organization_id=organization_id, project_id=project_id, inserted_at=now(), diff --git a/backend/app/crud/evaluations/fast.py b/backend/app/crud/evaluations/fast.py new file mode 100644 index 000000000..d78785104 --- /dev/null +++ b/backend/app/crud/evaluations/fast.py @@ -0,0 +1,846 @@ +""" +Fast evaluation orchestration. + +This module implements run_mode="fast" — the synchronous text-evaluation path +described in `Fast Evaluation SRD.md`. Unlike the batch path it does not submit +an OpenAI Batch job; instead it makes Responses + Embeddings calls in parallel +from a single Celery orchestrator task and persists per-stage units to S3. + +Stage layout (each stage is skipped on retry if its `batch_job` row already +exists, mirroring how the batch path tracks state): + + Stage 1 — Responses unit: evaluation_run.batch_job_id + Stage 2 — Embeddings unit: evaluation_run.embedding_batch_job_id + Stage 3 — Score + trace + cost (no marker; each step is idempotent) + Stage 4 — Mark completed +""" + +import logging +import random +import time +from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any, TypeVar + +import numpy as np +import openai +from langfuse import Langfuse +from openai import OpenAI +from sqlmodel import Session + +from app.core.cloud.storage import get_cloud_storage +from app.core.config import settings +from app.core.storage_utils import ( + load_json_from_object_store, + upload_jsonl_to_object_store, +) +from app.crud.evaluations.batch import fetch_dataset_items +from app.crud.evaluations.core import resolve_model_from_config, update_evaluation_run +from app.crud.evaluations.cost import attach_cost +from app.crud.evaluations.embeddings import ( + EMBEDDING_MODEL, + calculate_cosine_similarity, +) +from app.crud.evaluations.langfuse import ( + create_langfuse_dataset_run, + update_traces_with_cosine_scores, +) +from app.crud.job import create_batch_job, get_batch_job +from app.models import EvaluationRun, EvaluationRunUpdate +from app.models.batch_job import BatchJobCreate +from app.models.evaluation import RunModeEnum +from app.models.llm.request import TextLLMParams + +logger = logging.getLogger(__name__) + +_T = TypeVar("_T") + + +# job_type discriminators on batch_job for the two stages of the fast path. +# The OpenAI Batch concepts (provider_batch_id, provider_file_id, …) are NULL +# on these rows; the presence of the row + its raw_output_url is what tells the +# orchestrator a stage is already done. +JOB_TYPE_EVALUATION_FAST = "evaluation_fast" +JOB_TYPE_EMBEDDING_FAST = "embedding_fast" + + +# Retry policy for individual OpenAI calls inside Stage 1 / Stage 2. +# Retries on rate-limit, timeout, and connection errors; permanent errors +# (auth, validation, model not found) fail the item immediately. +_RETRY_MAX_ATTEMPTS = 5 +_RETRY_BASE_DELAY_SECONDS = 1.0 +_RETRY_MAX_DELAY_SECONDS = 30.0 + +# These OpenAI exception classes are retryable transient failures. +_RETRYABLE_OPENAI_ERRORS: tuple[type[Exception], ...] = ( + openai.RateLimitError, + openai.APITimeoutError, + openai.APIConnectionError, + openai.InternalServerError, +) + + +def _sleep_with_backoff(attempt: int) -> None: + """Exponential backoff with full jitter, capped at _RETRY_MAX_DELAY_SECONDS.""" + base = min( + _RETRY_BASE_DELAY_SECONDS * (2 ** (attempt - 1)), _RETRY_MAX_DELAY_SECONDS + ) + delay = random.uniform(0, base) + time.sleep(delay) + + +def _call_with_retry(label: str, fn: Callable[[], _T]) -> _T: + """Call `fn()` with retry on transient OpenAI errors. + + Returns the function's result on success. Raises the last exception when + retries are exhausted or when the error is non-retryable. The loop always + either returns or raises on the final attempt; the explicit + `RuntimeError` at the tail keeps mypy happy about the function's return + being reachable. + """ + for attempt in range(1, _RETRY_MAX_ATTEMPTS + 1): + try: + return fn() + except _RETRYABLE_OPENAI_ERRORS as exc: + if attempt == _RETRY_MAX_ATTEMPTS: + logger.warning( + f"[_call_with_retry] Exhausted retries | label={label} | " + f"attempt={attempt} | error={exc}" + ) + raise + logger.info( + f"[_call_with_retry] Transient error, retrying | label={label} | " + f"attempt={attempt} | error={exc}" + ) + _sleep_with_backoff(attempt) + except openai.OpenAIError as exc: + # Permanent OpenAI errors (auth, validation, model not found) — fail fast. + logger.warning( + f"[_call_with_retry] Permanent error, no retry | label={label} | " + f"error={exc}" + ) + raise + raise RuntimeError(f"_call_with_retry exited without result for {label}") + + +def _build_responses_params( + *, + config: TextLLMParams, + question: str, +) -> dict[str, Any]: + """Build the parameter dict for one OpenAI Responses API call. + + Mirrors the body shape `build_evaluation_jsonl` uses for the batch path so + fast eval and batch eval generate equivalent outputs for the same config. + """ + params: dict[str, Any] = { + "model": config.model, + "instructions": config.instructions, + "input": question, + } + + if "temperature" in config.model_fields_set: + params["temperature"] = config.temperature + + if config.reasoning: + params["reasoning"] = {"effort": config.reasoning} + + if config.knowledge_base_ids: + params["tools"] = [ + { + "type": "file_search", + "vector_store_ids": config.knowledge_base_ids, + "max_num_results": config.max_num_results or 20, + } + ] + + return params + + +def _field(obj: Any, name: str, default: Any = None) -> Any: + """Read a field from an object or dict, returning default when missing. + + OpenAI SDK objects expose fields as attributes; older response shapes and + tests sometimes pass dicts. This unifies access so call sites don't have + to branch on the runtime type. + """ + if obj is None: + return default + if isinstance(obj, dict): + return obj.get(name, default) + return getattr(obj, name, default) + + +def _extract_response_text(response: Any) -> str: + """Extract the model-generated text from an OpenAI Responses object. + + Prefers `response.output_text` (the SDK's flat helper) and falls back to a + structured walk of `response.output` if the helper is empty. + """ + output_text = _field(response, "output_text") + if output_text: + return output_text + + output = _field(response, "output") + if not output: + return "" + + for item in output: + if _field(item, "type") != "message": + continue + for content in _field(item, "content") or []: + if _field(content, "type") == "output_text": + text = _field(content, "text") + if text: + return text + return "" + + +def _usage_to_dict(usage: Any) -> dict[str, int]: + """Normalize an OpenAI usage object into the dict shape kaapi's cost layer expects.""" + return { + "input_tokens": int(_field(usage, "input_tokens", 0) or 0), + "output_tokens": int(_field(usage, "output_tokens", 0) or 0), + "total_tokens": int(_field(usage, "total_tokens", 0) or 0), + } + + +def _responses_call_for_item( + *, + openai_client: OpenAI, + config: TextLLMParams, + item: dict[str, Any], +) -> dict[str, Any]: + """Run one Responses API call for a single dataset item. + + Returns the same per-item shape as `parse_evaluation_output` from the batch + path so downstream code (Langfuse trace creation, embeddings, cost) does + not need to branch on run_mode. + """ + item_id = item["id"] + question = item["input"].get("question", "") if item.get("input") else "" + ground_truth = ( + item["expected_output"].get("answer", "") if item.get("expected_output") else "" + ) + question_id = (item.get("metadata") or {}).get("question_id") + + if not question: + return { + "item_id": item_id, + "question": "", + "generated_output": "ERROR: missing question in dataset item", + "ground_truth": ground_truth, + "response_id": None, + "usage": None, + "question_id": question_id, + "failed": True, + } + + params = _build_responses_params(config=config, question=question) + + try: + response = _call_with_retry( + label=f"responses.create:{item_id}", + fn=lambda: openai_client.responses.create(**params), + ) + except openai.OpenAIError as exc: + logger.warning( + f"[_responses_call_for_item] Item failed | item_id={item_id} | error={exc}" + ) + return { + "item_id": item_id, + "question": question, + "generated_output": f"ERROR: {exc}", + "ground_truth": ground_truth, + "response_id": None, + "usage": None, + "question_id": question_id, + "failed": True, + } + + return { + "item_id": item_id, + "question": question, + "generated_output": _extract_response_text(response), + "ground_truth": ground_truth, + "response_id": getattr(response, "id", None), + "usage": _usage_to_dict(getattr(response, "usage", None)), + "question_id": question_id, + "failed": False, + } + + +def _embedding_call_for_pair( + *, + openai_client: OpenAI, + embedding_model: str, + item_id: str, + output_text: str, + ground_truth: str, +) -> dict[str, Any]: + """Run one Embeddings API call for an (output, ground_truth) pair. + + Returns a dict with both embeddings and the prompt_tokens usage so the + cost layer can price the call. `failed=True` indicates a permanent or + retry-exhausted failure for the pair. + """ + if not output_text or not ground_truth: + return { + "item_id": item_id, + "output_embedding": None, + "ground_truth_embedding": None, + "usage": None, + "failed": True, + "error": "empty output or ground_truth", + } + + try: + response = _call_with_retry( + label=f"embeddings.create:{item_id}", + fn=lambda: openai_client.embeddings.create( + model=embedding_model, + input=[output_text, ground_truth], + encoding_format="float", + ), + ) + except openai.OpenAIError as exc: + logger.warning( + f"[_embedding_call_for_pair] Item failed | item_id={item_id} | error={exc}" + ) + return { + "item_id": item_id, + "output_embedding": None, + "ground_truth_embedding": None, + "usage": None, + "failed": True, + "error": str(exc), + } + + data = _field(response, "data") or [] + if len(data) < 2: + return { + "item_id": item_id, + "output_embedding": None, + "ground_truth_embedding": None, + "usage": None, + "failed": True, + "error": f"expected 2 embeddings, got {len(data)}", + } + + output_embedding: list[float] | None = None + ground_truth_embedding: list[float] | None = None + for emb in data: + index = _field(emb, "index") + vector = _field(emb, "embedding") + if index == 0: + output_embedding = vector + elif index == 1: + ground_truth_embedding = vector + + usage_obj = _field(response, "usage") + usage_dict: dict[str, int] = { + "prompt_tokens": int(_field(usage_obj, "prompt_tokens", 0) or 0), + "total_tokens": int(_field(usage_obj, "total_tokens", 0) or 0), + } + + return { + "item_id": item_id, + "output_embedding": output_embedding, + "ground_truth_embedding": ground_truth_embedding, + "usage": usage_dict, + "failed": output_embedding is None or ground_truth_embedding is None, + } + + +def _is_failure_threshold_breached(*, failed_rows: int, total_rows: int) -> bool: + """True if the failed-row fraction exceeds EVAL_FAST_FAILURE_THRESHOLD.""" + if total_rows == 0: + return False + return (failed_rows / total_rows) > settings.EVAL_FAST_FAILURE_THRESHOLD + + +def _upload_unit_to_s3( + *, + session: Session, + project_id: int, + eval_run_id: int, + filename: str, + results: list[dict[str, Any]], +) -> str | None: + """Upload a stage unit (responses or embeddings) as JSON to S3.""" + storage = get_cloud_storage(session=session, project_id=project_id) + return upload_jsonl_to_object_store( + storage=storage, + results=results, + filename=filename, + subdirectory=f"evaluations/fast/{eval_run_id}", + format="json", + ) + + +def _load_unit_from_s3( + *, session: Session, project_id: int, url: str +) -> list[dict[str, Any]]: + """Load a stage unit back from S3. Raises if the unit cannot be loaded.""" + storage = get_cloud_storage(session=session, project_id=project_id) + data = load_json_from_object_store(storage=storage, url=url) + if data is None: + raise RuntimeError(f"Failed to load fast eval unit from S3 | url={url}") + if not isinstance(data, list): + raise RuntimeError( + f"Fast eval unit at {url} is not a list | type={type(data).__name__}" + ) + return data + + +def _stage1_responses( + *, + session: Session, + openai_client: OpenAI, + eval_run: EvaluationRun, + config: TextLLMParams, + dataset_items: list[dict[str, Any]], + log_prefix: str, +) -> tuple[EvaluationRun, list[dict[str, Any]]]: + """Stage 1 — generate one Responses API call per dataset item. + + Skipped on retry if `eval_run.batch_job_id` is already set; in that case we + reload the persisted unit from S3 so downstream stages still have results. + """ + if eval_run.batch_job_id: + existing = get_batch_job(session=session, batch_job_id=eval_run.batch_job_id) + if existing and existing.raw_output_url: + logger.info( + f"[_stage1_responses] {log_prefix} Skipping stage 1 (already done) | " + f"batch_job_id={existing.id}" + ) + results = _load_unit_from_s3( + session=session, + project_id=eval_run.project_id, + url=existing.raw_output_url, + ) + return eval_run, results + + logger.info( + f"[_stage1_responses] {log_prefix} Running stage 1 | " + f"items={len(dataset_items)} | model={config.model} | " + f"concurrency={settings.EVAL_FAST_API_CONCURRENCY}" + ) + + results: list[dict[str, Any]] = [] + max_workers = max(1, min(settings.EVAL_FAST_API_CONCURRENCY, len(dataset_items))) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit( + _responses_call_for_item, + openai_client=openai_client, + config=config, + item=item, + ): item["id"] + for item in dataset_items + } + for future in as_completed(futures): + results.append(future.result()) + + failed_count = sum(1 for r in results if r.get("failed")) + logger.info( + f"[_stage1_responses] {log_prefix} Stage 1 finished | " + f"total={len(results)} | failed={failed_count}" + ) + + if _is_failure_threshold_breached( + failed_rows=failed_count, total_rows=len(results) + ): + raise RuntimeError( + f"Fast eval Stage 1 exceeded failure threshold | " + f"failed={failed_count}/{len(results)} | " + f"threshold={settings.EVAL_FAST_FAILURE_THRESHOLD}" + ) + + raw_output_url = _upload_unit_to_s3( + session=session, + project_id=eval_run.project_id, + eval_run_id=eval_run.id, + filename=f"responses_{eval_run.id}.json", + results=results, + ) + + # Aggregate usage for the batch_job summary so the cost layer can price the + # stage from raw_output_url alone if it needs to re-derive later. + summed_usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + for r in results: + usage = r.get("usage") or {} + for k in summed_usage: + summed_usage[k] += int(usage.get(k, 0) or 0) + + batch_job = create_batch_job( + session=session, + batch_job_create=BatchJobCreate( + provider="openai", + job_type=JOB_TYPE_EVALUATION_FAST, + config={ + "endpoint": "/v1/responses", + "run_mode": RunModeEnum.FAST.value, + "model": config.model, + "usage": summed_usage, + }, + raw_output_url=raw_output_url, + total_items=len(results), + organization_id=eval_run.organization_id, + project_id=eval_run.project_id, + ), + ) + + # `batch_job_id` and `total_items` aren't on EvaluationRunUpdate (the + # update model only covers fields that change during the lifecycle), so + # set them directly and let `update_evaluation_run` bump `updated_at`. + eval_run.batch_job_id = batch_job.id + eval_run.total_items = len(results) + eval_run = update_evaluation_run( + session=session, + eval_run=eval_run, + update=EvaluationRunUpdate(), + ) + + return eval_run, results + + +def _stage2_embeddings( + *, + session: Session, + openai_client: OpenAI, + eval_run: EvaluationRun, + response_results: list[dict[str, Any]], + log_prefix: str, +) -> tuple[EvaluationRun, list[dict[str, Any]]]: + """Stage 2 — embed each (output, ground_truth) pair for cosine similarity. + + Skipped on retry if `eval_run.embedding_batch_job_id` is already set. + """ + if eval_run.embedding_batch_job_id: + existing = get_batch_job( + session=session, batch_job_id=eval_run.embedding_batch_job_id + ) + if existing and existing.raw_output_url: + logger.info( + f"[_stage2_embeddings] {log_prefix} Skipping stage 2 (already done) | " + f"batch_job_id={existing.id}" + ) + embeddings = _load_unit_from_s3( + session=session, + project_id=eval_run.project_id, + url=existing.raw_output_url, + ) + return eval_run, embeddings + + # Only embed items that succeeded in Stage 1. + embed_candidates = [r for r in response_results if not r.get("failed")] + logger.info( + f"[_stage2_embeddings] {log_prefix} Running stage 2 | " + f"items={len(embed_candidates)} | model={EMBEDDING_MODEL} | " + f"concurrency={settings.EVAL_FAST_API_CONCURRENCY}" + ) + + embedding_results: list[dict[str, Any]] = [] + max_workers = max( + 1, min(settings.EVAL_FAST_API_CONCURRENCY, len(embed_candidates) or 1) + ) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit( + _embedding_call_for_pair, + openai_client=openai_client, + embedding_model=EMBEDDING_MODEL, + item_id=r["item_id"], + output_text=r.get("generated_output", ""), + ground_truth=r.get("ground_truth", ""), + ): r["item_id"] + for r in embed_candidates + } + for future in as_completed(futures): + embedding_results.append(future.result()) + + failed_count = sum(1 for r in embedding_results if r.get("failed")) + # Failure threshold for embeddings is computed over the whole dataset, not + # over the candidate subset — items that failed Stage 1 already count as + # failures from the user's perspective. + total_failures = failed_count + sum(1 for r in response_results if r.get("failed")) + logger.info( + f"[_stage2_embeddings] {log_prefix} Stage 2 finished | " + f"total={len(embedding_results)} | failed={failed_count}" + ) + + if _is_failure_threshold_breached( + failed_rows=total_failures, total_rows=len(response_results) + ): + raise RuntimeError( + f"Fast eval Stage 2 exceeded failure threshold | " + f"failed={total_failures}/{len(response_results)} | " + f"threshold={settings.EVAL_FAST_FAILURE_THRESHOLD}" + ) + + raw_output_url = _upload_unit_to_s3( + session=session, + project_id=eval_run.project_id, + eval_run_id=eval_run.id, + filename=f"embeddings_{eval_run.id}.json", + results=embedding_results, + ) + + summed_usage = {"prompt_tokens": 0, "total_tokens": 0} + for r in embedding_results: + usage = r.get("usage") or {} + for k in summed_usage: + summed_usage[k] += int(usage.get(k, 0) or 0) + + batch_job = create_batch_job( + session=session, + batch_job_create=BatchJobCreate( + provider="openai", + job_type=JOB_TYPE_EMBEDDING_FAST, + config={ + "endpoint": "/v1/embeddings", + "run_mode": RunModeEnum.FAST.value, + "embedding_model": EMBEDDING_MODEL, + "usage": summed_usage, + }, + raw_output_url=raw_output_url, + total_items=len(embedding_results), + organization_id=eval_run.organization_id, + project_id=eval_run.project_id, + ), + ) + + eval_run = update_evaluation_run( + session=session, + eval_run=eval_run, + update=EvaluationRunUpdate(embedding_batch_job_id=batch_job.id), + ) + + return eval_run, embedding_results + + +def _stage3_score_and_trace( + *, + session: Session, + eval_run: EvaluationRun, + langfuse: Langfuse, + response_results: list[dict[str, Any]], + embedding_results: list[dict[str, Any]], + log_prefix: str, +) -> EvaluationRun: + """Stage 3 — compute cosine, create Langfuse traces, attach costs. + + No stage marker; each step is idempotent or near-idempotent: + - Cosine recomputation is deterministic (same vectors → same numbers). + - Langfuse dedupes traces on the dataset_item.observe key. + - attach_cost is idempotent on (eval_run_id, batch_job_id, usage_type). + """ + logger.info( + f"[_stage3_score_and_trace] {log_prefix} Computing cosine + creating traces" + ) + + # 1. Compute cosine similarity per pair, in the same shape as the batch + # path's `parse_embedding_results` + `calculate_average_similarity`. + item_id_to_pair = { + r["item_id"]: r for r in embedding_results if not r.get("failed") + } + + # 2. Create Langfuse traces and get item_id -> trace_id mapping. The model + # name comes from the stored config (same source as batch path) so + # Langfuse's per-generation cost calc stays consistent. + model = resolve_model_from_config(session=session, eval_run=eval_run) + trace_id_mapping = create_langfuse_dataset_run( + langfuse=langfuse, + dataset_name=eval_run.dataset_name, + run_name=eval_run.run_name, + results=response_results, + model=model, + ) + + # 3. Build the per-item score list keyed on Langfuse trace_id, then update + # traces with the cosine score. + per_item_scores: list[dict[str, Any]] = [] + similarities: list[float] = [] + for r in response_results: + item_id = r["item_id"] + pair = item_id_to_pair.get(item_id) + trace_id = trace_id_mapping.get(item_id) + if not pair or not trace_id: + continue + if ( + pair.get("output_embedding") is None + or pair.get("ground_truth_embedding") is None + ): + continue + score = calculate_cosine_similarity( + pair["output_embedding"], pair["ground_truth_embedding"] + ) + similarities.append(score) + per_item_scores.append({"trace_id": trace_id, "cosine_similarity": score}) + + if per_item_scores: + try: + update_traces_with_cosine_scores( + langfuse=langfuse, per_item_scores=per_item_scores + ) + except Exception as exc: + # Mirror the batch path: Langfuse score-update failures don't fail + # the run; they get logged and the score still lives in eval_run.score. + logger.warning( + f"[_stage3_score_and_trace] {log_prefix} " + f"Failed to update Langfuse traces with scores | error={exc}", + exc_info=True, + ) + + # 4. Aggregate similarity stats — same shape as the batch path so + # summary_scores rendering on GET /evaluations/{id} stays identical. + if similarities: + sim_array = np.array(similarities) + avg = float(np.mean(sim_array)) + std = float(np.std(sim_array)) + else: + avg = 0.0 + std = 0.0 + + score_payload = { + "summary_scores": [ + { + "name": "Cosine Similarity", + "avg": round(avg, 2), + "std": round(std, 2), + "total_pairs": len(similarities), + "data_type": "NUMERIC", + } + ] + } + + # 5. Attach costs (response stage + embedding stage). attach_cost is + # idempotent on its natural key — see "Known Limitations" in the SRD. + if response_results: + attach_cost( + session=session, + eval_run=eval_run, + log_prefix=log_prefix, + response_model=model, + response_results=response_results, + ) + + # Embedding cost expects the raw OpenAI batch shape; rebuild it from our + # in-memory embedding_results so attach_cost can price the stage uniformly. + if embedding_results: + embedding_raw = [ + { + "response": { + "body": { + "usage": r.get("usage") + or {"prompt_tokens": 0, "total_tokens": 0} + } + } + } + for r in embedding_results + if not r.get("failed") + ] + if embedding_raw: + attach_cost( + session=session, + eval_run=eval_run, + log_prefix=log_prefix, + embedding_model=EMBEDDING_MODEL, + embedding_raw_results=embedding_raw, + ) + + eval_run = update_evaluation_run( + session=session, + eval_run=eval_run, + update=EvaluationRunUpdate( + score=score_payload, + cost=eval_run.cost, + ), + ) + + return eval_run + + +def run_fast_evaluation( + *, + session: Session, + openai_client: OpenAI, + langfuse: Langfuse, + eval_run: EvaluationRun, + config: TextLLMParams, +) -> EvaluationRun: + """Run the full fast-eval pipeline for one evaluation_run. + + Called from the `run_evaluation_fast` Celery task. Each stage is skipped on + retry if its batch_job marker is already set, so re-invocation does not + re-call OpenAI for work that already succeeded. + + Raises: + Exception: If any stage fails terminally; the orchestrator marks the + run failed. + """ + log_prefix = ( + f"[org={eval_run.organization_id}]" + f"[project={eval_run.project_id}]" + f"[eval={eval_run.id}]" + ) + logger.info(f"[run_fast_evaluation] {log_prefix} Starting fast evaluation pipeline") + + # Mark as processing immediately so the GET endpoint reflects state. + if eval_run.status == "pending": + eval_run = update_evaluation_run( + session=session, + eval_run=eval_run, + update=EvaluationRunUpdate(status="processing"), + ) + + dataset_items = fetch_dataset_items( + langfuse=langfuse, dataset_name=eval_run.dataset_name + ) + if not dataset_items: + raise ValueError( + f"Dataset '{eval_run.dataset_name}' returned no items for fast eval" + ) + + # Stage 1 + eval_run, response_results = _stage1_responses( + session=session, + openai_client=openai_client, + eval_run=eval_run, + config=config, + dataset_items=dataset_items, + log_prefix=log_prefix, + ) + + # Stage 2 + eval_run, embedding_results = _stage2_embeddings( + session=session, + openai_client=openai_client, + eval_run=eval_run, + response_results=response_results, + log_prefix=log_prefix, + ) + + # Stage 3 + eval_run = _stage3_score_and_trace( + session=session, + eval_run=eval_run, + langfuse=langfuse, + response_results=response_results, + embedding_results=embedding_results, + log_prefix=log_prefix, + ) + + # Stage 4 + eval_run = update_evaluation_run( + session=session, + eval_run=eval_run, + update=EvaluationRunUpdate(status="completed"), + ) + + logger.info( + f"[run_fast_evaluation] {log_prefix} Fast evaluation completed | " + f"total_items={eval_run.total_items}" + ) + return eval_run diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index b8eec6cde..b84af7482 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -90,6 +90,7 @@ EvaluationRunCreate, EvaluationRunPublic, EvaluationRunUpdate, + RunModeEnum, ) from .feature_flag import ( FeatureFlag, diff --git a/backend/app/models/evaluation.py b/backend/app/models/evaluation.py index 8a95ff054..49a47871d 100644 --- a/backend/app/models/evaluation.py +++ b/backend/app/models/evaluation.py @@ -1,4 +1,5 @@ from datetime import datetime +from enum import Enum from typing import TYPE_CHECKING, Any, Optional from uuid import UUID @@ -16,6 +17,11 @@ from .project import Project +class RunModeEnum(str, Enum): + BATCH = "batch" + FAST = "fast" + + class DatasetItem(BaseModel): """Model for a single dataset item (Q&A pair).""" @@ -47,6 +53,10 @@ class DatasetUploadResponse(BaseModel): signed_url: str | None = Field( None, description="A signed URL for downloading the dataset" ) + eligible_for_fast: bool = Field( + False, + description="True if dataset has ≤10 unique rows and is eligible for run_mode=fast", + ) class EvaluationResult(BaseModel): @@ -192,6 +202,12 @@ class EvaluationRun(SQLModel, table=True): __table_args__ = ( Index("idx_eval_run_status_org", "status", "organization_id"), Index("idx_eval_run_status_project", "status", "project_id"), + UniqueConstraint( + "organization_id", + "project_id", + "run_name", + name="uq_evaluation_run_org_project_run_name", + ), ) id: int = SQLField( @@ -284,6 +300,12 @@ class EvaluationRun(SQLModel, table=True): "comment": "Evaluation status (pending, processing, completed, failed)" }, ) + run_mode: str = SQLField( + default="batch", + max_length=10, + nullable=False, + sa_column_kwargs={"comment": "Execution mode: batch or fast"}, + ) object_store_url: str | None = SQLField( default=None, description="Object store URL of processed evaluation results for future reference", @@ -421,6 +443,7 @@ class EvaluationRunPublic(SQLModel): batch_job_id: int | None embedding_batch_job_id: int | None status: str + run_mode: str object_store_url: str | None score_trace_url: str | None total_items: int diff --git a/backend/app/services/evaluations/__init__.py b/backend/app/services/evaluations/__init__.py index 92d88fe0b..3776e60c4 100644 --- a/backend/app/services/evaluations/__init__.py +++ b/backend/app/services/evaluations/__init__.py @@ -5,6 +5,11 @@ get_evaluation_with_scores, start_evaluation, ) +from app.services.evaluations.fast import ( + execute_fast_evaluation, + is_dataset_fast_eligible, + validate_and_start_fast_evaluation, +) from app.services.evaluations.validators import ( ALLOWED_EXTENSIONS, ALLOWED_MIME_TYPES, diff --git a/backend/app/services/evaluations/evaluation.py b/backend/app/services/evaluations/evaluation.py index 9f78b9e37..5e4abd0dc 100644 --- a/backend/app/services/evaluations/evaluation.py +++ b/backend/app/services/evaluations/evaluation.py @@ -4,6 +4,7 @@ from uuid import UUID from fastapi import HTTPException +from sqlalchemy.exc import IntegrityError from sqlmodel import Session from app.crud.evaluations import ( @@ -129,16 +130,31 @@ def start_evaluation( ) # Step 3: Create EvaluationRun record with config references - eval_run = create_evaluation_run( - session=session, - run_name=experiment_name, - dataset_name=dataset.name, - dataset_id=dataset_id, - config_id=config_id, - config_version=config_version, - organization_id=organization_id, - project_id=project_id, - ) + try: + eval_run = create_evaluation_run( + session=session, + run_name=experiment_name, + dataset_name=dataset.name, + dataset_id=dataset_id, + config_id=config_id, + config_version=config_version, + organization_id=organization_id, + project_id=project_id, + ) + except IntegrityError: + session.rollback() + logger.warning( + f"[start_evaluation] Duplicate run_name | run_name={experiment_name} | " + f"org_id={organization_id} | project_id={project_id}" + ) + raise HTTPException( + status_code=409, + detail=( + f"run_name_already_exists: a run with name '{experiment_name}' " + "already exists for this organization and project. Pick a new " + "run_name or fetch the existing run via GET /evaluations." + ), + ) # Step 4: Start the batch evaluation try: diff --git a/backend/app/services/evaluations/fast.py b/backend/app/services/evaluations/fast.py new file mode 100644 index 000000000..d3e316b4b --- /dev/null +++ b/backend/app/services/evaluations/fast.py @@ -0,0 +1,289 @@ +"""Fast evaluation orchestration service. + +This is the only place that decides whether a /evaluations request enters the +fast-eval path. It also hosts the worker-side entry point invoked by the +`run_evaluation_fast` Celery task. + +See `Fast Evaluation SRD.md` for the full design. +""" + +import logging +from uuid import UUID + +from fastapi import HTTPException +from sqlalchemy.exc import IntegrityError +from sqlmodel import Session + +from app.celery.utils import start_fast_evaluation as enqueue_fast_evaluation +from app.core.config import settings +from app.core.db import engine +from app.crud.evaluations import ( + create_evaluation_run, + get_dataset_by_id, + resolve_evaluation_config, + run_fast_evaluation, +) +from app.crud.evaluations.core import update_evaluation_run +from app.models.evaluation import EvaluationRun, EvaluationRunUpdate, RunModeEnum +from app.models.llm.request import TextLLMParams +from app.services.llm.providers import LLMProvider +from app.utils import get_langfuse_client, get_openai_client + +logger = logging.getLogger(__name__) + + +# Error codes surfaced in HTTPException.detail so the UI can localize/branch. +ERR_CONFIG_TYPE_UNSUPPORTED = "config_type_unsupported" +ERR_DATASET_TOO_LARGE_FOR_FAST = "dataset_too_large_for_fast" +ERR_RUN_NAME_ALREADY_EXISTS = "run_name_already_exists" + + +def is_dataset_fast_eligible(*, original_items_count: int) -> bool: + """A dataset is eligible for fast mode when its unique-row count is within cap.""" + return original_items_count <= settings.EVAL_FAST_MAX_UNIQUE_ROWS + + +def validate_and_start_fast_evaluation( + *, + session: Session, + dataset_id: int, + run_name: str, + config_id: UUID, + config_version: int, + organization_id: int, + project_id: int, + trace_id: str = "N/A", +) -> EvaluationRun: + """Validate + create + dispatch a fast evaluation run. + + Validation (in order): + 1. Dataset exists and has a Langfuse id. + 2. Config resolves to a text-type OpenAI config. + 3. Dataset's original_items_count <= EVAL_FAST_MAX_UNIQUE_ROWS. + 4. (organization_id, project_id, run_name) is unique — enforced by the DB + constraint added in migration 063. We translate IntegrityError to 409. + + On success the function creates the EvaluationRun row with + `run_mode="fast"`, `status="processing"`, and enqueues the orchestrator + task. The caller (route) returns the row immediately. + """ + logger.info( + f"[validate_and_start_fast_evaluation] Starting fast eval | " + f"run_name={run_name} | dataset_id={dataset_id} | " + f"org_id={organization_id} | project_id={project_id}" + ) + + # 1. Dataset must exist + have a Langfuse id (same as batch path). + dataset = get_dataset_by_id( + session=session, + dataset_id=dataset_id, + organization_id=organization_id, + project_id=project_id, + ) + if not dataset: + raise HTTPException( + status_code=404, + detail=( + f"Dataset {dataset_id} not found or not accessible to this " + "organization/project" + ), + ) + if not dataset.langfuse_dataset_id: + raise HTTPException( + status_code=400, + detail=( + f"Dataset {dataset_id} has no Langfuse dataset id; cannot run " + "evaluation." + ), + ) + + # 2. Config must resolve and be a text OpenAI config. + config_blob, error = resolve_evaluation_config( + session=session, + config_id=config_id, + config_version=config_version, + project_id=project_id, + ) + if error or config_blob is None: + raise HTTPException( + status_code=400, + detail=f"Failed to resolve config: {error}", + ) + if config_blob.completion.provider != LLMProvider.OPENAI: + raise HTTPException( + status_code=422, + detail="Only 'openai' provider is supported for evaluation configs", + ) + if config_blob.completion.type != "text": + raise HTTPException( + status_code=422, + detail=ERR_CONFIG_TYPE_UNSUPPORTED, + ) + + # 3. Dataset must be small enough for fast eval. + original_items_count = (dataset.dataset_metadata or {}).get( + "original_items_count" + ) or 0 + if not is_dataset_fast_eligible(original_items_count=original_items_count): + raise HTTPException( + status_code=422, + detail=( + f"{ERR_DATASET_TOO_LARGE_FOR_FAST}: dataset has " + f"{original_items_count} unique rows; fast mode requires at most " + f"{settings.EVAL_FAST_MAX_UNIQUE_ROWS}." + ), + ) + + # 4. Create the run; rely on the DB unique constraint to catch double-clicks. + try: + eval_run = create_evaluation_run( + session=session, + run_name=run_name, + dataset_name=dataset.name, + dataset_id=dataset_id, + config_id=config_id, + config_version=config_version, + organization_id=organization_id, + project_id=project_id, + run_mode=RunModeEnum.FAST.value, + ) + except IntegrityError: + session.rollback() + logger.warning( + f"[validate_and_start_fast_evaluation] Duplicate run_name | " + f"run_name={run_name} | org_id={organization_id} | " + f"project_id={project_id}" + ) + raise HTTPException( + status_code=409, + detail=( + f"{ERR_RUN_NAME_ALREADY_EXISTS}: a run with name '{run_name}' " + "already exists for this organization and project. Pick a new " + "run_name or fetch the existing run via GET /evaluations." + ), + ) + + # Flip to processing before dispatching the task so the GET endpoint + # reflects the correct state immediately. + eval_run = update_evaluation_run( + session=session, + eval_run=eval_run, + update=EvaluationRunUpdate(status="processing"), + ) + + # Dispatch the orchestrator. If enqueue fails, mark the run as failed so it + # doesn't linger in `processing` forever. + try: + enqueue_fast_evaluation(eval_run_id=eval_run.id, trace_id=trace_id) + except Exception as exc: + logger.error( + f"[validate_and_start_fast_evaluation] Failed to enqueue task | " + f"eval_run_id={eval_run.id} | error={exc}", + exc_info=True, + ) + update_evaluation_run( + session=session, + eval_run=eval_run, + update=EvaluationRunUpdate( + status="failed", + error_message=f"Failed to enqueue fast eval task: {exc}", + ), + ) + raise HTTPException( + status_code=500, + detail="Failed to enqueue fast evaluation task", + ) + + return eval_run + + +def execute_fast_evaluation(*, eval_run_id: int) -> None: + """Worker-side entry point: run the full fast-eval pipeline. + + Called from the `run_evaluation_fast` Celery task. Opens its own DB + session so the task is self-contained, then resolves config + clients and + delegates to `run_fast_evaluation` (CRUD). + + On terminal failure the run is marked `failed` with a descriptive + error_message and the exception is re-raised so Celery records the failure + (no automatic retry for this task — stage-level idempotency is the retry + surface). + """ + logger.info(f"[execute_fast_evaluation] Starting | eval_run_id={eval_run_id}") + + with Session(engine) as session: + eval_run = session.get(EvaluationRun, eval_run_id) + if eval_run is None: + logger.error( + f"[execute_fast_evaluation] EvaluationRun not found | " + f"eval_run_id={eval_run_id}" + ) + raise ValueError(f"EvaluationRun {eval_run_id} not found") + + if eval_run.run_mode != RunModeEnum.FAST.value: + logger.error( + f"[execute_fast_evaluation] Wrong run_mode for fast task | " + f"eval_run_id={eval_run_id} | run_mode={eval_run.run_mode}" + ) + raise ValueError( + f"EvaluationRun {eval_run_id} has run_mode={eval_run.run_mode}, " + f"expected 'fast'" + ) + + if eval_run.status == "completed": + logger.info( + f"[execute_fast_evaluation] Run already completed, skipping | " + f"eval_run_id={eval_run_id}" + ) + return + + try: + config_blob, error = resolve_evaluation_config( + session=session, + config_id=eval_run.config_id, + config_version=eval_run.config_version, + project_id=eval_run.project_id, + ) + if error or config_blob is None: + raise ValueError(f"Failed to resolve config: {error}") + + text_params = TextLLMParams.model_validate(config_blob.completion.params) + + openai_client = get_openai_client( + session=session, + org_id=eval_run.organization_id, + project_id=eval_run.project_id, + ) + langfuse_client = get_langfuse_client( + session=session, + org_id=eval_run.organization_id, + project_id=eval_run.project_id, + ) + + run_fast_evaluation( + session=session, + openai_client=openai_client, + langfuse=langfuse_client, + eval_run=eval_run, + config=text_params, + ) + + except Exception as exc: + logger.error( + f"[execute_fast_evaluation] Run failed | " + f"eval_run_id={eval_run_id} | error={exc}", + exc_info=True, + ) + # Re-fetch the row in case our session was rolled back. + session.rollback() + failed_run = session.get(EvaluationRun, eval_run_id) + if failed_run is not None: + update_evaluation_run( + session=session, + eval_run=failed_run, + update=EvaluationRunUpdate( + status="failed", + error_message=f"Fast eval failed: {exc}", + ), + ) + raise diff --git a/backend/app/tests/api/routes/test_evaluation_fast.py b/backend/app/tests/api/routes/test_evaluation_fast.py new file mode 100644 index 000000000..b5b0ffb28 --- /dev/null +++ b/backend/app/tests/api/routes/test_evaluation_fast.py @@ -0,0 +1,729 @@ +"""Tests for the fast (synchronous) evaluation path. + +Covers FR-1 through FR-15 from the Fast Evaluation SRD. External boundaries +(OpenAI, Langfuse, S3, Celery dispatch) are mocked; the DB is real (`db` +fixture). +""" + +import random +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import openai +import pytest +from fastapi.testclient import TestClient +from sqlmodel import Session + +from app.crud.evaluations.fast import ( + JOB_TYPE_EMBEDDING_FAST, + JOB_TYPE_EVALUATION_FAST, + _call_with_retry, + _is_failure_threshold_breached, + _stage1_responses, + _stage2_embeddings, + run_fast_evaluation, +) +from app.models import EvaluationRun +from app.models.batch_job import BatchJob +from app.models.evaluation import RunModeEnum +from app.models.llm.request import ( + ConfigBlob, + KaapiCompletionConfig, + TextLLMParams, +) +from app.tests.utils.auth import TestAuthContext +from app.tests.utils.test_data import ( + create_test_config, + create_test_evaluation_dataset, +) + + +def _api_error(resp_body: dict) -> str: + """Pull the human-readable error string out of an APIResponse failure body.""" + return str(resp_body.get("error") or resp_body.get("detail") or resp_body).lower() + + +@pytest.fixture(autouse=True) +def _seeded_random(): + """Make jitter / random.choice deterministic so tests are repeatable.""" + random.seed(0) + yield + + +# --------------------------------------------------------------------------- +# Pure-function helpers (no FastAPI, no DB) +# --------------------------------------------------------------------------- + + +class TestFailureThreshold: + """`_is_failure_threshold_breached` controls run-level fail-fast.""" + + def test_returns_false_when_total_is_zero(self): + assert _is_failure_threshold_breached(failed_rows=0, total_rows=0) is False + + def test_returns_true_above_threshold(self): + # default EVAL_FAST_FAILURE_THRESHOLD = 0.5 + assert _is_failure_threshold_breached(failed_rows=6, total_rows=10) is True + + def test_returns_false_at_threshold(self): + # 0.5 / 1.0 is NOT greater-than the threshold, so do not breach + assert _is_failure_threshold_breached(failed_rows=5, total_rows=10) is False + + +class TestCallWithRetry: + """FR-8: transient OpenAI errors retry; permanent ones do not.""" + + def test_returns_immediately_on_success(self): + result = _call_with_retry(label="t", fn=lambda: "ok") + assert result == "ok" + + def test_retries_on_transient_then_succeeds(self, monkeypatch): + # Avoid sleeping — make backoff a no-op. + monkeypatch.setattr("app.crud.evaluations.fast.time.sleep", lambda *_: None) + + calls = {"n": 0} + + def flaky(): + calls["n"] += 1 + if calls["n"] < 3: + # APIConnectionError needs a request; pass a minimal object. + raise openai.APIConnectionError(request=MagicMock()) + return "ok" + + result = _call_with_retry(label="t", fn=flaky) + assert result == "ok" + assert calls["n"] == 3 + + def test_does_not_retry_on_permanent_error(self): + def bad(): + # AuthenticationError is a non-retryable OpenAIError subclass. + raise openai.AuthenticationError( + message="bad key", response=MagicMock(), body=None + ) + + with pytest.raises(openai.AuthenticationError): + _call_with_retry(label="t", fn=bad) + + +# --------------------------------------------------------------------------- +# Route validation: POST /evaluations with run_mode=fast (FR-1..FR-5, FR-15) +# --------------------------------------------------------------------------- + + +@pytest.fixture +def _patch_dispatch(): + """Stub the Celery dispatch so tests don't actually enqueue work.""" + with patch( + "app.services.evaluations.fast.enqueue_fast_evaluation", + return_value="fake-task-id", + ) as m: + yield m + + +def _make_fast_eligible_dataset( + *, + db: Session, + user_api_key: TestAuthContext, + original_items_count: int = 3, +): + return create_test_evaluation_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + original_items_count=original_items_count, + duplication_factor=1, + ) + + +def _make_text_openai_config(db: Session, project_id: int): + """Create a stored text-OpenAI Kaapi config (eligible for fast eval).""" + return create_test_config( + db=db, + project_id=project_id, + use_kaapi_schema=True, + ) + + +class TestFastEvaluationRoute: + """End-to-end validation on POST /evaluations with run_mode='fast'.""" + + def test_fr4_accepts_eligible_request_and_dispatches( + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + _patch_dispatch, + ): + """FR-4: eligible request returns processing + dispatches orchestrator.""" + dataset = _make_fast_eligible_dataset(db=db, user_api_key=user_api_key) + config = _make_text_openai_config(db, user_api_key.project_id) + + resp = client.post( + "/api/v1/evaluations", + json={ + "experiment_name": "fr4-fast-run", + "dataset_id": dataset.id, + "config_id": str(config.id), + "config_version": 1, + "run_mode": "fast", + }, + headers=user_api_key_header, + ) + + assert resp.status_code == 200, resp.text + body = resp.json()["data"] + assert body["run_mode"] == "fast" + assert body["status"] == "processing" + assert body["run_name"] == "fr4-fast-run" + _patch_dispatch.assert_called_once() + + # DB state matches the response. + run = db.get(EvaluationRun, body["id"]) + assert run is not None + assert run.run_mode == RunModeEnum.FAST.value + assert run.status == "processing" + + def test_fr2_rejects_dataset_with_too_many_unique_rows( + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + _patch_dispatch, + ): + """FR-2: >10 unique rows → 422 dataset_too_large_for_fast.""" + # default EVAL_FAST_MAX_UNIQUE_ROWS = 10; create 11 unique rows + dataset = _make_fast_eligible_dataset( + db=db, user_api_key=user_api_key, original_items_count=11 + ) + config = _make_text_openai_config(db, user_api_key.project_id) + + resp = client.post( + "/api/v1/evaluations", + json={ + "experiment_name": "fr2-fast-run", + "dataset_id": dataset.id, + "config_id": str(config.id), + "config_version": 1, + "run_mode": "fast", + }, + headers=user_api_key_header, + ) + + assert resp.status_code == 422 + # The route wraps HTTPException.detail into APIResponse.error. + error_str = _api_error(resp.json()) + assert "dataset_too_large_for_fast" in error_str + assert "11" in error_str # surfaces actual unique-row count + _patch_dispatch.assert_not_called() + + def test_fr1_rejects_non_text_config( + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + _patch_dispatch, + ): + """FR-1: non-text config for fast mode → 422 config_type_unsupported. + + Build a stored config whose completion.type is not 'text'. The current + config factories produce text configs by default, so we patch the + resolved blob to look like an STT config. + """ + dataset = _make_fast_eligible_dataset(db=db, user_api_key=user_api_key) + config = _make_text_openai_config(db, user_api_key.project_id) + + # Patch resolve_evaluation_config to return an STT-type blob. + fake_blob = ConfigBlob( + completion=KaapiCompletionConfig( + provider="openai", + type="stt", + params={"model": "whisper-1"}, + ) + ) + + with patch( + "app.services.evaluations.fast.resolve_evaluation_config", + return_value=(fake_blob, None), + ): + resp = client.post( + "/api/v1/evaluations", + json={ + "experiment_name": "fr1-fast-run", + "dataset_id": dataset.id, + "config_id": str(config.id), + "config_version": 1, + "run_mode": "fast", + }, + headers=user_api_key_header, + ) + + assert resp.status_code == 422 + assert "config_type_unsupported" in _api_error(resp.json()) + _patch_dispatch.assert_not_called() + + def test_fr3_rejects_duplicate_run_name( + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + _patch_dispatch, + ): + """FR-3: duplicate (org, project, run_name) → 409, no second dispatch.""" + dataset = _make_fast_eligible_dataset(db=db, user_api_key=user_api_key) + config = _make_text_openai_config(db, user_api_key.project_id) + payload = { + "experiment_name": "fr3-dup-run", + "dataset_id": dataset.id, + "config_id": str(config.id), + "config_version": 1, + "run_mode": "fast", + } + + first = client.post( + "/api/v1/evaluations", json=payload, headers=user_api_key_header + ) + assert first.status_code == 200, first.text + + second = client.post( + "/api/v1/evaluations", json=payload, headers=user_api_key_header + ) + assert second.status_code == 409 + assert "run_name_already_exists" in _api_error(second.json()) + # First call dispatched; second must not have. + assert _patch_dispatch.call_count == 1 + + def test_fr15_get_evaluation_returns_run_mode( + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + ): + """FR-15: GET /evaluations/{id} surfaces `run_mode` for both modes.""" + dataset = _make_fast_eligible_dataset(db=db, user_api_key=user_api_key) + config = _make_text_openai_config(db, user_api_key.project_id) + eval_run = EvaluationRun( + run_name="fr15-existing", + dataset_name=dataset.name, + dataset_id=dataset.id, + config_id=config.id, + config_version=1, + status="completed", + run_mode=RunModeEnum.FAST.value, + total_items=3, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + db.add(eval_run) + db.commit() + db.refresh(eval_run) + + resp = client.get( + f"/api/v1/evaluations/{eval_run.id}", headers=user_api_key_header + ) + assert resp.status_code == 200 + assert resp.json()["data"]["run_mode"] == "fast" + + +# --------------------------------------------------------------------------- +# Dataset listing eligibility filter (FR-5) +# --------------------------------------------------------------------------- + + +class TestDatasetListEligibleForFast: + def test_fr5_filters_to_fast_eligible_only( + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + ): + """FR-5: eligible_for=fast filters list to datasets with ≤10 unique rows.""" + eligible = _make_fast_eligible_dataset( + db=db, user_api_key=user_api_key, original_items_count=5 + ) + ineligible = _make_fast_eligible_dataset( + db=db, user_api_key=user_api_key, original_items_count=20 + ) + + resp = client.get( + "/api/v1/evaluations/datasets", + params={"eligible_for": "fast"}, + headers=user_api_key_header, + ) + assert resp.status_code == 200 + data = resp.json()["data"] + ids = {d["dataset_id"] for d in data} + assert eligible.id in ids + assert ineligible.id not in ids + assert all(d["eligible_for_fast"] is True for d in data) + + +# --------------------------------------------------------------------------- +# Stage skipping on retry (FR-6, FR-7) +# --------------------------------------------------------------------------- + + +def _fake_openai_response(text: str = "answer", item_id: str = "item-1"): + """Mimic the SDK's response.responses.create return shape.""" + return SimpleNamespace( + id=f"resp_{item_id}", + output_text=text, + output=[], + usage=SimpleNamespace(input_tokens=10, output_tokens=20, total_tokens=30), + ) + + +def _fake_embedding_response(): + """Mimic openai.embeddings.create return shape (2 vectors).""" + return SimpleNamespace( + data=[ + SimpleNamespace(index=0, embedding=[1.0, 0.0, 0.0]), + SimpleNamespace(index=1, embedding=[1.0, 0.0, 0.0]), + ], + usage=SimpleNamespace(prompt_tokens=5, total_tokens=5), + ) + + +class TestStageSkipping: + """FR-6 / FR-7: stages skip on retry when their batch_job marker is set.""" + + def test_fr6_stage1_skips_when_batch_job_id_already_set( + self, + db: Session, + user_api_key: TestAuthContext, + ): + """Pre-existing batch_job → Stage 1 does not call the Responses API.""" + dataset = _make_fast_eligible_dataset(db=db, user_api_key=user_api_key) + config = _make_text_openai_config(db, user_api_key.project_id) + eval_run = EvaluationRun( + run_name="fr6-stage1-skip", + dataset_name=dataset.name, + dataset_id=dataset.id, + config_id=config.id, + config_version=1, + status="processing", + run_mode=RunModeEnum.FAST.value, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + db.add(eval_run) + db.commit() + db.refresh(eval_run) + + # Pre-create a batch_job marker as if Stage 1 had already completed. + existing = BatchJob( + provider="openai", + job_type=JOB_TYPE_EVALUATION_FAST, + config={}, + raw_output_url="s3://bucket/responses_x.json", + total_items=1, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + db.add(existing) + db.commit() + db.refresh(existing) + + eval_run.batch_job_id = existing.id + db.add(eval_run) + db.commit() + db.refresh(eval_run) + + cached = [ + { + "item_id": "item-1", + "question": "q", + "generated_output": "a", + "ground_truth": "a", + "response_id": "resp_1", + "usage": {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}, + "question_id": 1, + "failed": False, + } + ] + + fake_openai = MagicMock() + with patch("app.crud.evaluations.fast._load_unit_from_s3", return_value=cached): + _, results = _stage1_responses( + session=db, + openai_client=fake_openai, + eval_run=eval_run, + config=TextLLMParams(model="gpt-4o", instructions="x"), + dataset_items=[ + { + "id": "item-1", + "input": {"question": "q"}, + "expected_output": {"answer": "a"}, + "metadata": {}, + } + ], + log_prefix="[t]", + ) + + # Skip path returns the cached unit and never calls the OpenAI client. + assert results == cached + fake_openai.responses.create.assert_not_called() + + def test_fr7_stage2_skips_when_embedding_batch_job_id_already_set( + self, + db: Session, + user_api_key: TestAuthContext, + ): + """Pre-existing embedding batch_job → Stage 2 does not call the Embeddings API.""" + dataset = _make_fast_eligible_dataset(db=db, user_api_key=user_api_key) + config = _make_text_openai_config(db, user_api_key.project_id) + eval_run = EvaluationRun( + run_name="fr7-stage2-skip", + dataset_name=dataset.name, + dataset_id=dataset.id, + config_id=config.id, + config_version=1, + status="processing", + run_mode=RunModeEnum.FAST.value, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + db.add(eval_run) + db.commit() + db.refresh(eval_run) + + marker = BatchJob( + provider="openai", + job_type=JOB_TYPE_EMBEDDING_FAST, + config={}, + raw_output_url="s3://bucket/embeddings_x.json", + total_items=1, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + db.add(marker) + db.commit() + db.refresh(marker) + + eval_run.embedding_batch_job_id = marker.id + db.add(eval_run) + db.commit() + db.refresh(eval_run) + + cached = [ + { + "item_id": "item-1", + "output_embedding": [1.0, 0.0], + "ground_truth_embedding": [1.0, 0.0], + "usage": {"prompt_tokens": 5, "total_tokens": 5}, + "failed": False, + } + ] + + fake_openai = MagicMock() + with patch("app.crud.evaluations.fast._load_unit_from_s3", return_value=cached): + _, results = _stage2_embeddings( + session=db, + openai_client=fake_openai, + eval_run=eval_run, + response_results=[ + { + "item_id": "item-1", + "question": "q", + "generated_output": "a", + "ground_truth": "a", + "failed": False, + } + ], + log_prefix="[t]", + ) + + assert results == cached + fake_openai.embeddings.create.assert_not_called() + + +# --------------------------------------------------------------------------- +# End-to-end orchestrator pipeline with mocked externals (FR-9..FR-14) +# --------------------------------------------------------------------------- + + +@pytest.fixture +def _fast_pipeline_mocks(): + """Patch external boundaries used inside `run_fast_evaluation`. + + OpenAI client returns fixed responses/embeddings, Langfuse returns trace + ids, S3 upload returns a URL, and `attach_cost` no-ops so the test does + not need a model_config row. + """ + with ( + patch("app.crud.evaluations.fast.fetch_dataset_items") as mock_fetch_items, + patch( + "app.crud.evaluations.fast._upload_unit_to_s3", + side_effect=lambda **kw: f"s3://bucket/{kw['filename']}", + ), + patch( + "app.crud.evaluations.fast.create_langfuse_dataset_run" + ) as mock_create_traces, + patch( + "app.crud.evaluations.fast.update_traces_with_cosine_scores" + ) as mock_update_traces, + patch( + "app.crud.evaluations.fast.resolve_model_from_config", return_value="gpt-4o" + ), + patch("app.crud.evaluations.fast.attach_cost") as mock_attach_cost, + ): + mock_fetch_items.return_value = [ + { + "id": "item-1", + "input": {"question": "Q1"}, + "expected_output": {"answer": "A1"}, + "metadata": {"question_id": 1}, + }, + { + "id": "item-2", + "input": {"question": "Q2"}, + "expected_output": {"answer": "A2"}, + "metadata": {"question_id": 2}, + }, + ] + mock_create_traces.return_value = { + "item-1": "trace-1", + "item-2": "trace-2", + } + yield SimpleNamespace( + fetch_items=mock_fetch_items, + create_traces=mock_create_traces, + update_traces=mock_update_traces, + attach_cost=mock_attach_cost, + ) + + +class TestFastPipelineEndToEnd: + """Run the orchestrator with mocked external boundaries (FR-9..FR-13).""" + + def test_pipeline_completes_with_scores_and_writes_batch_jobs( + self, + db: Session, + user_api_key: TestAuthContext, + _fast_pipeline_mocks, + ): + dataset = _make_fast_eligible_dataset(db=db, user_api_key=user_api_key) + config = _make_text_openai_config(db, user_api_key.project_id) + eval_run = EvaluationRun( + run_name="pipeline-happy-path", + dataset_name=dataset.name, + dataset_id=dataset.id, + config_id=config.id, + config_version=1, + status="pending", + run_mode=RunModeEnum.FAST.value, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + db.add(eval_run) + db.commit() + db.refresh(eval_run) + + fake_openai = MagicMock() + fake_openai.responses.create.side_effect = ( + lambda **kwargs: _fake_openai_response(text="LLM answer", item_id="x") + ) + fake_openai.embeddings.create.return_value = _fake_embedding_response() + fake_langfuse = MagicMock() + + result = run_fast_evaluation( + session=db, + openai_client=fake_openai, + langfuse=fake_langfuse, + eval_run=eval_run, + config=TextLLMParams(model="gpt-4o", instructions="be helpful"), + ) + + # FR-11/FR-14: status completed, summary cosine ≈ 1.0 for identical vectors + assert result.status == "completed" + assert result.score is not None + cosine = result.score["summary_scores"][0] + assert cosine["name"] == "Cosine Similarity" + assert cosine["avg"] == pytest.approx(1.0, abs=0.01) + assert cosine["total_pairs"] == 2 + + # Stage markers exist (FR-6/FR-7 invariant + FR-9 — no llm_call rows). + assert result.batch_job_id is not None + assert result.embedding_batch_job_id is not None + responses_job = db.get(BatchJob, result.batch_job_id) + embeddings_job = db.get(BatchJob, result.embedding_batch_job_id) + assert responses_job.job_type == JOB_TYPE_EVALUATION_FAST + assert responses_job.raw_output_url is not None + assert embeddings_job.job_type == JOB_TYPE_EMBEDDING_FAST + assert embeddings_job.raw_output_url is not None + + # FR-12: Langfuse traces created and per-trace scores attached. + assert _fast_pipeline_mocks.create_traces.called + assert _fast_pipeline_mocks.update_traces.called + + # FR-13: attach_cost called twice (response + embedding stages). + assert _fast_pipeline_mocks.attach_cost.call_count == 2 + + +# --------------------------------------------------------------------------- +# Failure-threshold short-circuit (FR-14) +# --------------------------------------------------------------------------- + + +class TestFailureThresholdInPipeline: + def test_fr14_stage1_raises_when_failure_ratio_exceeds_threshold( + self, + db: Session, + user_api_key: TestAuthContext, + ): + """FR-14: Stage 1 raises RuntimeError when failure ratio > threshold. + + The outer orchestrator (run_fast_evaluation / execute_fast_evaluation) + catches the RuntimeError and marks the run failed; the structural + guarantee under test is that Stage 1 fails fast instead of proceeding + to Stage 2 with a mostly-broken response set. + """ + dataset = _make_fast_eligible_dataset(db=db, user_api_key=user_api_key) + config = _make_text_openai_config(db, user_api_key.project_id) + eval_run = EvaluationRun( + run_name="fr14-fail-threshold", + dataset_name=dataset.name, + dataset_id=dataset.id, + config_id=config.id, + config_version=1, + status="processing", + run_mode=RunModeEnum.FAST.value, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + db.add(eval_run) + db.commit() + db.refresh(eval_run) + + # Make all OpenAI Responses calls fail with a permanent error so retries + # short-circuit and every item gets failed=True. With every item + # failing, the failure fraction (1.0) is well above the 0.5 threshold. + fake_openai = MagicMock() + fake_openai.responses.create.side_effect = openai.AuthenticationError( + message="bad key", response=MagicMock(), body=None + ) + + dataset_items = [ + { + "id": f"item-{i}", + "input": {"question": f"Q{i}"}, + "expected_output": {"answer": f"A{i}"}, + "metadata": {}, + } + for i in range(4) + ] + + with pytest.raises(RuntimeError, match="failure threshold"): + _stage1_responses( + session=db, + openai_client=fake_openai, + eval_run=eval_run, + config=TextLLMParams(model="gpt-4o", instructions="x"), + dataset_items=dataset_items, + log_prefix="[t]", + )