diff --git a/backend/README.md b/backend/README.md index ed3d0fd..ce07174 100644 --- a/backend/README.md +++ b/backend/README.md @@ -193,7 +193,7 @@ Set the resulting digest as `AUTH_TOKEN` in your `.env` / `.env.test`. ## Multi-tenant API Key Configuration -Ban List and Topic Relevance Config APIs use `X-API-KEY` auth instead of bearer token auth. +Ban List and LLM Prompt Config APIs use `X-API-KEY` auth instead of bearer token auth. Required environment variables: - `KAAPI_AUTH_URL`: Base URL of the Kaapi auth service used to verify API keys. @@ -203,7 +203,7 @@ At runtime, the backend calls: - `GET {KAAPI_AUTH_URL}/apikeys/verify` - Header: `X-API-KEY: ` -If verification succeeds, tenant's scope (`organization_id`, `project_id`) is resolved from the auth response and applied to tenant-scoped CRUD operations (for example Ban Lists and Topic Relevance Configs). +If verification succeeds, tenant's scope (`organization_id`, `project_id`) is resolved from the auth response and applied to tenant-scoped CRUD operations (for example Ban Lists and LLM Prompt Configs). ## Guardrails AI Setup diff --git a/backend/app/alembic/versions/008_added_llm_validator_prompt.py b/backend/app/alembic/versions/008_added_llm_validator_prompt.py new file mode 100644 index 0000000..7fd2b61 --- /dev/null +++ b/backend/app/alembic/versions/008_added_llm_validator_prompt.py @@ -0,0 +1,104 @@ +"""Added llm_validator_prompt: rename topic_relevance to llm_prompt, add validator_name, rename configuration to llm_prompt + +Revision ID: 008 +Revises: 007 +Create Date: 2026-05-08 00:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +revision: str = "008" +down_revision = "007" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Rename table + op.rename_table("topic_relevance", "llm_prompt") + + # Rename indexes created by migration 006 + op.execute( + "ALTER INDEX idx_topic_relevance_organization RENAME TO idx_llm_prompt_organization" + ) + op.execute( + "ALTER INDEX idx_topic_relevance_project RENAME TO idx_llm_prompt_project" + ) + op.execute( + "ALTER INDEX idx_topic_relevance_prompt_schema_version " + "RENAME TO idx_llm_prompt_prompt_schema_version" + ) + op.execute( + "ALTER INDEX idx_topic_relevance_is_active RENAME TO idx_llm_prompt_is_active" + ) + + # Add validator_name column (server_default backfills existing rows as topic_relevance) + op.add_column( + "llm_prompt", + sa.Column( + "validator_name", + sa.String(), + nullable=False, + server_default="topic_relevance", + ), + ) + # Drop server_default so future rows must supply validator_name explicitly + op.alter_column("llm_prompt", "validator_name", server_default=None) + + # Rename configuration → llm_prompt column + op.alter_column("llm_prompt", "configuration", new_column_name="llm_prompt") + + # Replace unique constraint to include validator_name and use new column name + op.drop_constraint( + "uq_topic_relevance_config_org_project_prompt", + "llm_prompt", + type_="unique", + ) + op.create_unique_constraint( + "uq_llm_prompt_config", + "llm_prompt", + [ + "organization_id", + "project_id", + "validator_name", + "prompt_schema_version", + "llm_prompt", + ], + ) + + op.create_index("idx_llm_prompt_validator_name", "llm_prompt", ["validator_name"]) + + +def downgrade() -> None: + op.drop_index("idx_llm_prompt_validator_name", table_name="llm_prompt") + + op.drop_constraint("uq_llm_prompt_config", "llm_prompt", type_="unique") + op.create_unique_constraint( + "uq_topic_relevance_config_org_project_prompt", + "llm_prompt", + ["organization_id", "project_id", "prompt_schema_version", "llm_prompt"], + ) + + op.alter_column("llm_prompt", "llm_prompt", new_column_name="configuration") + + op.drop_column("llm_prompt", "validator_name") + + op.execute( + "ALTER INDEX idx_llm_prompt_is_active RENAME TO idx_topic_relevance_is_active" + ) + op.execute( + "ALTER INDEX idx_llm_prompt_prompt_schema_version " + "RENAME TO idx_topic_relevance_prompt_schema_version" + ) + op.execute( + "ALTER INDEX idx_llm_prompt_project RENAME TO idx_topic_relevance_project" + ) + op.execute( + "ALTER INDEX idx_llm_prompt_organization RENAME TO idx_topic_relevance_organization" + ) + + op.rename_table("llm_prompt", "topic_relevance") diff --git a/backend/app/alembic/versions/009_add_output_text_to_request_log.py b/backend/app/alembic/versions/009_add_output_text_to_request_log.py new file mode 100644 index 0000000..ff1cfd0 --- /dev/null +++ b/backend/app/alembic/versions/009_add_output_text_to_request_log.py @@ -0,0 +1,28 @@ +"""Add output_text to request_log + +Revision ID: 009 +Revises: 008 +Create Date: 2026-05-21 00:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +revision: str = "009" +down_revision = "008" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "request_log", + sa.Column("output_text", sa.String(), nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("request_log", "output_text") diff --git a/backend/app/api/API_USAGE.md b/backend/app/api/API_USAGE.md index 55392b0..af2b82a 100644 --- a/backend/app/api/API_USAGE.md +++ b/backend/app/api/API_USAGE.md @@ -6,7 +6,7 @@ This guide explains how to use the current API surface for: - Runtime validator discovery - Guardrail execution - Ban list CRUD for multi-tenant projects -- Topic relevance config CRUD for multi-tenant projects +- LLM prompt config CRUD for multi-tenant projects (`topic_relevance` and `answer_relevance_custom_llm`) ## Base URL and Version @@ -24,7 +24,7 @@ This API currently uses two auth modes: - Used by validator config and guardrails endpoints. - The server validates your plaintext bearer token against a SHA-256 digest stored in `AUTH_TOKEN`. 2. multi-tenant API key auth (`X-API-KEY: `) - - Used by ban list and topic relevance config endpoints. + - Used by ban list and LLM prompt config endpoints. - The API key is verified against `KAAPI_AUTH_URL` and resolves tenant's scope (`organization_id`, `project_id`). Notes: @@ -184,8 +184,8 @@ Request fields: Important: - Runtime validators use `on_fail`. - If you pass objects from config APIs, server normalization supports `on_fail_action` and strips non-runtime fields. -- For `topic_relevance`, pass `topic_relevance_config_id` only. -- The API resolves `configuration` + `prompt_schema_version` in `guardrails.py` before validator execution, so the validator always executes with both values. +- For `topic_relevance`, pass `topic_relevance_config_id` only. The API resolves `configuration` + `prompt_schema_version` in `guardrails.py` before validator execution. +- For `answer_relevance_custom_llm`, `input` must be a JSON string `{"query": "...", "answer": "..."}`. Pass `custom_prompt_id` to use a stored tenant prompt, or omit to use the built-in default prompt. Example: @@ -342,86 +342,107 @@ curl -X DELETE "http://localhost:8001/api/v1/guardrails/ban_lists/" -H "X-API-KEY: " ``` -## 6) Topic Relevance Config APIs (multi-tenant) +## 6) LLM Prompt Config APIs (multi-tenant) -These endpoints manage tenant-scoped topic relevance presets and use `X-API-KEY` auth. +These endpoints manage tenant-scoped LLM prompt configs for the `topic_relevance` and `answer_relevance_custom_llm` validators. They use `X-API-KEY` auth. Base path: -- `/api/v1/guardrails/topic_relevance_configs` +- `/api/v1/guardrails/llm_prompt_configs` -## 6.1 Create topic relevance config +The `validator_name` field determines which validator the config applies to: +- `"topic_relevance"` — a scope description used as the LLM topic guard prompt. No placeholder requirements. +- `"answer_relevance_custom_llm"` — a custom evaluation prompt. Must contain `{query}` and `{answer}` placeholders. + +## 6.1 Create LLM prompt config Endpoint: -- `POST /api/v1/guardrails/topic_relevance_configs/` +- `POST /api/v1/guardrails/llm_prompt_configs/` -Example: +Example (topic relevance): ```bash -curl -X POST "http://localhost:8001/api/v1/guardrails/topic_relevance_configs/" \ +curl -X POST "http://localhost:8001/api/v1/guardrails/llm_prompt_configs/" \ -H "X-API-KEY: " \ -H "Content-Type: application/json" \ -d '{ + "validator_name": "topic_relevance", "name": "Maternal Health Scope", "description": "Topic guard for maternal health support bot", "prompt_schema_version": 1, - "configuration": "Pregnancy care: Questions about prenatal care, ANC visits, nutrition, supplements, danger signs. Postpartum care: Questions about recovery after delivery, breastfeeding, and mother health checks." + "llm_prompt": "Pregnancy care: Questions about prenatal care, ANC visits, nutrition, supplements, danger signs. Postpartum care: Questions about recovery after delivery, breastfeeding, and mother health checks." + }' +``` + +Example (answer relevance): + +```bash +curl -X POST "http://localhost:8001/api/v1/guardrails/llm_prompt_configs/" \ + -H "X-API-KEY: " \ + -H "Content-Type: application/json" \ + -d '{ + "validator_name": "answer_relevance_custom_llm", + "name": "Maternal Health Relevance", + "description": "Checks if LLM answer addresses a maternal health query", + "llm_prompt": "You are evaluating a maternal health assistant.\nQuery: {query}\nAnswer: {answer}\n\nDoes the answer directly address the maternal health query with accurate information?\nAnswer only YES or NO." }' ``` -## 6.2 List topic relevance configs +## 6.2 List LLM prompt configs Endpoint: -- `GET /api/v1/guardrails/topic_relevance_configs/?offset=0&limit=20` +- `GET /api/v1/guardrails/llm_prompt_configs/?offset=0&limit=20` + +Optional filter: +- `validator_name=topic_relevance|answer_relevance_custom_llm` Example: ```bash -curl -X GET "http://localhost:8001/api/v1/guardrails/topic_relevance_configs/?offset=0&limit=20" \ +curl -X GET "http://localhost:8001/api/v1/guardrails/llm_prompt_configs/?validator_name=topic_relevance&offset=0&limit=20" \ -H "X-API-KEY: " ``` -## 6.3 Get topic relevance config by id +## 6.3 Get LLM prompt config by id Endpoint: -- `GET /api/v1/guardrails/topic_relevance_configs/{id}` +- `GET /api/v1/guardrails/llm_prompt_configs/{id}` Example: ```bash -curl -X GET "http://localhost:8001/api/v1/guardrails/topic_relevance_configs/" \ +curl -X GET "http://localhost:8001/api/v1/guardrails/llm_prompt_configs/" \ -H "X-API-KEY: " ``` -## 6.4 Update topic relevance config +## 6.4 Update LLM prompt config Endpoint: -- `PATCH /api/v1/guardrails/topic_relevance_configs/{id}` +- `PATCH /api/v1/guardrails/llm_prompt_configs/{id}` Example: ```bash -curl -X PATCH "http://localhost:8001/api/v1/guardrails/topic_relevance_configs/" \ +curl -X PATCH "http://localhost:8001/api/v1/guardrails/llm_prompt_configs/" \ -H "X-API-KEY: " \ -H "Content-Type: application/json" \ -d '{ - "prompt_schema_version": 1, - "configuration": "Pregnancy care: Updated scope definition" + "llm_prompt": "Pregnancy care: Updated scope definition" }' ``` -## 6.5 Delete topic relevance config +## 6.5 Delete LLM prompt config Endpoint: -- `DELETE /api/v1/guardrails/topic_relevance_configs/{id}` +- `DELETE /api/v1/guardrails/llm_prompt_configs/{id}` Example: ```bash -curl -X DELETE "http://localhost:8001/api/v1/guardrails/topic_relevance_configs/" \ +curl -X DELETE "http://localhost:8001/api/v1/guardrails/llm_prompt_configs/" \ -H "X-API-KEY: " ``` -## 7) End-to-End Usage Pattern +## 8) End-to-End Usage Pattern Recommended request flow: 1. Create/update validator configs via `/guardrails/validators/configs`. @@ -430,16 +451,17 @@ Recommended request flow: 4. Use `safe_text` as downstream text. 5. If `rephrase_needed=true`, ask user to rephrase. 6. For `ban_list` validators without inline `banned_words`, create/manage a ban list first and pass `ban_list_id`. -7. For `topic_relevance`, create/manage a topic relevance config and pass `topic_relevance_config_id` at runtime. The server resolves the configuration string internally. +7. For `topic_relevance`, create/manage an LLM prompt config (`validator_name: "topic_relevance"`) and pass `topic_relevance_config_id` at runtime. The server resolves `llm_prompt` and `prompt_schema_version` internally. +8. For `answer_relevance_custom_llm`, format `input` as `{"query": "...", "answer": "..."}`. Optionally create an LLM prompt config (`validator_name: "answer_relevance_custom_llm"`) and pass `custom_prompt_id`. If no `custom_prompt_id` is given, the built-in default prompt is used. -## 8) Common Errors +## 9) Common Errors - `401 Missing Authorization header` - Add `Authorization: Bearer `. - `401 Invalid authorization token` - Verify plaintext token matches server-side hash. - `401 Missing X-API-KEY header` - - Add `X-API-KEY: ` for ban list and topic relevance config endpoints. + - Add `X-API-KEY: ` for ban list and LLM prompt config endpoints. - `401 Invalid API key` - Verify the API key is valid in the upstream Kaapi auth service. - `Invalid request_id` @@ -448,10 +470,10 @@ Recommended request flow: - Type+stage is unique per organization/project scope. - `Validator not found` - Confirm `id`, `organization_id`, and `project_id` match. -- `Topic relevance preset not found` - - Confirm topic relevance config `id` exists within your tenant scope. +- `LLM prompt config not found` + - Confirm the LLM prompt config `id` exists within your tenant scope. -## 9) Current Validator Types +## 10) Current Validator Types From `validators.json`: - `uli_slur_match` @@ -463,6 +485,7 @@ From `validators.json`: - `llamaguard_7b` - `profanity_free` - `nsfw_text` +- `answer_relevance_custom_llm` Source of truth: - `backend/app/core/validators/validators.json` diff --git a/backend/app/api/docs/guardrails/run_guardrails.md b/backend/app/api/docs/guardrails/run_guardrails.md index d0f1c7f..7a02220 100644 --- a/backend/app/api/docs/guardrails/run_guardrails.md +++ b/backend/app/api/docs/guardrails/run_guardrails.md @@ -6,8 +6,9 @@ Behavior notes: - The endpoint always saves a `request_log` entry for the run. - Validator logs are also saved; with `suppress_pass_logs=true`, only fail-case validator logs are persisted. Otherwise, all validator logs are added. - For `ban_list`, `ban_list_id` can be resolved to `banned_words` from tenant ban list configs. -- For `topic_relevance`, `topic_relevance_config_id` is required and is resolved to `configuration` + `prompt_schema_version` from tenant topic relevance configs. Requires `OPENAI_API_KEY` to be configured; returns a validation failure with an explicit error if missing. +- For `topic_relevance`, `topic_relevance_config_id` is required and is resolved to `llm_prompt` + `prompt_schema_version` from tenant LLM prompt configs. Requires `OPENAI_API_KEY` to be configured; returns a validation failure with an explicit error if missing. - For `llm_critic`, `OPENAI_API_KEY` must be configured; returns `success=false` with an explicit error if missing. +- For `answer_relevance_custom_llm`, `input` must be a JSON string `{"query": "...", "answer": "..."}`. Pass `custom_prompt_id` to use a tenant-stored prompt template, or `prompt_template` inline. Requires `OPENAI_API_KEY`. - For `llamaguard_7b`, `policies` accepts human-readable policy names (see table below). If omitted, all policies are enforced by default. | `policies` value | Policy enforced | diff --git a/backend/app/api/docs/llm_prompt_configs/create_config.md b/backend/app/api/docs/llm_prompt_configs/create_config.md new file mode 100644 index 0000000..b74b1fb --- /dev/null +++ b/backend/app/api/docs/llm_prompt_configs/create_config.md @@ -0,0 +1,51 @@ +Creates an LLM prompt config for the tenant resolved from `X-API-KEY`. + +Behavior notes: +- Stores a named prompt used by an LLM-backed validator (`topic_relevance` or `answer_relevance_custom_llm`). +- `validator_name` determines which validator this config applies to. +- Tenant scope is enforced from the API key context. +- Duplicate configurations (same `validator_name`, `prompt_schema_version`, and `llm_prompt`) are rejected. +- For `answer_relevance_custom_llm`, `llm_prompt` must contain both `{query}` and `{answer}` placeholders. + +Common failure cases: +- Missing or invalid API key. +- Payload schema validation errors. +- `llm_prompt` is missing `{query}` or `{answer}` placeholder (for `answer_relevance_custom_llm`). +- A config with the same configuration already exists. + +## Field glossary + +**`validator_name`** +Which LLM-backed validator this prompt config applies to. + +Accepted values: +- `topic_relevance` — scope guard; `llm_prompt` is a plain-text description of allowed topics injected at `{{TOPIC_CONFIGURATION}}`. +- `answer_relevance_custom_llm` — relevance judge; `llm_prompt` must contain `{query}` and `{answer}` placeholders. + +**`llm_prompt`** +The prompt text supplied to the LLM at evaluation time. + +For `topic_relevance`, this is a plain-text scope definition: +``` +This assistant only answers questions about maternal health and pregnancy care. +It should not respond to questions about politics or general medicine unrelated to pregnancy. +``` + +For `answer_relevance_custom_llm`, this must include `{query}` and `{answer}` placeholders: +``` +You are evaluating a maternal health assistant. +Query: {query} +Answer: {answer} + +Does the answer directly address the maternal health query? +Answer only YES or NO. +``` + +**`prompt_schema_version`** +Integer selecting the versioned prompt template. Defaults to `1`. Only relevant for `topic_relevance`; increment only when a new system prompt version has been added. + +**`name`** +Human-readable label for this config (max 100 characters). + +**`description`** +What this config evaluates or guards (max 500 characters). diff --git a/backend/app/api/docs/llm_prompt_configs/delete_config.md b/backend/app/api/docs/llm_prompt_configs/delete_config.md new file mode 100644 index 0000000..227fb94 --- /dev/null +++ b/backend/app/api/docs/llm_prompt_configs/delete_config.md @@ -0,0 +1,10 @@ +Deletes an LLM prompt config by id for the tenant resolved from `X-API-KEY`. + +Behavior notes: +- Tenant scope is enforced from the API key context. +- Deletion is permanent; any guardrail configs referencing this id will fail to resolve at runtime after deletion. + +Common failure cases: +- Missing or invalid API key. +- LLM prompt config not found in tenant's scope. +- Invalid id format. diff --git a/backend/app/api/docs/llm_prompt_configs/get_config.md b/backend/app/api/docs/llm_prompt_configs/get_config.md new file mode 100644 index 0000000..44ad12d --- /dev/null +++ b/backend/app/api/docs/llm_prompt_configs/get_config.md @@ -0,0 +1,9 @@ +Fetches a single LLM prompt config by id for the tenant resolved from `X-API-KEY`. + +Behavior notes: +- Tenant scope is enforced from the API key context. + +Common failure cases: +- Missing or invalid API key. +- LLM prompt config not found in tenant's scope. +- Invalid id format. diff --git a/backend/app/api/docs/llm_prompt_configs/list_configs.md b/backend/app/api/docs/llm_prompt_configs/list_configs.md new file mode 100644 index 0000000..873fc6d --- /dev/null +++ b/backend/app/api/docs/llm_prompt_configs/list_configs.md @@ -0,0 +1,13 @@ +Lists LLM prompt configs for the tenant resolved from `X-API-KEY`. + +Behavior notes: +- Returns configs scoped to the tenant's `organization_id` and `project_id`. +- Optionally filter by `validator_name` to retrieve configs for a specific validator. +- Supports pagination via `offset` and `limit`. +- `offset` defaults to `0`. +- `limit` is optional; when omitted, no limit is applied. +- Results are ordered by `created_at` ascending, then `id`. + +Common failure cases: +- Missing or invalid API key. +- Invalid pagination values. diff --git a/backend/app/api/docs/llm_prompt_configs/update_config.md b/backend/app/api/docs/llm_prompt_configs/update_config.md new file mode 100644 index 0000000..f13d11c --- /dev/null +++ b/backend/app/api/docs/llm_prompt_configs/update_config.md @@ -0,0 +1,13 @@ +Partially updates an LLM prompt config by id for the tenant resolved from `X-API-KEY`. + +Behavior notes: +- Supports patch-style updates; omitted fields remain unchanged. +- `validator_name` cannot be changed after creation. +- Tenant scope is enforced from the API key context. +- Duplicate configurations are rejected. + +Common failure cases: +- Missing or invalid API key. +- LLM prompt config not found in tenant's scope. +- Payload schema validation errors. +- A config with the same configuration already exists. diff --git a/backend/app/api/docs/topic_relevance_configs/create_config.md b/backend/app/api/docs/topic_relevance_configs/create_config.md deleted file mode 100644 index 07ac176..0000000 --- a/backend/app/api/docs/topic_relevance_configs/create_config.md +++ /dev/null @@ -1,27 +0,0 @@ -Creates a topic relevance configuration for the tenant resolved from `X-API-KEY`. - -Behavior notes: -- Stores a topic relevance preset with `name`, `prompt_schema_version`, and `configuration`. -- `configuration` is a plain text scope sub-prompt (string). -- Tenant scope is enforced from the API key context. -- Duplicate configurations are rejected. - -Common failure cases: -- Missing or invalid API key. -- Payload schema validation errors. -- Topic relevance with the same configuration already exists. - -## Field glossary - -**`configuration`** -A plain text string describing the topic scope the assistant is allowed to handle. This is injected into the LLM critic evaluation prompt at the `{{TOPIC_CONFIGURATION}}` placeholder to define what is considered in-scope. - -Example: -``` -This assistant only answers questions about maternal health and pregnancy care for NGO beneficiaries. It should not respond to questions about politics, general medicine unrelated to pregnancy, or financial topics. -``` - -**`prompt_schema_version`** -An integer selecting the versioned prompt template used to evaluate scope violations (e.g., `1` → `v1.md`). Controls the structure and wording of the LLM critic assessment prompt. Defaults to `1`. Only increment this when a new prompt template version has been added to the system. - -Example: `1` diff --git a/backend/app/api/docs/topic_relevance_configs/delete_config.md b/backend/app/api/docs/topic_relevance_configs/delete_config.md deleted file mode 100644 index ff45017..0000000 --- a/backend/app/api/docs/topic_relevance_configs/delete_config.md +++ /dev/null @@ -1,8 +0,0 @@ -Deletes a topic relevance configuration by id for the tenant resolved from `X-API-KEY`. - -Behavior notes: -- Tenant scope is enforced from the API key context. - -Common failure cases: -- Missing or invalid API key. -- Topic relevance preset not found in tenant's scope. diff --git a/backend/app/api/docs/topic_relevance_configs/get_config.md b/backend/app/api/docs/topic_relevance_configs/get_config.md deleted file mode 100644 index 89a3c2e..0000000 --- a/backend/app/api/docs/topic_relevance_configs/get_config.md +++ /dev/null @@ -1,9 +0,0 @@ -Fetches a single topic relevance configuration by id for the tenant resolved from `X-API-KEY`. - -Behavior notes: -- Tenant scope is enforced from the API key context. - -Common failure cases: -- Missing or invalid API key. -- Topic relevance preset not found in tenant's scope. -- Invalid id format. diff --git a/backend/app/api/docs/topic_relevance_configs/list_configs.md b/backend/app/api/docs/topic_relevance_configs/list_configs.md deleted file mode 100644 index d463c03..0000000 --- a/backend/app/api/docs/topic_relevance_configs/list_configs.md +++ /dev/null @@ -1,11 +0,0 @@ -Lists topic relevance configurations for the tenant resolved from `X-API-KEY`. - -Behavior notes: -- Supports pagination via `offset` and `limit`. -- `offset` defaults to `0`. -- `limit` is optional; when omitted, no limit is applied. -- Tenant scope is enforced from the API key context. - -Common failure cases: -- Missing or invalid API key. -- Invalid pagination values. diff --git a/backend/app/api/docs/topic_relevance_configs/update_config.md b/backend/app/api/docs/topic_relevance_configs/update_config.md deleted file mode 100644 index f9627b9..0000000 --- a/backend/app/api/docs/topic_relevance_configs/update_config.md +++ /dev/null @@ -1,13 +0,0 @@ -Partially updates a topic relevance configuration by id for the tenant resolved from `X-API-KEY`. - -Behavior notes: -- Supports patch-style updates; omitted fields remain unchanged. -- `configuration` should be provided as a plain text scope sub-prompt (string). -- Tenant scope is enforced from the API key context. -- Duplicate configurations are rejected. - -Common failure cases: -- Missing or invalid API key. -- Topic relevance preset not found in tenant's scope. -- Payload schema validation errors. -- Topic relevance with the same configuration already exists. diff --git a/backend/app/api/main.py b/backend/app/api/main.py index f3c4543..0d97b95 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -3,7 +3,7 @@ from app.api.routes import ( ban_lists, guardrails, - topic_relevance_configs, + llm_prompt_configs, validator_configs, utils, ) @@ -11,7 +11,7 @@ api_router = APIRouter() api_router.include_router(ban_lists.router) api_router.include_router(guardrails.router) -api_router.include_router(topic_relevance_configs.router) +api_router.include_router(llm_prompt_configs.router) api_router.include_router(validator_configs.router) api_router.include_router(utils.router) diff --git a/backend/app/api/routes/guardrails.py b/backend/app/api/routes/guardrails.py index 7281718..0f752c8 100644 --- a/backend/app/api/routes/guardrails.py +++ b/backend/app/api/routes/guardrails.py @@ -1,7 +1,8 @@ +import json from uuid import UUID import uuid -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException from guardrails.guard import Guard from guardrails.validators import FailResult, PassResult from sqlmodel import Session @@ -13,16 +14,19 @@ LLM_CRITIC_REPHRASE_MESSAGE, REPHRASE_ON_FAIL_PREFIX, ) -from app.core.enum import ValidatorType +from app.core.enum import LLMValidatorName, ValidatorType from app.core.guardrail_controller import build_guard, get_validator_config_models from app.core.exception_handlers import _safe_error_message from app.core.validators.config.ban_list_safety_validator_config import ( BanListSafetyValidatorConfig, ) from app.crud.ban_list import ban_list_crud -from app.crud.topic_relevance import topic_relevance_crud +from app.crud.llm_prompt_config import llm_prompt_config_crud from app.crud.request_log import RequestLogCrud from app.crud.validator_log import ValidatorLogCrud +from app.core.validators.config.answer_relevance_custom_llm_safety_validator_config import ( + AnswerRelevanceCustomLLMSafetyValidatorConfig, +) from app.core.validators.config.topic_relevance_safety_validator_config import ( TopicRelevanceSafetyValidatorConfig, ) @@ -58,9 +62,10 @@ def run_guardrails( except ValueError: return APIResponse.failure_response(error="Invalid request_id") - _resolve_validator_configs(payload, session) + data = _resolve_validator_configs(payload, session) return _validate_with_guard( payload, + data, request_log_crud, request_log.id, validator_log_crud, @@ -98,12 +103,19 @@ def list_validators(_: AuthDep): return {"validators": validators} -def _resolve_validator_configs(payload: GuardrailRequest, session: Session) -> None: +def _resolve_validator_configs(payload: GuardrailRequest, session: Session) -> str: """ Resolves config-backed references for all validators in-place before guard execution: - BanList: fetches banned_words from the stored BanList when not provided inline. - TopicRelevance: fetches configuration and prompt_schema_version from stored config. + - AnswerRelevance: fetches custom prompt template from stored config; returns + JSON-encoded {"query": input, "answer": output} as the guard data. + + Returns the data string to pass to guard.validate(). """ + # Input guardrails validate payload.input; output guardrails validate payload.output. + # AnswerRelevance is the exception: it needs both, encoded as JSON. + data = payload.output if payload.output is not None else payload.input for validator in payload.validators: if isinstance(validator, BanListSafetyValidatorConfig): if validator.type == BAN_LIST and validator.banned_words is None: @@ -117,18 +129,47 @@ def _resolve_validator_configs(payload: GuardrailRequest, session: Session) -> N elif isinstance(validator, TopicRelevanceSafetyValidatorConfig): if validator.topic_relevance_config_id is not None: - config = topic_relevance_crud.get( + config = llm_prompt_config_crud.get( session=session, id=validator.topic_relevance_config_id, organization_id=payload.organization_id, project_id=payload.project_id, ) - validator.configuration = config.configuration + if config.validator_name != LLMValidatorName.TopicRelevance: + raise HTTPException( + 400, + f"LLM prompt config '{config.id}' is for validator " + f"'{config.validator_name}', not 'topic_relevance'", + ) + validator.configuration = config.llm_prompt validator.prompt_schema_version = config.prompt_schema_version + elif isinstance(validator, AnswerRelevanceCustomLLMSafetyValidatorConfig): + data = json.dumps({"query": payload.input, "answer": payload.output or ""}) + if validator.custom_prompt_id is not None: + prompt_config = llm_prompt_config_crud.get( + session=session, + id=validator.custom_prompt_id, + organization_id=payload.organization_id, + project_id=payload.project_id, + ) + if ( + prompt_config.validator_name + != LLMValidatorName.AnswerRelevanceCustomLLM + ): + raise HTTPException( + 400, + f"LLM prompt config '{prompt_config.id}' is for validator " + f"'{prompt_config.validator_name}', not 'answer_relevance_custom_llm'", + ) + validator.prompt_template = prompt_config.llm_prompt + + return data + def _validate_with_guard( payload: GuardrailRequest, + data: str, request_log_crud: RequestLogCrud, request_log_id: UUID, validator_log_crud: ValidatorLogCrud, @@ -142,7 +183,6 @@ def _validate_with_guard( while still safely handling unexpected runtime errors. """ response_id = uuid.uuid4() - data = payload.input validators = payload.validators guard: Guard | None = None diff --git a/backend/app/api/routes/llm_prompt_configs.py b/backend/app/api/routes/llm_prompt_configs.py new file mode 100644 index 0000000..035f2ba --- /dev/null +++ b/backend/app/api/routes/llm_prompt_configs.py @@ -0,0 +1,121 @@ +from typing import Annotated, Optional +from uuid import UUID + +from fastapi import APIRouter, Query + +from app.api.deps import MultitenantAuthDep, SessionDep +from app.core.enum import LLMValidatorName +from app.crud.llm_prompt_config import llm_prompt_config_crud +from app.schemas.llm_prompt_config import ( + LLMPromptConfigCreate, + LLMPromptConfigResponse, + LLMPromptConfigUpdate, +) +from app.utils import APIResponse, load_description + +router = APIRouter( + prefix="/guardrails/llm_prompt_configs", + tags=["LLM Prompt Configs"], +) + + +@router.post( + "/", + description=load_description("llm_prompt_configs/create_config.md"), + response_model=APIResponse[LLMPromptConfigResponse], +) +def create_llm_prompt_config( + payload: LLMPromptConfigCreate, + session: SessionDep, + auth: MultitenantAuthDep, +) -> APIResponse[LLMPromptConfigResponse]: + obj = llm_prompt_config_crud.create( + session, + payload, + auth.organization_id, + auth.project_id, + ) + return APIResponse.success_response(data=obj) + + +@router.get( + "/", + description=load_description("llm_prompt_configs/list_configs.md"), + response_model=APIResponse[list[LLMPromptConfigResponse]], +) +def list_llm_prompt_configs( + session: SessionDep, + auth: MultitenantAuthDep, + validator_name: Annotated[Optional[LLMValidatorName], Query()] = None, + offset: Annotated[int, Query(ge=0)] = 0, + limit: Annotated[int | None, Query(ge=1, le=100)] = None, +) -> APIResponse[list[LLMPromptConfigResponse]]: + objs = llm_prompt_config_crud.list( + session, + auth.organization_id, + auth.project_id, + validator_name=validator_name, + offset=offset, + limit=limit, + ) + return APIResponse.success_response(data=objs) + + +@router.get( + "/{id}", + description=load_description("llm_prompt_configs/get_config.md"), + response_model=APIResponse[LLMPromptConfigResponse], +) +def get_llm_prompt_config( + id: UUID, + session: SessionDep, + auth: MultitenantAuthDep, +) -> APIResponse[LLMPromptConfigResponse]: + obj = llm_prompt_config_crud.get( + session, + id, + auth.organization_id, + auth.project_id, + ) + return APIResponse.success_response(data=obj) + + +@router.patch( + "/{id}", + description=load_description("llm_prompt_configs/update_config.md"), + response_model=APIResponse[LLMPromptConfigResponse], +) +def update_llm_prompt_config( + id: UUID, + payload: LLMPromptConfigUpdate, + session: SessionDep, + auth: MultitenantAuthDep, +) -> APIResponse[LLMPromptConfigResponse]: + obj = llm_prompt_config_crud.update( + session, + id, + auth.organization_id, + auth.project_id, + payload, + ) + return APIResponse.success_response(data=obj) + + +@router.delete( + "/{id}", + description=load_description("llm_prompt_configs/delete_config.md"), + response_model=APIResponse[dict], +) +def delete_llm_prompt_config( + id: UUID, + session: SessionDep, + auth: MultitenantAuthDep, +) -> APIResponse[dict]: + obj = llm_prompt_config_crud.get( + session, + id, + auth.organization_id, + auth.project_id, + ) + llm_prompt_config_crud.delete(session, obj) + return APIResponse.success_response(data={"message": "Config deleted successfully"}) diff --git a/backend/app/api/routes/topic_relevance_configs.py b/backend/app/api/routes/topic_relevance_configs.py deleted file mode 100644 index b855a58..0000000 --- a/backend/app/api/routes/topic_relevance_configs.py +++ /dev/null @@ -1,118 +0,0 @@ -from typing import Annotated -from uuid import UUID - -from fastapi import APIRouter, Query - -from app.api.deps import MultitenantAuthDep, SessionDep -from app.crud.topic_relevance import topic_relevance_crud -from app.schemas.topic_relevance import ( - TopicRelevanceCreate, - TopicRelevanceUpdate, - TopicRelevanceResponse, -) -from app.utils import APIResponse, load_description - -router = APIRouter( - prefix="/guardrails/topic_relevance_configs", - tags=["Topic Relevance Configs"], -) - - -@router.post( - "/", - description=load_description("topic_relevance_configs/create_config.md"), - response_model=APIResponse[TopicRelevanceResponse], -) -def create_topic_relevance_config( - payload: TopicRelevanceCreate, - session: SessionDep, - auth: MultitenantAuthDep, -) -> APIResponse[TopicRelevanceResponse]: - topic_relevance_config = topic_relevance_crud.create( - session, - payload, - auth.organization_id, - auth.project_id, - ) - return APIResponse.success_response(data=topic_relevance_config) - - -@router.get( - "/", - description=load_description("topic_relevance_configs/list_configs.md"), - response_model=APIResponse[list[TopicRelevanceResponse]], -) -def list_topic_relevance_configs( - session: SessionDep, - auth: MultitenantAuthDep, - offset: Annotated[int, Query(ge=0)] = 0, - limit: Annotated[int | None, Query(ge=1, le=100)] = None, -) -> APIResponse[list[TopicRelevanceResponse]]: - topic_relevance_configs = topic_relevance_crud.list( - session, - auth.organization_id, - auth.project_id, - offset, - limit, - ) - return APIResponse.success_response(data=topic_relevance_configs) - - -@router.get( - "/{id}", - description=load_description("topic_relevance_configs/get_config.md"), - response_model=APIResponse[TopicRelevanceResponse], -) -def get_topic_relevance_config( - id: UUID, - session: SessionDep, - auth: MultitenantAuthDep, -) -> APIResponse[TopicRelevanceResponse]: - topic_relevance_config = topic_relevance_crud.get( - session, - id, - auth.organization_id, - auth.project_id, - ) - return APIResponse.success_response(data=topic_relevance_config) - - -@router.patch( - "/{id}", - description=load_description("topic_relevance_configs/update_config.md"), - response_model=APIResponse[TopicRelevanceResponse], -) -def update_topic_relevance_config( - id: UUID, - payload: TopicRelevanceUpdate, - session: SessionDep, - auth: MultitenantAuthDep, -) -> APIResponse[TopicRelevanceResponse]: - topic_relevance_config = topic_relevance_crud.update( - session, - id, - auth.organization_id, - auth.project_id, - payload, - ) - return APIResponse.success_response(data=topic_relevance_config) - - -@router.delete( - "/{id}", - description=load_description("topic_relevance_configs/delete_config.md"), - response_model=APIResponse[dict], -) -def delete_topic_relevance_config( - id: UUID, - session: SessionDep, - auth: MultitenantAuthDep, -) -> APIResponse[dict]: - obj = topic_relevance_crud.get( - session, - id, - auth.organization_id, - auth.project_id, - ) - topic_relevance_crud.delete(session, obj) - return APIResponse.success_response(data={"message": "Config deleted successfully"}) diff --git a/backend/app/core/enum.py b/backend/app/core/enum.py index ff653c5..936c629 100644 --- a/backend/app/core/enum.py +++ b/backend/app/core/enum.py @@ -1,6 +1,11 @@ from enum import Enum +class LLMValidatorName(str, Enum): + TopicRelevance = "topic_relevance" + AnswerRelevanceCustomLLM = "answer_relevance_custom_llm" + + class SlurSeverity(Enum): Low = "low" Medium = "medium" @@ -36,3 +41,4 @@ class ValidatorType(Enum): LlamaGuard7B = "llamaguard_7b" ProfanityFree = "profanity_free" NSFWText = "nsfw_text" + AnswerRelevanceCustomLLM = "answer_relevance_custom_llm" diff --git a/backend/app/core/validators/README.md b/backend/app/core/validators/README.md index f843d8e..e31d14b 100644 --- a/backend/app/core/validators/README.md +++ b/backend/app/core/validators/README.md @@ -15,6 +15,7 @@ Current validator manifest: - `llamaguard_7b` (source: `hub://guardrails/llamaguard_7b`) - `profanity_free` (source: `hub://guardrails/profanity_free`) - `nsfw_text` (source: `hub://guardrails/nsfw_text`) +- `answer_relevance_custom_llm` (source: `local`) ## Configuration Model @@ -302,7 +303,7 @@ What it does: Why this is used: - Enables flexible, prompt-driven content evaluation for use cases not covered by rule-based validators. -- All configuration is passed inline in the runtime request — there is no stored config object to resolve. Unlike `topic_relevance`, which looks up scope text from a persisted `TopicRelevanceConfig`, `llm_critic` receives `metrics`, `max_score`, and `llm_callable` directly in the guardrail request payload. +- All configuration is passed inline in the runtime request — there is no stored config object to resolve. Unlike `topic_relevance`, which looks up scope text from a persisted LLM prompt config, `llm_critic` receives `metrics`, `max_score`, and `llm_callable` directly in the guardrail request payload. Recommendation: @@ -359,7 +360,7 @@ Notes / limitations: - Runtime validation requires `topic_relevance_config_id`. - **Requires `OPENAI_API_KEY` to be set in environment variables.** If the key is not configured, validation returns a `FailResult` with an explicit message. -- Configuration is resolved in `backend/app/api/routes/guardrails.py` from tenant Topic Relevance Config APIs. +- Configuration is resolved in `backend/app/api/routes/guardrails.py` from tenant LLM Prompt Config APIs (`/guardrails/llm_prompt_configs`). - Prompt templates must include the `{{TOPIC_CONFIGURATION}}` placeholder. ### 7) LlamaGuard 7B Validator (`llamaguard_7b`) @@ -483,6 +484,54 @@ Notes / limitations: - No programmatic fix is applied — with `on_fail=fix`, `safe_text` will be `""` and the response `metadata.reason` will identify this validator as the cause. - English-focused; cross-lingual profanity may not be detected. +### 10) Answer Relevance Custom LLM Validator (`answer_relevance_custom_llm`) + +Code: + +- Config: `backend/app/core/validators/config/answer_relevance_custom_llm_safety_validator_config.py` +- Runtime validator: `backend/app/core/validators/answer_relevance_custom_llm.py` + +What it does: + +- Evaluates whether an LLM's answer is relevant to the user's query by asking a configurable LLM to respond YES or NO. +- Accepts `input` as a JSON string `{"query": "...", "answer": "..."}`. +- Uses a customizable prompt template with `{query}` and `{answer}` placeholders; falls back to a built-in default prompt if none is provided. +- Supports per-tenant custom prompts stored via the LLM Prompt Config APIs and referenced by `custom_prompt_id`. + +Why this is used: + +- Detects hallucinated or off-topic LLM responses before they are shown to users. +- Each NGO can tune the relevance criteria via a custom prompt without code changes (e.g. stricter domain constraints, language-specific phrasing). + +Recommendation: + +- primarily `output` + - Why `output`: answer relevance is a property of the LLM's generated response relative to the user's query. + +Parameters / customization: + +- `llm_callable: str` (default: `gpt-4o-mini`) — model identifier passed to LiteLLM for the YES/NO evaluation +- `prompt_template: str` (optional) — inline prompt with `{query}` and `{answer}` placeholders +- `custom_prompt_id: UUID` (optional) — reference to a tenant-stored prompt config; resolved to `prompt_template` before execution +- `on_fail` + +Default prompt: +``` +Query: {query} +Answer: {answer} + +Does the answer fully satisfy the query and constraints? +Answer only YES or NO. +``` + +Notes / limitations: + +- **Requires `OPENAI_API_KEY` to be set in environment variables.** +- `input` to the guardrail endpoint must be a JSON string: `{"query": "...", "answer": "..."}`. Both fields must be non-empty. +- LLM-judge responses can vary; YES/NO parsing uses prefix matching. +- `on_fail=fix` has no programmatic fix for irrelevant answers — `safe_text` will be `""` and `metadata.reason` will identify this validator. +- If `custom_prompt_id` is deleted after being referenced, the guardrail will return a 404 at resolution time. + ## Example Config Payloads Example: create validator config (stored shape) @@ -514,7 +563,7 @@ Example: runtime guardrail validator object (execution shape) Default stage strategy: - Input guardrails: `pii_remover`, `uli_slur_match`, `ban_list`, `topic_relevance` (when scope enforcement is needed), `profanity_free`, `llamaguard_7b` -- Output guardrails: `pii_remover`, `uli_slur_match`, `gender_assumption_bias`, `ban_list`, `profanity_free`, `llamaguard_7b` +- Output guardrails: `pii_remover`, `uli_slur_match`, `gender_assumption_bias`, `ban_list`, `profanity_free`, `llamaguard_7b`, `answer_relevance_custom_llm` (when answer quality must be verified) Tuning strategy: @@ -534,5 +583,9 @@ Tuning strategy: - `backend/app/core/validators/config/llamaguard_7b_safety_validator_config.py` - `backend/app/core/validators/config/nsfw_text_safety_validator_config.py` - `backend/app/core/validators/config/profanity_free_safety_validator_config.py` +- `backend/app/core/validators/config/answer_relevance_custom_llm_safety_validator_config.py` +- `backend/app/core/validators/answer_relevance_custom_llm.py` +- `backend/app/models/config/llm_prompt_config.py` +- `backend/app/crud/llm_prompt_config.py` - `backend/app/schemas/guardrail_config.py` - `backend/app/schemas/validator_config.py` diff --git a/backend/app/core/validators/answer_relevance_custom_llm.py b/backend/app/core/validators/answer_relevance_custom_llm.py new file mode 100644 index 0000000..a03029b --- /dev/null +++ b/backend/app/core/validators/answer_relevance_custom_llm.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import json +from typing import Callable, Optional + +from litellm import completion +from guardrails import OnFailAction +from guardrails.validators import ( + FailResult, + PassResult, + ValidationResult, + Validator, + register_validator, +) + +DEFAULT_PROMPT_TEMPLATE = ( + "Query: {query}\n" + "Answer: {answer}\n\n" + "Does the answer fully satisfy the query and constraints?\n" + "Answer only YES or NO." +) + + +@register_validator(name="answer-relevance-custom-llm", data_type="string") +class AnswerRelevanceCustomLLM(Validator): + """ + Validates whether an LLM answer is relevant to the user query. + + Expects `value` to be a JSON string: {"query": "...", "answer": "..."}. + Uses a configurable prompt template with {query} and {answer} placeholders. + Returns PassResult for YES, FailResult for NO. + """ + + def __init__( + self, + prompt_template: str = DEFAULT_PROMPT_TEMPLATE, + llm_callable: str = "gpt-4o-mini", + on_fail: Optional[Callable] = OnFailAction.NOOP, + ): + super().__init__(on_fail=on_fail) + self.prompt_template = prompt_template + self.llm_callable = llm_callable + + def _validate(self, value: str, metadata: dict | None = None) -> ValidationResult: + try: + data = json.loads(value) + query = data.get("query", "") + answer = data.get("answer", "") + except (json.JSONDecodeError, TypeError): + return FailResult( + error_message="Input must be a JSON string with 'query' and 'answer' fields." + ) + + if not query.strip() or not answer.strip(): + return FailResult( + error_message="Both 'query' and 'answer' fields must be non-empty." + ) + + try: + prompt = self.prompt_template.format(query=query, answer=answer) + except KeyError as e: + return FailResult(error_message=f"Prompt template missing placeholder: {e}") + + try: + response = completion( + model=self.llm_callable, + messages=[{"role": "user", "content": prompt}], + max_tokens=10, + ) + response_text = response.choices[0].message.content.strip().upper() + except Exception as e: + return FailResult(error_message=f"LLM call failed: {e}") + + if response_text.startswith("YES"): + return PassResult(value=value) + + if response_text.startswith("NO"): + return FailResult( + error_message="The answer is not relevant to the query.", + ) + + return FailResult( + error_message=f"Unexpected LLM response for relevance check: {response_text}" + ) diff --git a/backend/app/core/validators/config/answer_relevance_custom_llm_safety_validator_config.py b/backend/app/core/validators/config/answer_relevance_custom_llm_safety_validator_config.py new file mode 100644 index 0000000..ab1d7db --- /dev/null +++ b/backend/app/core/validators/config/answer_relevance_custom_llm_safety_validator_config.py @@ -0,0 +1,30 @@ +from typing import Literal, Optional +from uuid import UUID + +from app.core.config import settings +from app.core.validators.answer_relevance_custom_llm import AnswerRelevanceCustomLLM +from app.core.validators.config.base_validator_config import BaseValidatorConfig + + +class AnswerRelevanceCustomLLMSafetyValidatorConfig(BaseValidatorConfig): + type: Literal["answer_relevance_custom_llm"] + llm_callable: str = "gpt-4o-mini" + # Inline prompt template with {query} and {answer} placeholders. + # If None, the validator uses its built-in default. + prompt_template: Optional[str] = None + # Reference to a stored custom prompt; resolved to prompt_template before build(). + custom_prompt_id: Optional[UUID] = None + + def build(self): + if not settings.OPENAI_API_KEY: + raise ValueError( + "OPENAI_API_KEY is not configured. " + "Answer relevance validation requires an OpenAI API key." + ) + kwargs = dict( + llm_callable=self.llm_callable, + on_fail=self.resolve_on_fail(), + ) + if self.prompt_template: + kwargs["prompt_template"] = self.prompt_template + return AnswerRelevanceCustomLLM(**kwargs) diff --git a/backend/app/core/validators/gender_assumption_bias.py b/backend/app/core/validators/gender_assumption_bias.py index 2165067..a012cfc 100644 --- a/backend/app/core/validators/gender_assumption_bias.py +++ b/backend/app/core/validators/gender_assumption_bias.py @@ -30,7 +30,7 @@ def __init__( self.gender_bias_list = self.load_gender_bias_list(self.categories) super().__init__(on_fail=on_fail) - def _validate(self, value: str, metadata: dict = None) -> ValidationResult: + def _validate(self, value: str, metadata: dict | None = None) -> ValidationResult: detected_biased_words = [] bias_check = False diff --git a/backend/app/core/validators/lexical_slur.py b/backend/app/core/validators/lexical_slur.py index 42d7596..2270d52 100644 --- a/backend/app/core/validators/lexical_slur.py +++ b/backend/app/core/validators/lexical_slur.py @@ -39,7 +39,7 @@ def __init__( self._compile_slur_patterns() super().__init__(on_fail=on_fail, search_words=self.slur_list) - def _validate(self, value: str, metadata: dict = None) -> ValidationResult: + def _validate(self, value: str, metadata: dict | None = None) -> ValidationResult: original_text = value normalized_text = self.normalize_for_matching(value) detected_slurs = [] diff --git a/backend/app/core/validators/pii_remover.py b/backend/app/core/validators/pii_remover.py index 5a93a73..efe8ad0 100644 --- a/backend/app/core/validators/pii_remover.py +++ b/backend/app/core/validators/pii_remover.py @@ -109,7 +109,7 @@ def __init__( self.analyzer = _get_cached_analyzer(self.entity_types) self.anonymizer = AnonymizerEngine() - def _validate(self, value: str, metadata: dict = None) -> ValidationResult: + def _validate(self, value: str, metadata: dict | None = None) -> ValidationResult: text = value results = self.analyzer.analyze( text=text, entities=self.entity_types, language="en" diff --git a/backend/app/core/validators/topic_relevance.py b/backend/app/core/validators/topic_relevance.py index 22d2bcc..721697f 100644 --- a/backend/app/core/validators/topic_relevance.py +++ b/backend/app/core/validators/topic_relevance.py @@ -107,7 +107,7 @@ def __init__( ), ) - def _validate(self, value: str, metadata: dict = None) -> ValidationResult: + def _validate(self, value: str, metadata: dict | None = None) -> ValidationResult: """Run the LLMCritic and return a PassResult or FailResult with the scope score.""" if self._invalid_config_reason: return FailResult(error_message=self._invalid_config_reason) diff --git a/backend/app/crud/llm_prompt_config.py b/backend/app/crud/llm_prompt_config.py new file mode 100644 index 0000000..9405e94 --- /dev/null +++ b/backend/app/crud/llm_prompt_config.py @@ -0,0 +1,138 @@ +from typing import List, Optional +from uuid import UUID + +from fastapi import HTTPException +from sqlalchemy.exc import IntegrityError +from sqlmodel import Session, select + +from app.core.enum import LLMValidatorName +from app.models.config.llm_prompt_config import LLMPromptConfig +from app.utils import now + + +class LLMPromptConfigCrud: + def create( + self, + session: Session, + payload, + organization_id: int, + project_id: int, + ) -> LLMPromptConfig: + obj = LLMPromptConfig( + **payload.model_dump(), + organization_id=organization_id, + project_id=project_id, + ) + session.add(obj) + try: + session.commit() + except IntegrityError: + session.rollback() + raise HTTPException( + 400, + "A prompt config with the same configuration already exists", + ) + except Exception: + session.rollback() + raise + + session.refresh(obj) + return obj + + def get( + self, + session: Session, + id: UUID, + organization_id: int, + project_id: int, + ) -> LLMPromptConfig: + query = select(LLMPromptConfig).where( + LLMPromptConfig.id == id, + LLMPromptConfig.organization_id == organization_id, + LLMPromptConfig.project_id == project_id, + ) + obj = session.exec(query).first() + if not obj: + raise HTTPException(404, "LLM prompt config not found") + return obj + + def list( + self, + session: Session, + organization_id: int, + project_id: int, + validator_name: Optional[LLMValidatorName] = None, + offset: int = 0, + limit: Optional[int] = None, + ) -> List[LLMPromptConfig]: + query = select(LLMPromptConfig).where( + LLMPromptConfig.organization_id == organization_id, + LLMPromptConfig.project_id == project_id, + ) + + if validator_name is not None: + query = query.where(LLMPromptConfig.validator_name == validator_name) + + query = query.order_by(LLMPromptConfig.created_at, LLMPromptConfig.id) + + if offset: + query = query.offset(offset) + if limit: + query = query.limit(limit) + + return list(session.exec(query).all()) + + def update( + self, + session: Session, + id: UUID, + organization_id: int, + project_id: int, + payload, + ) -> LLMPromptConfig: + obj = self.get(session, id, organization_id, project_id) + + update_data = payload.model_dump(exclude_unset=True) + + if ( + "llm_prompt" in update_data + and obj.validator_name == LLMValidatorName.AnswerRelevanceCustomLLM + ): + new_prompt = update_data["llm_prompt"] + missing = [p for p in ("{query}", "{answer}") if p not in new_prompt] + if missing: + raise HTTPException( + 422, + f"llm_prompt must contain the placeholders: {', '.join(missing)}", + ) + + for key, value in update_data.items(): + setattr(obj, key, value) + + obj.updated_at = now() + session.add(obj) + try: + session.commit() + except IntegrityError: + session.rollback() + raise HTTPException( + 400, + "A prompt config with the same configuration already exists", + ) + except Exception: + session.rollback() + raise + + session.refresh(obj) + return obj + + def delete(self, session: Session, obj: LLMPromptConfig) -> None: + session.delete(obj) + try: + session.commit() + except Exception: + session.rollback() + raise + + +llm_prompt_config_crud = LLMPromptConfigCrud() diff --git a/backend/app/crud/request_log.py b/backend/app/crud/request_log.py index 2a283cb..ce21da9 100644 --- a/backend/app/crud/request_log.py +++ b/backend/app/crud/request_log.py @@ -16,6 +16,7 @@ def create(self, payload: GuardrailRequest) -> RequestLog: create_request_log = RequestLog( request_id=request_id, request_text=payload.input, + output_text=payload.output, organization_id=payload.organization_id, project_id=payload.project_id, ) diff --git a/backend/app/crud/topic_relevance.py b/backend/app/crud/topic_relevance.py deleted file mode 100644 index c6455d0..0000000 --- a/backend/app/crud/topic_relevance.py +++ /dev/null @@ -1,120 +0,0 @@ -from typing import List -from uuid import UUID - -from fastapi import HTTPException -from sqlalchemy.exc import IntegrityError -from sqlmodel import Session, select - -from app.models.config.topic_relevance import TopicRelevance -from app.schemas.topic_relevance import ( - TopicRelevanceCreate, - TopicRelevanceUpdate, -) -from app.utils import now - - -class TopicRelevanceCrud: - def create( - self, - session: Session, - payload: TopicRelevanceCreate, - organization_id: int, - project_id: int, - ) -> TopicRelevance: - topic_relevance_obj = TopicRelevance( - **payload.model_dump(), - organization_id=organization_id, - project_id=project_id, - ) - session.add(topic_relevance_obj) - try: - session.commit() - except IntegrityError: - session.rollback() - raise HTTPException( - 400, "Topic relevance with the same configuration already exists" - ) - except Exception: - session.rollback() - raise - - session.refresh(topic_relevance_obj) - return topic_relevance_obj - - def get( - self, session: Session, id: UUID, organization_id: int, project_id: int - ) -> TopicRelevance: - query = select(TopicRelevance).where( - TopicRelevance.id == id, - TopicRelevance.organization_id == organization_id, - TopicRelevance.project_id == project_id, - ) - topic_relevance_obj = session.exec(query).first() - if not topic_relevance_obj: - raise HTTPException(404, "Topic relevance preset not found") - return topic_relevance_obj - - def list( - self, - session: Session, - organization_id: int, - project_id: int, - offset: int = 0, - limit: int | None = None, - ) -> List[TopicRelevance]: - query = ( - select(TopicRelevance) - .where( - TopicRelevance.organization_id == organization_id, - TopicRelevance.project_id == project_id, - ) - .order_by(TopicRelevance.created_at, TopicRelevance.id) - ) - - if offset: - query = query.offset(offset) - if limit: - query = query.limit(limit) - - return list(session.exec(query).all()) - - def update( - self, - session: Session, - id: UUID, - organization_id: int, - project_id: int, - payload: TopicRelevanceUpdate, - ) -> TopicRelevance: - topic_relevance_obj = self.get(session, id, organization_id, project_id) - - update_data = payload.model_dump(exclude_unset=True) - for key, value in update_data.items(): - setattr(topic_relevance_obj, key, value) - - topic_relevance_obj.updated_at = now() - session.add(topic_relevance_obj) - try: - session.commit() - except IntegrityError: - session.rollback() - raise HTTPException( - 400, "Topic relevance with the same configuration already exists" - ) - except Exception: - session.rollback() - raise - - session.refresh(topic_relevance_obj) - return topic_relevance_obj - - def delete(self, session: Session, topic_relevance_obj: TopicRelevance): - session.delete(topic_relevance_obj) - try: - session.commit() - except Exception: - session.rollback() - raise - - -topic_relevance_crud = TopicRelevanceCrud() diff --git a/backend/app/models/config/topic_relevance.py b/backend/app/models/config/llm_prompt_config.py similarity index 57% rename from backend/app/models/config/topic_relevance.py rename to backend/app/models/config/llm_prompt_config.py index a044e91..f513f1d 100644 --- a/backend/app/models/config/topic_relevance.py +++ b/backend/app/models/config/llm_prompt_config.py @@ -4,16 +4,17 @@ from sqlalchemy import UniqueConstraint from sqlmodel import SQLModel, Field +from app.core.enum import LLMValidatorName from app.utils import now -class TopicRelevance(SQLModel, table=True): - __tablename__ = "topic_relevance" +class LLMPromptConfig(SQLModel, table=True): + __tablename__ = "llm_prompt" id: UUID = Field( default_factory=uuid4, primary_key=True, - sa_column_kwargs={"comment": "Unique identifier for the topic relevance entry"}, + sa_column_kwargs={"comment": "Unique identifier for the LLM prompt config"}, ) organization_id: int = Field( @@ -28,51 +29,52 @@ class TopicRelevance(SQLModel, table=True): sa_column_kwargs={"comment": "Identifier for the project"}, ) + validator_name: LLMValidatorName = Field( + nullable=False, + index=True, + sa_column_kwargs={"comment": "Validator type this prompt config belongs to"}, + ) + name: str = Field( nullable=False, - sa_column_kwargs={"comment": "Name of the topic relevance entry"}, + sa_column_kwargs={"comment": "Human-readable name for this prompt config"}, ) description: str = Field( nullable=False, - sa_column_kwargs={"comment": "Description of the topic relevance entry"}, + sa_column_kwargs={"comment": "Description of what this prompt evaluates"}, ) prompt_schema_version: int = Field( + default=1, index=True, nullable=False, - sa_column_kwargs={"comment": "Version of the topic relevance prompt to use"}, + sa_column_kwargs={"comment": "Version of the prompt schema"}, ) - configuration: str = Field( + llm_prompt: str = Field( nullable=False, - sa_column_kwargs={ - "comment": "Prompt text blob containing topic relevance scope definition" - }, + sa_column_kwargs={"comment": "Prompt text used by the LLM validator"}, ) is_active: bool = Field( default=True, index=True, nullable=False, - sa_column_kwargs={ - "comment": "Whether the topic relevance entry is active or not" - }, + sa_column_kwargs={"comment": "Whether this prompt config is active"}, ) created_at: datetime = Field( default_factory=now, nullable=False, - sa_column_kwargs={ - "comment": "Timestamp when the topic configuration entry was created" - }, + sa_column_kwargs={"comment": "Timestamp when the entry was created"}, ) updated_at: datetime = Field( default_factory=now, nullable=False, sa_column_kwargs={ - "comment": "Timestamp when the topic configuration entry was last updated", + "comment": "Timestamp when the entry was last updated", "onupdate": now, }, ) @@ -81,8 +83,9 @@ class TopicRelevance(SQLModel, table=True): UniqueConstraint( "organization_id", "project_id", + "validator_name", "prompt_schema_version", - "configuration", - name="uq_topic_relevance_config_org_project_prompt", + "llm_prompt", + name="uq_validator_prompt_config", ), ) diff --git a/backend/app/models/logging/request_log.py b/backend/app/models/logging/request_log.py index bda3ad7..648e44d 100644 --- a/backend/app/models/logging/request_log.py +++ b/backend/app/models/logging/request_log.py @@ -55,6 +55,12 @@ class RequestLog(SQLModel, table=True): sa_column_kwargs={"comment": "Text of the request made"}, ) + output_text: Optional[str] = Field( + default=None, + nullable=True, + sa_column_kwargs={"comment": "LLM output text passed for output guardrails"}, + ) + response_text: Optional[str] = Field( default=None, nullable=True, diff --git a/backend/app/schemas/guardrail_config.py b/backend/app/schemas/guardrail_config.py index 968c260..4460702 100644 --- a/backend/app/schemas/guardrail_config.py +++ b/backend/app/schemas/guardrail_config.py @@ -33,9 +33,13 @@ from app.core.validators.config.profanity_free_safety_validator_config import ( ProfanityFreeSafetyValidatorConfig, ) +from app.core.validators.config.answer_relevance_custom_llm_safety_validator_config import ( + AnswerRelevanceCustomLLMSafetyValidatorConfig, +) ValidatorConfigItem = Annotated[ Union[ + AnswerRelevanceCustomLLMSafetyValidatorConfig, BanListSafetyValidatorConfig, GenderAssumptionBiasSafetyValidatorConfig, LexicalSlurSafetyValidatorConfig, @@ -56,6 +60,7 @@ class GuardrailRequest(SQLModel): organization_id: int project_id: int input: str + output: Optional[str] = None validators: List[ValidatorConfigItem] @model_validator(mode="before") diff --git a/backend/app/schemas/llm_prompt_config.py b/backend/app/schemas/llm_prompt_config.py new file mode 100644 index 0000000..080e6f5 --- /dev/null +++ b/backend/app/schemas/llm_prompt_config.py @@ -0,0 +1,70 @@ +from datetime import datetime +from typing import Annotated, Optional +from uuid import UUID + +from pydantic import StringConstraints, model_validator +from sqlmodel import Field, SQLModel + +from app.core.enum import LLMValidatorName + +MAX_NAME_LENGTH = 100 +MAX_DESCRIPTION_LENGTH = 500 + +LLMPromptName = Annotated[ + str, + StringConstraints(strip_whitespace=True, min_length=1, max_length=MAX_NAME_LENGTH), +] + +LLMPromptDescription = Annotated[ + str, + StringConstraints( + strip_whitespace=True, min_length=1, max_length=MAX_DESCRIPTION_LENGTH + ), +] + +LLMPromptText = Annotated[ + str, + StringConstraints(strip_whitespace=True, min_length=1), +] + +_ANSWER_RELEVANCE_PLACEHOLDERS = ("{query}", "{answer}") + + +class LLMPromptConfigCreate(SQLModel): + validator_name: LLMValidatorName + name: LLMPromptName + description: LLMPromptDescription + prompt_schema_version: int = Field(default=1, ge=1) + llm_prompt: LLMPromptText + + @model_validator(mode="after") + def validate_answer_relevance_placeholders(self) -> "LLMPromptConfigCreate": + if self.validator_name == LLMValidatorName.AnswerRelevanceCustomLLM: + missing = [ + p for p in _ANSWER_RELEVANCE_PLACEHOLDERS if p not in self.llm_prompt + ] + if missing: + raise ValueError( + f"llm_prompt must contain the placeholders: {', '.join(missing)}" + ) + return self + + +class LLMPromptConfigUpdate(SQLModel): + name: Optional[LLMPromptName] = None + description: Optional[LLMPromptDescription] = None + prompt_schema_version: Optional[int] = Field(default=None, ge=1) + llm_prompt: Optional[LLMPromptText] = None + is_active: Optional[bool] = None + + +class LLMPromptConfigResponse(SQLModel): + id: UUID + validator_name: LLMValidatorName + name: str + description: str + prompt_schema_version: int + llm_prompt: str + is_active: bool + created_at: datetime + updated_at: datetime diff --git a/backend/app/schemas/topic_relevance.py b/backend/app/schemas/topic_relevance.py deleted file mode 100644 index aabe9d3..0000000 --- a/backend/app/schemas/topic_relevance.py +++ /dev/null @@ -1,52 +0,0 @@ -from datetime import datetime -from typing import Annotated, Optional -from uuid import UUID - -from pydantic import StringConstraints -from sqlmodel import Field, SQLModel - -MAX_TOPIC_RELEVANCE_NAME_LENGTH = 100 -MAX_TOPIC_RELEVANCE_DESCRIPTION_LENGTH = 500 - -TopicsName = Annotated[ - str, - StringConstraints( - strip_whitespace=True, - min_length=1, - max_length=MAX_TOPIC_RELEVANCE_NAME_LENGTH, - ), -] - -TopicConfiguration = Annotated[ - str, - StringConstraints( - strip_whitespace=True, - min_length=1, - ), -] - - -class TopicRelevanceBase(SQLModel): - name: TopicsName - prompt_schema_version: int = Field(ge=1) - configuration: TopicConfiguration - - -class TopicRelevanceCreate(TopicRelevanceBase): - description: str - - -class TopicRelevanceUpdate(SQLModel): - name: Optional[TopicsName] = None - description: Optional[str] = None - prompt_schema_version: Optional[int] = Field(default=None, ge=1) - configuration: Optional[TopicConfiguration] = None - is_active: Optional[bool] = None - - -class TopicRelevanceResponse(TopicRelevanceBase): - description: str - id: UUID - is_active: bool - created_at: datetime - updated_at: datetime diff --git a/backend/app/tests/test_llm_prompt_configs_api.py b/backend/app/tests/test_llm_prompt_configs_api.py new file mode 100644 index 0000000..d142644 --- /dev/null +++ b/backend/app/tests/test_llm_prompt_configs_api.py @@ -0,0 +1,221 @@ +from unittest.mock import MagicMock, patch +from uuid import UUID + +import pytest +from sqlmodel import Session + +from app.api.deps import TenantContext +from app.api.routes.llm_prompt_configs import ( + create_llm_prompt_config, + delete_llm_prompt_config, + get_llm_prompt_config, + list_llm_prompt_configs, + update_llm_prompt_config, +) +from app.core.enum import LLMValidatorName +from app.schemas.llm_prompt_config import LLMPromptConfigCreate, LLMPromptConfigUpdate + +TEST_ID = UUID("223e4567-e89b-12d3-a456-426614174111") +TEST_ORG_ID = 101 +TEST_PROJECT_ID = 202 + +TOPIC_PROMPT = "Pregnancy care: Questions related to prenatal care and supplements." +ANSWER_PROMPT = "Query: {query}\nAnswer: {answer}\nRelevant? YES or NO." + + +@pytest.fixture +def mock_session(): + return MagicMock(spec=Session) + + +@pytest.fixture +def sample_topic_config(): + obj = MagicMock() + obj.id = TEST_ID + obj.validator_name = LLMValidatorName.TopicRelevance + obj.name = "Maternal Health Scope" + obj.description = "Topic scope for maternal health bot" + obj.prompt_schema_version = 1 + obj.llm_prompt = TOPIC_PROMPT + obj.is_active = True + obj.organization_id = TEST_ORG_ID + obj.project_id = TEST_PROJECT_ID + return obj + + +@pytest.fixture +def sample_answer_config(): + obj = MagicMock() + obj.id = TEST_ID + obj.validator_name = LLMValidatorName.AnswerRelevanceCustomLLM + obj.name = "Health Relevance" + obj.description = "Checks answer relevance for health queries" + obj.prompt_schema_version = 1 + obj.llm_prompt = ANSWER_PROMPT + obj.is_active = True + obj.organization_id = TEST_ORG_ID + obj.project_id = TEST_PROJECT_ID + return obj + + +@pytest.fixture +def topic_create_payload(): + return LLMPromptConfigCreate( + validator_name=LLMValidatorName.TopicRelevance, + name="Maternal Health Scope", + description="Topic scope for maternal health bot", + prompt_schema_version=1, + llm_prompt=TOPIC_PROMPT, + ) + + +@pytest.fixture +def answer_create_payload(): + return LLMPromptConfigCreate( + validator_name=LLMValidatorName.AnswerRelevanceCustomLLM, + name="Health Relevance", + description="Checks answer relevance for health queries", + llm_prompt=ANSWER_PROMPT, + ) + + +@pytest.fixture +def auth_context(): + return TenantContext( + organization_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + ) + + +def test_create_topic_relevance_config( + mock_session, topic_create_payload, sample_topic_config, auth_context +): + with patch("app.api.routes.llm_prompt_configs.llm_prompt_config_crud") as crud: + crud.create.return_value = sample_topic_config + + result = create_llm_prompt_config( + payload=topic_create_payload, + session=mock_session, + auth=auth_context, + ) + + crud.create.assert_called_once_with( + mock_session, + topic_create_payload, + TEST_ORG_ID, + TEST_PROJECT_ID, + ) + assert result.data == sample_topic_config + + +def test_create_answer_relevance_config( + mock_session, answer_create_payload, sample_answer_config, auth_context +): + with patch("app.api.routes.llm_prompt_configs.llm_prompt_config_crud") as crud: + crud.create.return_value = sample_answer_config + + result = create_llm_prompt_config( + payload=answer_create_payload, + session=mock_session, + auth=auth_context, + ) + + assert result.data == sample_answer_config + + +def test_list_all_configs( + mock_session, sample_topic_config, sample_answer_config, auth_context +): + with patch("app.api.routes.llm_prompt_configs.llm_prompt_config_crud") as crud: + crud.list.return_value = [sample_topic_config, sample_answer_config] + + result = list_llm_prompt_configs( + session=mock_session, + auth=auth_context, + ) + + crud.list.assert_called_once_with( + mock_session, + TEST_ORG_ID, + TEST_PROJECT_ID, + validator_name=None, + offset=0, + limit=None, + ) + assert len(result.data) == 2 + + +def test_list_filtered_by_validator_name( + mock_session, sample_topic_config, auth_context +): + with patch("app.api.routes.llm_prompt_configs.llm_prompt_config_crud") as crud: + crud.list.return_value = [sample_topic_config] + + result = list_llm_prompt_configs( + session=mock_session, + auth=auth_context, + validator_name=LLMValidatorName.TopicRelevance, + ) + + crud.list.assert_called_once_with( + mock_session, + TEST_ORG_ID, + TEST_PROJECT_ID, + validator_name=LLMValidatorName.TopicRelevance, + offset=0, + limit=None, + ) + assert len(result.data) == 1 + + +def test_get_success(mock_session, sample_topic_config, auth_context): + with patch("app.api.routes.llm_prompt_configs.llm_prompt_config_crud") as crud: + crud.get.return_value = sample_topic_config + + result = get_llm_prompt_config( + id=TEST_ID, + session=mock_session, + auth=auth_context, + ) + + crud.get.assert_called_once_with( + mock_session, TEST_ID, TEST_ORG_ID, TEST_PROJECT_ID + ) + assert result.data == sample_topic_config + + +def test_update_success(mock_session, sample_topic_config, auth_context): + with patch("app.api.routes.llm_prompt_configs.llm_prompt_config_crud") as crud: + crud.update.return_value = sample_topic_config + + result = update_llm_prompt_config( + id=TEST_ID, + payload=LLMPromptConfigUpdate(name="updated"), + session=mock_session, + auth=auth_context, + ) + + crud.update.assert_called_once() + args, _ = crud.update.call_args + assert args[1] == TEST_ID + assert args[2] == TEST_ORG_ID + assert args[3] == TEST_PROJECT_ID + assert args[4].name == "updated" + assert result.data == sample_topic_config + + +def test_delete_success(mock_session, sample_topic_config, auth_context): + with patch("app.api.routes.llm_prompt_configs.llm_prompt_config_crud") as crud: + crud.get.return_value = sample_topic_config + + result = delete_llm_prompt_config( + id=TEST_ID, + session=mock_session, + auth=auth_context, + ) + + crud.get.assert_called_once_with( + mock_session, TEST_ID, TEST_ORG_ID, TEST_PROJECT_ID + ) + crud.delete.assert_called_once_with(mock_session, sample_topic_config) + assert result.success is True diff --git a/backend/app/tests/test_llm_prompt_configs_api_integration.py b/backend/app/tests/test_llm_prompt_configs_api_integration.py new file mode 100644 index 0000000..3eb770c --- /dev/null +++ b/backend/app/tests/test_llm_prompt_configs_api_integration.py @@ -0,0 +1,394 @@ +import uuid + +import pytest + +from app.schemas.llm_prompt_config import MAX_NAME_LENGTH, MAX_DESCRIPTION_LENGTH + +pytestmark = pytest.mark.integration + +BASE_URL = "/api/v1/guardrails/llm_prompt_configs/" +DEFAULT_API_KEY = "org1_project1" +ALT_API_KEY = "org999_project999" + +TOPIC_PROMPT = ( + "Pregnancy care: Questions about prenatal care, supplements, and danger signs. " + "Postpartum care: Questions about recovery after delivery and breastfeeding." +) +ANSWER_PROMPT = "Query: {query}\nAnswer: {answer}\nIs the answer relevant? YES or NO." +CUSTOM_ANSWER_PROMPT = ( + "You are evaluating a health assistant.\n" + "Query: {query}\n" + "Answer: {answer}\n" + "Does the answer address the health query? YES or NO." +) + + +class BaseLLMPromptConfigTest: + def _headers(self, api_key=DEFAULT_API_KEY): + return {"X-API-Key": api_key} + + def create_topic(self, client, api_key=DEFAULT_API_KEY, **overrides): + name = overrides.get("name", "Maternal Health Scope") + payload = { + "validator_name": "topic_relevance", + "name": name, + "description": "Topic guard for maternal health support bot", + "prompt_schema_version": 1, + "llm_prompt": f"{TOPIC_PROMPT} Scope name: {name}.", + **overrides, + } + return client.post(BASE_URL, json=payload, headers=self._headers(api_key)) + + def create_answer(self, client, api_key=DEFAULT_API_KEY, **overrides): + payload = { + "validator_name": "answer_relevance_custom_llm", + "name": "Health Relevance", + "description": "Checks LLM answer relevance for health queries", + "llm_prompt": ANSWER_PROMPT, + **overrides, + } + return client.post(BASE_URL, json=payload, headers=self._headers(api_key)) + + def list(self, client, api_key=DEFAULT_API_KEY, **filters): + return client.get(BASE_URL, params=filters, headers=self._headers(api_key)) + + def get(self, client, id, api_key=DEFAULT_API_KEY): + return client.get(f"{BASE_URL}{id}", headers=self._headers(api_key)) + + def update(self, client, id, payload, api_key=DEFAULT_API_KEY): + return client.patch( + f"{BASE_URL}{id}", + json=payload, + headers=self._headers(api_key), + ) + + def delete(self, client, id, api_key=DEFAULT_API_KEY): + return client.delete(f"{BASE_URL}{id}", headers=self._headers(api_key)) + + +class TestCreateLLMPromptConfig(BaseLLMPromptConfigTest): + def test_create_topic_relevance_success(self, integration_client, clear_database): + response = self.create_topic(integration_client) + + assert response.status_code == 200 + data = response.json()["data"] + assert data["validator_name"] == "topic_relevance" + assert data["name"] == "Maternal Health Scope" + assert "Pregnancy care" in data["llm_prompt"] + assert data["prompt_schema_version"] == 1 + assert data["is_active"] is True + assert "id" in data + + def test_create_answer_relevance_success(self, integration_client, clear_database): + response = self.create_answer(integration_client) + + assert response.status_code == 200 + data = response.json()["data"] + assert data["validator_name"] == "answer_relevance_custom_llm" + assert "{query}" in data["llm_prompt"] + assert "{answer}" in data["llm_prompt"] + assert data["is_active"] is True + + def test_create_answer_relevance_custom_prompt( + self, integration_client, clear_database + ): + response = self.create_answer( + integration_client, + name="Custom Health Prompt", + llm_prompt=CUSTOM_ANSWER_PROMPT, + ) + + assert response.status_code == 200 + assert "health assistant" in response.json()["data"]["llm_prompt"] + + def test_create_validation_error_missing_required_fields( + self, integration_client, clear_database + ): + response = integration_client.post( + BASE_URL, + json={"name": "incomplete"}, + headers=self._headers(), + ) + assert response.status_code == 422 + + def test_create_validation_error_invalid_validator_name( + self, integration_client, clear_database + ): + response = integration_client.post( + BASE_URL, + json={ + "validator_name": "unknown_validator", + "name": "test", + "description": "test", + "llm_prompt": "test prompt", + }, + headers=self._headers(), + ) + assert response.status_code == 422 + + def test_create_answer_relevance_missing_query_placeholder( + self, integration_client, clear_database + ): + response = self.create_answer( + integration_client, + llm_prompt="Answer: {answer}\nRelevant? YES or NO.", + ) + assert response.status_code == 422 + + def test_create_answer_relevance_missing_answer_placeholder( + self, integration_client, clear_database + ): + response = self.create_answer( + integration_client, + llm_prompt="Query: {query}\nRelevant? YES or NO.", + ) + assert response.status_code == 422 + + def test_create_topic_relevance_no_placeholder_validation( + self, integration_client, clear_database + ): + response = self.create_topic( + integration_client, + llm_prompt="A plain scope description without any placeholders.", + ) + assert response.status_code == 200 + + def test_create_validation_error_name_too_long( + self, integration_client, clear_database + ): + response = self.create_topic( + integration_client, + name="n" * (MAX_NAME_LENGTH + 1), + ) + assert response.status_code == 422 + + def test_create_validation_error_description_too_long( + self, integration_client, clear_database + ): + response = self.create_topic( + integration_client, + description="d" * (MAX_DESCRIPTION_LENGTH + 1), + ) + assert response.status_code == 422 + + +class TestListLLMPromptConfigs(BaseLLMPromptConfigTest): + def test_list_all_success(self, integration_client, clear_database): + self.create_topic(integration_client, name="Scope 1") + self.create_topic(integration_client, name="Scope 2") + self.create_answer(integration_client, name="Answer Config 1") + + response = self.list(integration_client) + + assert response.status_code == 200 + assert len(response.json()["data"]) == 3 + + def test_list_filtered_by_validator_name(self, integration_client, clear_database): + self.create_topic(integration_client, name="Scope 1") + self.create_topic(integration_client, name="Scope 2") + self.create_answer(integration_client, name="Answer Config") + + response = self.list(integration_client, validator_name="topic_relevance") + + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 2 + assert all(d["validator_name"] == "topic_relevance" for d in data) + + def test_list_empty(self, integration_client, clear_database): + response = self.list(integration_client) + + assert response.status_code == 200 + assert response.json()["data"] == [] + + def test_list_pagination_with_limit(self, integration_client, clear_database): + for i in range(4): + self.create_topic(integration_client, name=f"Scope {i}") + + response = self.list(integration_client, limit=2) + + assert response.status_code == 200 + assert len(response.json()["data"]) == 2 + + def test_list_is_tenant_scoped(self, integration_client, clear_database): + self.create_topic(integration_client, name="Tenant1 scope") + + response = self.list(integration_client, api_key=ALT_API_KEY) + + assert response.status_code == 200 + assert response.json()["data"] == [] + + +class TestGetLLMPromptConfig(BaseLLMPromptConfigTest): + def test_get_success(self, integration_client, clear_database): + create_resp = self.create_topic(integration_client) + config_id = create_resp.json()["data"]["id"] + + response = self.get(integration_client, config_id) + + assert response.status_code == 200 + data = response.json()["data"] + assert data["id"] == config_id + assert data["validator_name"] == "topic_relevance" + + def test_get_not_found(self, integration_client, clear_database): + response = self.get(integration_client, uuid.uuid4()) + body = response.json() + + assert response.status_code == 404 + assert body["success"] is False + assert "LLM prompt config not found" in body["error"] + + def test_get_other_tenant_not_found(self, integration_client, clear_database): + create_resp = self.create_topic(integration_client) + config_id = create_resp.json()["data"]["id"] + + response = self.get(integration_client, config_id, api_key=ALT_API_KEY) + + assert response.status_code == 404 + assert response.json()["success"] is False + + +class TestUpdateLLMPromptConfig(BaseLLMPromptConfigTest): + def test_update_name_success(self, integration_client, clear_database): + create_resp = self.create_topic(integration_client) + config_id = create_resp.json()["data"]["id"] + + response = self.update(integration_client, config_id, {"name": "Updated scope"}) + + assert response.status_code == 200 + assert response.json()["data"]["name"] == "Updated scope" + + def test_update_is_active_false(self, integration_client, clear_database): + create_resp = self.create_topic(integration_client) + config_id = create_resp.json()["data"]["id"] + + response = self.update(integration_client, config_id, {"is_active": False}) + + assert response.status_code == 200 + assert response.json()["data"]["is_active"] is False + + def test_partial_update_preserves_other_fields( + self, integration_client, clear_database + ): + create_resp = self.create_answer(integration_client) + original = create_resp.json()["data"] + config_id = original["id"] + + self.update(integration_client, config_id, {"name": "New Name"}) + response = self.get(integration_client, config_id) + data = response.json()["data"] + + assert data["name"] == "New Name" + assert data["llm_prompt"] == original["llm_prompt"] + assert data["description"] == original["description"] + + def test_update_not_found(self, integration_client, clear_database): + response = self.update(integration_client, uuid.uuid4(), {"name": "x"}) + + assert response.status_code == 404 + assert "LLM prompt config not found" in response.json()["error"] + + def test_update_other_tenant_not_found(self, integration_client, clear_database): + create_resp = self.create_topic(integration_client) + config_id = create_resp.json()["data"]["id"] + + response = self.update( + integration_client, + config_id, + {"name": "other-tenant-update"}, + api_key=ALT_API_KEY, + ) + + assert response.status_code == 404 + + def test_update_answer_relevance_llm_prompt_missing_query_placeholder( + self, integration_client, clear_database + ): + create_resp = self.create_answer(integration_client) + config_id = create_resp.json()["data"]["id"] + + response = self.update( + integration_client, + config_id, + {"llm_prompt": "Answer: {answer}\nRelevant? YES or NO."}, + ) + + assert response.status_code == 422 + + def test_update_answer_relevance_llm_prompt_missing_answer_placeholder( + self, integration_client, clear_database + ): + create_resp = self.create_answer(integration_client) + config_id = create_resp.json()["data"]["id"] + + response = self.update( + integration_client, + config_id, + {"llm_prompt": "Query: {query}\nRelevant? YES or NO."}, + ) + + assert response.status_code == 422 + + def test_update_answer_relevance_valid_llm_prompt_succeeds( + self, integration_client, clear_database + ): + create_resp = self.create_answer(integration_client) + config_id = create_resp.json()["data"]["id"] + + new_prompt = "Q: {query}\nA: {answer}\nYES or NO." + response = self.update( + integration_client, config_id, {"llm_prompt": new_prompt} + ) + + assert response.status_code == 200 + assert response.json()["data"]["llm_prompt"] == new_prompt + + def test_update_topic_relevance_llm_prompt_no_placeholder_required( + self, integration_client, clear_database + ): + create_resp = self.create_topic(integration_client) + config_id = create_resp.json()["data"]["id"] + + response = self.update( + integration_client, + config_id, + {"llm_prompt": "A plain scope with no placeholders."}, + ) + + assert response.status_code == 200 + + +class TestDeleteLLMPromptConfig(BaseLLMPromptConfigTest): + def test_delete_success(self, integration_client, clear_database): + create_resp = self.create_topic(integration_client) + config_id = create_resp.json()["data"]["id"] + + response = self.delete(integration_client, config_id) + + assert response.status_code == 200 + assert response.json()["success"] is True + assert "deleted" in response.json()["data"]["message"].lower() + + def test_delete_removes_from_list(self, integration_client, clear_database): + create_resp = self.create_topic(integration_client) + config_id = create_resp.json()["data"]["id"] + + self.delete(integration_client, config_id) + + ids = [item["id"] for item in self.list(integration_client).json()["data"]] + assert config_id not in ids + + def test_delete_not_found(self, integration_client, clear_database): + response = self.delete(integration_client, uuid.uuid4()) + + assert response.status_code == 404 + assert "LLM prompt config not found" in response.json()["error"] + + def test_delete_other_tenant_not_found(self, integration_client, clear_database): + create_resp = self.create_topic(integration_client) + config_id = create_resp.json()["data"]["id"] + + response = self.delete(integration_client, config_id, api_key=ALT_API_KEY) + + assert response.status_code == 404 diff --git a/backend/app/tests/test_llm_validators.py b/backend/app/tests/test_llm_validators.py index 5834843..e300724 100644 --- a/backend/app/tests/test_llm_validators.py +++ b/backend/app/tests/test_llm_validators.py @@ -3,6 +3,9 @@ import pytest from guardrails.validators import FailResult +from app.core.validators.config.answer_relevance_custom_llm_safety_validator_config import ( + AnswerRelevanceCustomLLMSafetyValidatorConfig, +) from app.core.validators.config.topic_relevance_safety_validator_config import ( TopicRelevanceSafetyValidatorConfig, ) @@ -118,3 +121,94 @@ def test__normalize_llm_critic_error_passes_through_unknown_messages(): _normalize_llm_critic_error(raw) == "The query did not meet the required quality criteria." ) + + +# --------------------------------------------------------------------------- +# AnswerRelevanceCustomLLMSafetyValidatorConfig +# --------------------------------------------------------------------------- + +_ANSWER_RELEVANCE_SETTINGS_PATH = ( + "app.core.validators.config" + ".answer_relevance_custom_llm_safety_validator_config.settings" +) + +_SAMPLE_ANSWER_RELEVANCE_CONFIG = dict(type="answer_relevance_custom_llm") + + +def test_answer_relevance_build_raises_when_openai_key_missing(): + config = AnswerRelevanceCustomLLMSafetyValidatorConfig( + **_SAMPLE_ANSWER_RELEVANCE_CONFIG + ) + + with patch(_ANSWER_RELEVANCE_SETTINGS_PATH) as mock_settings: + mock_settings.OPENAI_API_KEY = None + + with pytest.raises(ValueError) as exc: + config.build() + + assert "OPENAI_API_KEY" in str(exc.value) + assert "not configured" in str(exc.value) + + +def test_answer_relevance_build_proceeds_when_openai_key_present(): + config = AnswerRelevanceCustomLLMSafetyValidatorConfig( + **_SAMPLE_ANSWER_RELEVANCE_CONFIG + ) + + with patch(_ANSWER_RELEVANCE_SETTINGS_PATH) as mock_settings, patch( + "app.core.validators.config" + ".answer_relevance_custom_llm_safety_validator_config.AnswerRelevanceCustomLLM" + ) as mock_validator: + mock_settings.OPENAI_API_KEY = "sk-test-key" + config.build() + + mock_validator.assert_called_once() + + +def test_answer_relevance_build_uses_default_prompt_when_none(): + config = AnswerRelevanceCustomLLMSafetyValidatorConfig( + **_SAMPLE_ANSWER_RELEVANCE_CONFIG + ) + + with patch(_ANSWER_RELEVANCE_SETTINGS_PATH) as mock_settings, patch( + "app.core.validators.config" + ".answer_relevance_custom_llm_safety_validator_config.AnswerRelevanceCustomLLM" + ) as mock_validator: + mock_settings.OPENAI_API_KEY = "sk-test-key" + config.build() + + _, kwargs = mock_validator.call_args + assert "prompt_template" not in kwargs + + +def test_answer_relevance_build_passes_inline_prompt_template(): + custom = "Q: {query}\nA: {answer}\nYES or NO." + config = AnswerRelevanceCustomLLMSafetyValidatorConfig( + **{**_SAMPLE_ANSWER_RELEVANCE_CONFIG, "prompt_template": custom} + ) + + with patch(_ANSWER_RELEVANCE_SETTINGS_PATH) as mock_settings, patch( + "app.core.validators.config" + ".answer_relevance_custom_llm_safety_validator_config.AnswerRelevanceCustomLLM" + ) as mock_validator: + mock_settings.OPENAI_API_KEY = "sk-test-key" + config.build() + + _, kwargs = mock_validator.call_args + assert kwargs["prompt_template"] == custom + + +def test_answer_relevance_build_passes_llm_callable(): + config = AnswerRelevanceCustomLLMSafetyValidatorConfig( + **{**_SAMPLE_ANSWER_RELEVANCE_CONFIG, "llm_callable": "gpt-4o"} + ) + + with patch(_ANSWER_RELEVANCE_SETTINGS_PATH) as mock_settings, patch( + "app.core.validators.config" + ".answer_relevance_custom_llm_safety_validator_config.AnswerRelevanceCustomLLM" + ) as mock_validator: + mock_settings.OPENAI_API_KEY = "sk-test-key" + config.build() + + _, kwargs = mock_validator.call_args + assert kwargs["llm_callable"] == "gpt-4o" diff --git a/backend/app/tests/test_topic_relevance_configs_api.py b/backend/app/tests/test_topic_relevance_configs_api.py deleted file mode 100644 index c8c166c..0000000 --- a/backend/app/tests/test_topic_relevance_configs_api.py +++ /dev/null @@ -1,145 +0,0 @@ -from unittest.mock import MagicMock, patch -from uuid import UUID - -import pytest -from sqlmodel import Session - -from app.api.deps import TenantContext -from app.api.routes.topic_relevance_configs import ( - create_topic_relevance_config, - delete_topic_relevance_config, - get_topic_relevance_config, - list_topic_relevance_configs, - update_topic_relevance_config, -) -from app.schemas.topic_relevance import TopicRelevanceCreate, TopicRelevanceUpdate - -TOPIC_RELEVANCE_TEST_ID = UUID("223e4567-e89b-12d3-a456-426614174111") -TOPIC_RELEVANCE_TEST_ORGANIZATION_ID = 101 -TOPIC_RELEVANCE_TEST_PROJECT_ID = 202 - - -@pytest.fixture -def mock_session(): - return MagicMock(spec=Session) - - -@pytest.fixture -def sample_topic_relevance(): - obj = MagicMock() - obj.id = TOPIC_RELEVANCE_TEST_ID - obj.name = "Maternal Health Scope" - obj.description = "Topic scope for maternal health bot" - obj.prompt_schema_version = 1 - obj.configuration = ( - "Pregnancy care: Questions related to prenatal care and supplements." - ) - obj.is_active = True - obj.organization_id = TOPIC_RELEVANCE_TEST_ORGANIZATION_ID - obj.project_id = TOPIC_RELEVANCE_TEST_PROJECT_ID - return obj - - -@pytest.fixture -def create_payload(): - return TopicRelevanceCreate( - name="Maternal Health Scope", - description="Topic scope for maternal health bot", - prompt_schema_version=1, - configuration="Pregnancy care: Questions related to prenatal care and supplements.", - ) - - -@pytest.fixture -def auth_context(): - return TenantContext( - organization_id=TOPIC_RELEVANCE_TEST_ORGANIZATION_ID, - project_id=TOPIC_RELEVANCE_TEST_PROJECT_ID, - ) - - -def test_create_calls_crud( - mock_session, create_payload, sample_topic_relevance, auth_context -): - with patch("app.api.routes.topic_relevance_configs.topic_relevance_crud") as crud: - crud.create.return_value = sample_topic_relevance - - result = create_topic_relevance_config( - payload=create_payload, - session=mock_session, - auth=auth_context, - ) - - assert result.data == sample_topic_relevance - - -def test_list_returns_data(mock_session, sample_topic_relevance, auth_context): - with patch("app.api.routes.topic_relevance_configs.topic_relevance_crud") as crud: - crud.list.return_value = [sample_topic_relevance] - - result = list_topic_relevance_configs( - session=mock_session, - auth=auth_context, - ) - - crud.list.assert_called_once_with( - mock_session, - TOPIC_RELEVANCE_TEST_ORGANIZATION_ID, - TOPIC_RELEVANCE_TEST_PROJECT_ID, - 0, - None, - ) - assert len(result.data) == 1 - - -def test_get_success(mock_session, sample_topic_relevance, auth_context): - with patch("app.api.routes.topic_relevance_configs.topic_relevance_crud") as crud: - crud.get.return_value = sample_topic_relevance - - result = get_topic_relevance_config( - id=TOPIC_RELEVANCE_TEST_ID, - session=mock_session, - auth=auth_context, - ) - - assert result.data == sample_topic_relevance - - -def test_update_success(mock_session, sample_topic_relevance, auth_context): - with patch("app.api.routes.topic_relevance_configs.topic_relevance_crud") as crud: - crud.update.return_value = sample_topic_relevance - - result = update_topic_relevance_config( - id=TOPIC_RELEVANCE_TEST_ID, - payload=TopicRelevanceUpdate(name="updated"), - session=mock_session, - auth=auth_context, - ) - - crud.update.assert_called_once() - args, _ = crud.update.call_args - assert args[1] == TOPIC_RELEVANCE_TEST_ID - assert args[2] == TOPIC_RELEVANCE_TEST_ORGANIZATION_ID - assert args[3] == TOPIC_RELEVANCE_TEST_PROJECT_ID - assert args[4].name == "updated" - assert result.data == sample_topic_relevance - - -def test_delete_success(mock_session, sample_topic_relevance, auth_context): - with patch("app.api.routes.topic_relevance_configs.topic_relevance_crud") as crud: - crud.get.return_value = sample_topic_relevance - - result = delete_topic_relevance_config( - id=TOPIC_RELEVANCE_TEST_ID, - session=mock_session, - auth=auth_context, - ) - - crud.get.assert_called_once_with( - mock_session, - TOPIC_RELEVANCE_TEST_ID, - TOPIC_RELEVANCE_TEST_ORGANIZATION_ID, - TOPIC_RELEVANCE_TEST_PROJECT_ID, - ) - crud.delete.assert_called_once_with(mock_session, sample_topic_relevance) - assert result.success is True diff --git a/backend/app/tests/test_topic_relevance_configs_api_integration.py b/backend/app/tests/test_topic_relevance_configs_api_integration.py deleted file mode 100644 index 8f31ec8..0000000 --- a/backend/app/tests/test_topic_relevance_configs_api_integration.py +++ /dev/null @@ -1,261 +0,0 @@ -import uuid - -import pytest - -from app.schemas.topic_relevance import MAX_TOPIC_RELEVANCE_NAME_LENGTH - -pytestmark = pytest.mark.integration - -BASE_URL = "/api/v1/guardrails/topic_relevance_configs/" -DEFAULT_API_KEY = "org1_project1" -ALT_API_KEY = "org999_project999" - - -class BaseTopicRelevanceTest: - def _headers(self, api_key=DEFAULT_API_KEY): - return {"X-API-Key": api_key} - - def create(self, client, api_key=DEFAULT_API_KEY, **kwargs): - name = kwargs.get("name", "Maternal Health Scope") - payload = { - "name": name, - "description": "Topic guard for maternal health support bot", - "prompt_schema_version": 1, - "configuration": ( - "Pregnancy care: Questions about prenatal care, supplements, and " - "danger signs. Postpartum care: Questions about recovery after " - f"delivery and breastfeeding. Scope name: {name}." - ), - **kwargs, - } - return client.post(BASE_URL, json=payload, headers=self._headers(api_key)) - - def list(self, client, api_key=DEFAULT_API_KEY, **filters): - return client.get(BASE_URL, params=filters, headers=self._headers(api_key)) - - def get(self, client, id, api_key=DEFAULT_API_KEY): - return client.get(f"{BASE_URL}{id}", headers=self._headers(api_key)) - - def update(self, client, id, payload, api_key=DEFAULT_API_KEY): - return client.patch( - f"{BASE_URL}{id}", - json=payload, - headers=self._headers(api_key), - ) - - def delete(self, client, id, api_key=DEFAULT_API_KEY): - return client.delete(f"{BASE_URL}{id}", headers=self._headers(api_key)) - - -class TestCreateTopicRelevanceConfig(BaseTopicRelevanceTest): - def test_create_success(self, integration_client, clear_database): - response = self.create(integration_client) - - assert response.status_code == 200 - data = response.json()["data"] - - assert data["name"] == "Maternal Health Scope" - assert data["prompt_schema_version"] == 1 - assert "Pregnancy care" in data["configuration"] - - def test_create_validation_error_missing_required_fields( - self, integration_client, clear_database - ): - response = integration_client.post( - BASE_URL, - json={"name": "missing config"}, - headers=self._headers(), - ) - - assert response.status_code == 422 - - def test_create_validation_error_name_too_long( - self, integration_client, clear_database - ): - response = self.create( - integration_client, - name="n" * (MAX_TOPIC_RELEVANCE_NAME_LENGTH + 1), - ) - - assert response.status_code == 422 - - -class TestListTopicRelevanceConfigs(BaseTopicRelevanceTest): - def test_list_success(self, integration_client, clear_database): - assert self.create(integration_client, name="Scope 1").status_code == 200 - assert self.create(integration_client, name="Scope 2").status_code == 200 - assert self.create(integration_client, name="Scope 3").status_code == 200 - - response = self.list(integration_client) - - assert response.status_code == 200 - data = response.json()["data"] - assert len(data) == 3 - - def test_list_empty(self, integration_client, clear_database): - response = self.list(integration_client) - - assert response.status_code == 200 - assert response.json()["data"] == [] - - def test_list_pagination_with_limit(self, integration_client, clear_database): - assert self.create(integration_client, name="Scope 1").status_code == 200 - assert self.create(integration_client, name="Scope 2").status_code == 200 - assert self.create(integration_client, name="Scope 3").status_code == 200 - - response = self.list(integration_client, limit=2) - - assert response.status_code == 200 - assert len(response.json()["data"]) == 2 - - def test_list_pagination_with_offset_and_limit( - self, integration_client, clear_database - ): - assert self.create(integration_client, name="Scope 1").status_code == 200 - assert self.create(integration_client, name="Scope 2").status_code == 200 - assert self.create(integration_client, name="Scope 3").status_code == 200 - assert self.create(integration_client, name="Scope 4").status_code == 200 - - full_response = self.list(integration_client) - full_data = full_response.json()["data"] - - response = self.list(integration_client, offset=2, limit=2) - - assert response.status_code == 200 - paged_data = response.json()["data"] - assert len(paged_data) == 2 - assert [item["id"] for item in paged_data] == [ - item["id"] for item in full_data[2:4] - ] - - def test_list_is_tenant_scoped(self, integration_client, clear_database): - self.create(integration_client, name="Tenant1 scope") - - response = self.list(integration_client, api_key=ALT_API_KEY) - - assert response.status_code == 200 - assert response.json()["data"] == [] - - -class TestGetTopicRelevanceConfig(BaseTopicRelevanceTest): - def test_get_success(self, integration_client, clear_database): - create_resp = self.create(integration_client) - config_id = create_resp.json()["data"]["id"] - - response = self.get(integration_client, config_id) - - assert response.status_code == 200 - assert response.json()["data"]["id"] == config_id - - def test_get_not_found(self, integration_client, clear_database): - fake = uuid.uuid4() - - response = self.get(integration_client, fake) - body = response.json() - - assert response.status_code == 404 - assert body["success"] is False - assert "Topic relevance preset not found" in body["error"] - - def test_get_other_tenant_not_found(self, integration_client, clear_database): - create_resp = self.create(integration_client) - config_id = create_resp.json()["data"]["id"] - - response = self.get(integration_client, config_id, api_key=ALT_API_KEY) - body = response.json() - - assert response.status_code == 404 - assert body["success"] is False - assert "Topic relevance preset not found" in body["error"] - - -class TestUpdateTopicRelevanceConfig(BaseTopicRelevanceTest): - def test_update_success(self, integration_client, clear_database): - create_resp = self.create(integration_client) - config_id = create_resp.json()["data"]["id"] - - response = self.update( - integration_client, - config_id, - {"name": "Updated scope", "prompt_schema_version": 1}, - ) - - assert response.status_code == 200 - data = response.json()["data"] - assert data["name"] == "Updated scope" - assert data["prompt_schema_version"] == 1 - - def test_partial_update(self, integration_client, clear_database): - create_resp = self.create(integration_client) - config_id = create_resp.json()["data"]["id"] - - response = self.update( - integration_client, - config_id, - {"is_active": False}, - ) - - assert response.status_code == 200 - assert response.json()["data"]["is_active"] is False - - def test_update_not_found(self, integration_client, clear_database): - fake = uuid.uuid4() - - response = self.update(integration_client, fake, {"name": "x"}) - body = response.json() - - assert response.status_code == 404 - assert body["success"] is False - assert "Topic relevance preset not found" in body["error"] - - def test_update_other_tenant_not_found(self, integration_client, clear_database): - create_resp = self.create(integration_client) - config_id = create_resp.json()["data"]["id"] - - response = self.update( - integration_client, - config_id, - {"name": "updated-by-other-tenant"}, - api_key=ALT_API_KEY, - ) - body = response.json() - - assert response.status_code == 404 - assert body["success"] is False - assert "Topic relevance preset not found" in body["error"] - - -class TestDeleteTopicRelevanceConfig(BaseTopicRelevanceTest): - def test_delete_success(self, integration_client, clear_database): - create_resp = self.create(integration_client) - config_id = create_resp.json()["data"]["id"] - - response = self.delete(integration_client, config_id) - - assert response.status_code == 200 - assert response.json()["success"] is True - - def test_delete_not_found(self, integration_client, clear_database): - fake = uuid.uuid4() - - response = self.delete(integration_client, fake) - body = response.json() - - assert response.status_code == 404 - assert body["success"] is False - assert "Topic relevance preset not found" in body["error"] - - def test_delete_other_tenant_not_found(self, integration_client, clear_database): - create_resp = self.create(integration_client) - config_id = create_resp.json()["data"]["id"] - - response = self.delete( - integration_client, - config_id, - api_key=ALT_API_KEY, - ) - body = response.json() - - assert response.status_code == 404 - assert body["success"] is False - assert "Topic relevance preset not found" in body["error"] diff --git a/backend/app/tests/test_validate_with_guard.py b/backend/app/tests/test_validate_with_guard.py index 2956512..797d6f6 100644 --- a/backend/app/tests/test_validate_with_guard.py +++ b/backend/app/tests/test_validate_with_guard.py @@ -1,12 +1,15 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 +import pytest +from fastapi import HTTPException from guardrails.validators import FailResult as GRFailResult from app.api.routes.guardrails import ( _resolve_validator_configs, _validate_with_guard, ) +from app.core.enum import LLMValidatorName from app.schemas.guardrail_config import GuardrailRequest from app.tests.guardrails_mocks import MockResult from app.tests.seed_data import ( @@ -35,12 +38,14 @@ class MockGuard: def validate(self, data): return MockResult(validated_output="clean text") + payload = _build_payload("hello") with patch( "app.api.routes.guardrails.build_guard", return_value=MockGuard(), ): response = _validate_with_guard( - payload=_build_payload("hello"), + payload=payload, + data=payload.input, request_log_crud=mock_request_log_crud, request_log_id=mock_request_log_id, validator_log_crud=mock_validator_log_crud, @@ -57,12 +62,14 @@ class MockGuard: def validate(self, data): return MockResult(validated_output=None) + payload = _build_payload("bad text") with patch( "app.api.routes.guardrails.build_guard", return_value=MockGuard(), ): response = _validate_with_guard( - payload=_build_payload("bad text"), + payload=payload, + data=payload.input, request_log_crud=mock_request_log_crud, request_log_id=mock_request_log_id, validator_log_crud=mock_validator_log_crud, @@ -75,12 +82,14 @@ def validate(self, data): def test_validate_with_guard_exception(): + payload = _build_payload("text") with patch( "app.api.routes.guardrails.build_guard", side_effect=Exception("Invalid config"), ): response = _validate_with_guard( - payload=_build_payload("text"), + payload=payload, + data=payload.input, request_log_crud=mock_request_log_crud, request_log_id=mock_request_log_id, validator_log_crud=mock_validator_log_crud, @@ -117,11 +126,13 @@ class MockGuard: def validate(self, data): return MockResult(validated_output=None) + payload = _build_payload("bad text") with patch( "app.api.routes.guardrails.build_guard", return_value=MockGuard() ), patch("app.api.routes.guardrails.add_validator_logs"): response = _validate_with_guard( - payload=_build_payload("bad text"), + payload=payload, + data=payload.input, request_log_crud=mock_request_log_crud, request_log_id=mock_request_log_id, validator_log_crud=mock_validator_log_crud, @@ -143,12 +154,14 @@ class last: def validate(self, data): return MockResult(validated_output=None) + payload = _build_payload("bad text") with patch( "app.api.routes.guardrails.build_guard", return_value=MockGuard(), ): response = _validate_with_guard( - payload=_build_payload("bad text"), + payload=payload, + data=payload.input, request_log_crud=mock_request_log_crud, request_log_id=mock_request_log_id, validator_log_crud=mock_validator_log_crud, @@ -213,9 +226,10 @@ def test_resolve_validator_configs_topic_relevance_from_config_id(): ) mock_session = MagicMock() - with patch("app.api.routes.guardrails.topic_relevance_crud.get") as mock_get: + with patch("app.api.routes.guardrails.llm_prompt_config_crud.get") as mock_get: mock_get.return_value = MagicMock( - configuration="Topic scope prompt text", + validator_name=LLMValidatorName.TopicRelevance, + llm_prompt="Topic scope prompt text", prompt_schema_version=2, ) _resolve_validator_configs(payload, mock_session) @@ -241,7 +255,7 @@ def test_resolve_validator_configs_skips_topic_relevance_lookup_when_no_config_i ) mock_session = MagicMock() - with patch("app.api.routes.guardrails.topic_relevance_crud.get") as mock_get: + with patch("app.api.routes.guardrails.llm_prompt_config_crud.get") as mock_get: _resolve_validator_configs(payload, mock_session) mock_get.assert_not_called() @@ -262,7 +276,7 @@ def test_resolve_validator_configs_uses_inline_topic_relevance_without_lookup(): ) mock_session = MagicMock() - with patch("app.api.routes.guardrails.topic_relevance_crud.get") as mock_get: + with patch("app.api.routes.guardrails.llm_prompt_config_crud.get") as mock_get: _resolve_validator_configs(payload, mock_session) validator = payload.validators[0] @@ -270,6 +284,79 @@ def test_resolve_validator_configs_uses_inline_topic_relevance_without_lookup(): mock_get.assert_not_called() +def test_resolve_validator_configs_answer_relevance_from_custom_prompt_id(): + custom_prompt_id = str(uuid4()) + payload = GuardrailRequest( + request_id=str(uuid4()), + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + input="{}", + validators=[ + { + "type": "answer_relevance_custom_llm", + "custom_prompt_id": custom_prompt_id, + } + ], + ) + mock_session = MagicMock() + + with patch("app.api.routes.guardrails.llm_prompt_config_crud.get") as mock_get: + mock_get.return_value = MagicMock( + validator_name=LLMValidatorName.AnswerRelevanceCustomLLM, + llm_prompt="Q: {query}\nA: {answer}\nYES or NO.", + ) + _resolve_validator_configs(payload, mock_session) + + validator = payload.validators[0] + assert validator.prompt_template == "Q: {query}\nA: {answer}\nYES or NO." + mock_get.assert_called_once_with( + session=mock_session, + id=validator.custom_prompt_id, + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + ) + + +def test_resolve_validator_configs_skips_answer_relevance_lookup_when_no_prompt_id(): + payload = GuardrailRequest( + request_id=str(uuid4()), + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + input="{}", + validators=[{"type": "answer_relevance_custom_llm"}], + ) + mock_session = MagicMock() + + with patch("app.api.routes.guardrails.llm_prompt_config_crud.get") as mock_get: + _resolve_validator_configs(payload, mock_session) + + mock_get.assert_not_called() + + +def test_resolve_validator_configs_uses_inline_answer_relevance_prompt_without_lookup(): + inline_template = "Query: {query}\nAnswer: {answer}\nYES or NO." + payload = GuardrailRequest( + request_id=str(uuid4()), + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + input="{}", + validators=[ + { + "type": "answer_relevance_custom_llm", + "prompt_template": inline_template, + } + ], + ) + mock_session = MagicMock() + + with patch("app.api.routes.guardrails.llm_prompt_config_crud.get") as mock_get: + _resolve_validator_configs(payload, mock_session) + + validator = payload.validators[0] + assert validator.prompt_template == inline_template + mock_get.assert_not_called() + + def _build_mock_guard_with_fail_result(validator_name: str, error_message: str): mock_log = MagicMock() mock_log.validator_name = validator_name @@ -308,6 +395,7 @@ def test_nsfw_error_message_redacts_input(): ), patch("app.api.routes.guardrails.add_validator_logs"): response = _validate_with_guard( payload=_build_payload(unsafe_input), + data=unsafe_input, request_log_crud=mock_request_log_crud, request_log_id=mock_request_log_id, validator_log_crud=mock_validator_log_crud, @@ -329,6 +417,7 @@ def test_all_validators_redact_input_from_error_message(): ), patch("app.api.routes.guardrails.add_validator_logs"): response = _validate_with_guard( payload=_build_payload(input_text), + data=input_text, request_log_crud=mock_request_log_crud, request_log_id=mock_request_log_id, validator_log_crud=mock_validator_log_crud, @@ -351,6 +440,7 @@ def test_profanity_free_error_message_redacts_input(): ), patch("app.api.routes.guardrails.add_validator_logs"): response = _validate_with_guard( payload=_build_payload(unsafe_input), + data=unsafe_input, request_log_crud=mock_request_log_crud, request_log_id=mock_request_log_id, validator_log_crud=mock_validator_log_crud, @@ -374,6 +464,7 @@ def test_nsfw_exception_redacts_input(): ): response = _validate_with_guard( payload=_build_payload(unsafe_input), + data=unsafe_input, request_log_crud=mock_request_log_crud, request_log_id=mock_request_log_id, validator_log_crud=mock_validator_log_crud, @@ -397,6 +488,7 @@ def test_profanity_free_exception_redacts_input(): ): response = _validate_with_guard( payload=_build_payload(unsafe_input), + data=unsafe_input, request_log_crud=mock_request_log_crud, request_log_id=mock_request_log_id, validator_log_crud=mock_validator_log_crud, @@ -404,3 +496,60 @@ def test_profanity_free_exception_redacts_input(): assert response.success is False assert unsafe_input not in response.error + + +def test_resolve_validator_configs_rejects_topic_relevance_config_used_for_answer_relevance(): + """Passing an answer_relevance_custom_llm config ID to the topic_relevance validator + must raise a 400 — validator_name mismatch.""" + config_id = str(uuid4()) + payload = GuardrailRequest( + request_id=str(uuid4()), + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + input="test", + validators=[ + {"type": "topic_relevance", "topic_relevance_config_id": config_id} + ], + ) + mock_session = MagicMock() + + with patch("app.api.routes.guardrails.llm_prompt_config_crud.get") as mock_get: + mock_get.return_value = MagicMock( + id=config_id, + validator_name=LLMValidatorName.AnswerRelevanceCustomLLM, + llm_prompt="Q: {query}\nA: {answer}\nYES or NO.", + prompt_schema_version=1, + ) + with pytest.raises(HTTPException) as exc_info: + _resolve_validator_configs(payload, mock_session) + + assert exc_info.value.status_code == 400 + assert "topic_relevance" in exc_info.value.detail + + +def test_resolve_validator_configs_rejects_topic_relevance_config_used_for_answer_relevance_prompt(): + """Passing a topic_relevance config ID to the answer_relevance_custom_llm validator + must raise a 400 — validator_name mismatch.""" + config_id = str(uuid4()) + payload = GuardrailRequest( + request_id=str(uuid4()), + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + input="{}", + validators=[ + {"type": "answer_relevance_custom_llm", "custom_prompt_id": config_id} + ], + ) + mock_session = MagicMock() + + with patch("app.api.routes.guardrails.llm_prompt_config_crud.get") as mock_get: + mock_get.return_value = MagicMock( + id=config_id, + validator_name=LLMValidatorName.TopicRelevance, + llm_prompt="A plain scope description.", + ) + with pytest.raises(HTTPException) as exc_info: + _resolve_validator_configs(payload, mock_session) + + assert exc_info.value.status_code == 400 + assert "answer_relevance_custom_llm" in exc_info.value.detail diff --git a/backend/app/tests/validators/test_answer_relevance_custom_llm.py b/backend/app/tests/validators/test_answer_relevance_custom_llm.py new file mode 100644 index 0000000..9427003 --- /dev/null +++ b/backend/app/tests/validators/test_answer_relevance_custom_llm.py @@ -0,0 +1,237 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest +from guardrails.validators import FailResult, PassResult + +from app.core.validators.answer_relevance_custom_llm import ( + DEFAULT_PROMPT_TEMPLATE, + AnswerRelevanceCustomLLM, +) + +VALID_INPUT = json.dumps( + {"query": "What causes fever?", "answer": "Infections cause fever."} +) +VALID_INPUT_YES = VALID_INPUT +VALID_INPUT_NO = json.dumps( + {"query": "What causes fever?", "answer": "The sky is blue."} +) + + +def _make_llm_response(text: str): + choice = MagicMock() + choice.message.content = text + result = MagicMock() + result.choices = [choice] + return result + + +@pytest.fixture +def validator(): + return AnswerRelevanceCustomLLM() + + +# --------------------------------------------------------------------------- +# Default prompt template shape +# --------------------------------------------------------------------------- + + +def test_default_prompt_template_has_query_placeholder(): + assert "{query}" in DEFAULT_PROMPT_TEMPLATE + + +def test_default_prompt_template_has_answer_placeholder(): + assert "{answer}" in DEFAULT_PROMPT_TEMPLATE + + +# --------------------------------------------------------------------------- +# PassResult on YES +# --------------------------------------------------------------------------- + + +def test_passes_when_llm_returns_yes(validator): + with patch( + "app.core.validators.answer_relevance_custom_llm.completion" + ) as mock_llm: + mock_llm.return_value = _make_llm_response("YES") + result = validator._validate(VALID_INPUT_YES) + + assert isinstance(result, PassResult) + + +def test_passes_when_llm_returns_yes_lowercase(validator): + with patch( + "app.core.validators.answer_relevance_custom_llm.completion" + ) as mock_llm: + mock_llm.return_value = _make_llm_response("yes") + result = validator._validate(VALID_INPUT_YES) + + assert isinstance(result, PassResult) + + +def test_passes_when_llm_returns_yes_with_trailing_text(validator): + with patch( + "app.core.validators.answer_relevance_custom_llm.completion" + ) as mock_llm: + mock_llm.return_value = _make_llm_response("YES.") + result = validator._validate(VALID_INPUT_YES) + + assert isinstance(result, PassResult) + + +# --------------------------------------------------------------------------- +# FailResult on NO +# --------------------------------------------------------------------------- + + +def test_fails_when_llm_returns_no(validator): + with patch( + "app.core.validators.answer_relevance_custom_llm.completion" + ) as mock_llm: + mock_llm.return_value = _make_llm_response("NO") + result = validator._validate(VALID_INPUT_NO) + + assert isinstance(result, FailResult) + assert "not relevant" in result.error_message + + +def test_fails_when_llm_returns_no_lowercase(validator): + with patch( + "app.core.validators.answer_relevance_custom_llm.completion" + ) as mock_llm: + mock_llm.return_value = _make_llm_response("no") + result = validator._validate(VALID_INPUT_NO) + + assert isinstance(result, FailResult) + + +# --------------------------------------------------------------------------- +# Input parsing errors +# --------------------------------------------------------------------------- + + +def test_fails_with_non_json_input(validator): + result = validator._validate("this is not json") + + assert isinstance(result, FailResult) + assert "JSON" in result.error_message + + +def test_fails_with_empty_query(validator): + value = json.dumps({"query": "", "answer": "Some answer."}) + result = validator._validate(value) + + assert isinstance(result, FailResult) + assert "non-empty" in result.error_message + + +def test_fails_with_whitespace_only_query(validator): + value = json.dumps({"query": " ", "answer": "Some answer."}) + result = validator._validate(value) + + assert isinstance(result, FailResult) + + +def test_fails_with_empty_answer(validator): + value = json.dumps({"query": "What is fever?", "answer": ""}) + result = validator._validate(value) + + assert isinstance(result, FailResult) + assert "non-empty" in result.error_message + + +def test_fails_with_missing_query_key(validator): + value = json.dumps({"answer": "Some answer."}) + result = validator._validate(value) + + assert isinstance(result, FailResult) + + +def test_fails_with_missing_answer_key(validator): + value = json.dumps({"query": "What is fever?"}) + result = validator._validate(value) + + assert isinstance(result, FailResult) + + +# --------------------------------------------------------------------------- +# Custom prompt template +# --------------------------------------------------------------------------- + + +def test_custom_prompt_template_is_used(): + custom_template = "Q: {query}\nA: {answer}\nRelevant? YES or NO." + validator = AnswerRelevanceCustomLLM(prompt_template=custom_template) + + with patch( + "app.core.validators.answer_relevance_custom_llm.completion" + ) as mock_llm: + mock_llm.return_value = _make_llm_response("YES") + validator._validate(VALID_INPUT_YES) + + call_args = mock_llm.call_args + prompt_sent = call_args.kwargs["messages"][0]["content"] + + assert "Q: What causes fever?" in prompt_sent + assert "A: Infections cause fever." in prompt_sent + + +def test_custom_prompt_with_unknown_placeholder_returns_fail_result(): + # str.format() raises KeyError for *unknown* keys, not for missing {answer}/{query}. + bad_template = "Query: {query} Answer: {answer} Extra: {unknown_field}" + validator = AnswerRelevanceCustomLLM(prompt_template=bad_template) + + result = validator._validate(VALID_INPUT_YES) + + assert isinstance(result, FailResult) + assert "placeholder" in result.error_message + + +# --------------------------------------------------------------------------- +# LLM call failure +# --------------------------------------------------------------------------- + + +def test_fails_when_llm_raises(validator): + with patch( + "app.core.validators.answer_relevance_custom_llm.completion" + ) as mock_llm: + mock_llm.side_effect = Exception("network error") + result = validator._validate(VALID_INPUT_YES) + + assert isinstance(result, FailResult) + assert "LLM call failed" in result.error_message + + +# --------------------------------------------------------------------------- +# Unexpected LLM response +# --------------------------------------------------------------------------- + + +def test_fails_on_unexpected_llm_response(validator): + with patch( + "app.core.validators.answer_relevance_custom_llm.completion" + ) as mock_llm: + mock_llm.return_value = _make_llm_response("MAYBE") + result = validator._validate(VALID_INPUT_YES) + + assert isinstance(result, FailResult) + assert "Unexpected" in result.error_message + + +# --------------------------------------------------------------------------- +# llm_callable is forwarded +# --------------------------------------------------------------------------- + + +def test_llm_callable_is_forwarded(): + validator = AnswerRelevanceCustomLLM(llm_callable="gpt-4o") + + with patch( + "app.core.validators.answer_relevance_custom_llm.completion" + ) as mock_llm: + mock_llm.return_value = _make_llm_response("YES") + validator._validate(VALID_INPUT_YES) + + call_args = mock_llm.call_args + assert call_args.kwargs["model"] == "gpt-4o"