From 9b40427a28b4fa1a62f6a7cd3f7d4aabdfbb0a4d Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Tue, 19 May 2026 02:39:20 +0530 Subject: [PATCH 01/10] Refactor model configuration validation and add new model configurations for STT/TTS providers --- .../061_seed_stt_tts_model_configs.py | 83 +++++++++++ backend/app/crud/config/config.py | 3 + backend/app/crud/config/version.py | 3 + backend/app/crud/model_config.py | 141 +++++++++++++++++- backend/app/models/llm/constants.py | 41 ----- backend/app/models/llm/request.py | 42 +----- backend/app/models/model_config.py | 4 +- backend/app/tests/crud/test_model_config.py | 131 ++++++++++++++++ backend/app/tests/models/llm/test_request.py | 66 +------- 9 files changed, 361 insertions(+), 153 deletions(-) create mode 100644 backend/app/alembic/versions/061_seed_stt_tts_model_configs.py diff --git a/backend/app/alembic/versions/061_seed_stt_tts_model_configs.py b/backend/app/alembic/versions/061_seed_stt_tts_model_configs.py new file mode 100644 index 000000000..2100662a1 --- /dev/null +++ b/backend/app/alembic/versions/061_seed_stt_tts_model_configs.py @@ -0,0 +1,83 @@ +"""seed stt/tts model_config rows for google, sarvamai, elevenlabs + +Revision ID: 061 +Revises: 060 +Create Date: 2026-05-19 00:00:00.000000 + +""" + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "061" +down_revision = "060" +branch_labels = None +depends_on = None + + +SEEDED_MODELS = [ + ("google", "gemini-2.5-pro"), + ("google", "gemini-3.1-pro-preview"), + ("google", "gemini-3-flash-preview"), + ("google", "gemini-2.5-flash"), + ("google", "gemini-2.5-flash-preview-tts"), + ("google", "gemini-2.5-pro-preview-tts"), + ("sarvamai", "saaras:v3"), + ("sarvamai", "bulbul:v3"), + ("elevenlabs", "scribe_v2"), + ("elevenlabs", "eleven_v3"), +] + + +def upgrade(): + # Re-align identity sequence to MAX(id) so new rows get contiguous ids + # even if dev/test DBs drifted from manual inserts/deletes. + op.execute( + "SELECT setval(pg_get_serial_sequence('global.model_config', 'id'), " + "(SELECT COALESCE(MAX(id), 1) FROM global.model_config))" + ) + + op.execute( + """ + INSERT INTO global.model_config + (provider, model_name, config, input_modalities, output_modalities, pricing, is_active, inserted_at, updated_at) + VALUES + ('google', 'gemini-2.5-pro', '{"temperature": {"type": "float", "default": 1.0, "min": 0.0, "max": 2.0, "description": "Controls randomness. Lower = more deterministic."}}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), + ('google', 'gemini-3.1-pro-preview', '{"thinking_level": {"type": "enum", "default": "high", "options": ["low", "medium", "high"], "description": "Max reasoning depth before output. high = best quality, low = faster/cheaper."}}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), + ('google', 'gemini-3-flash-preview', '{"thinking_level": {"type": "enum", "default": "high", "options": ["minimal", "low", "medium", "high"], "description": "Max reasoning depth before output."}}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), + ('google', 'gemini-2.5-flash', '{"temperature": {"type": "float", "default": 1.0, "min": 0.0, "max": 2.0, "description": "Controls randomness. Lower = more deterministic."}}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), + ('google', 'gemini-2.5-flash-preview-tts', '{"voice": {"type": "enum", "default": "Kore", "options": ["Kore", "Orus", "Leda", "Charon"], "description": "TTS voice."}}', '{TEXT}', '{AUDIO}', NULL, true, NOW(), NOW()), + ('google', 'gemini-2.5-pro-preview-tts', '{"voice": {"type": "enum", "default": "Kore", "options": ["Kore", "Orus", "Leda", "Charon"], "description": "TTS voice."}}', '{TEXT}', '{AUDIO}', NULL, true, NOW(), NOW()), + ('sarvamai', 'saaras:v3', '{}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), + ('sarvamai', 'bulbul:v3', '{"voice": {"type": "enum", "default": "simran", "options": ["simran", "shubh", "roopa"], "description": "TTS voice."}}', '{TEXT}', '{AUDIO}', NULL, true, NOW(), NOW()), + ('elevenlabs', 'scribe_v2', '{}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), + ('elevenlabs', 'eleven_v3', '{"voice": {"type": "enum", "default": "Sarah", "options": ["Sarah", "George", "Callum", "Liam"], "description": "TTS voice."}}', '{TEXT}', '{AUDIO}', NULL, true, NOW(), NOW()) + ON CONFLICT (provider, model_name) DO NOTHING + """ + ) + + # Keep sequence in sync after insert + op.execute( + "SELECT setval(pg_get_serial_sequence('global.model_config', 'id'), " + "(SELECT MAX(id) FROM global.model_config))" + ) + + +def downgrade(): + op.execute( + """ + DELETE FROM global.model_config + WHERE (provider, model_name) IN ( + ('google', 'gemini-2.5-pro'), + ('google', 'gemini-3.1-pro-preview'), + ('google', 'gemini-3-flash-preview'), + ('google', 'gemini-2.5-flash'), + ('google', 'gemini-2.5-flash-preview-tts'), + ('google', 'gemini-2.5-pro-preview-tts'), + ('sarvamai', 'saaras:v3'), + ('sarvamai', 'bulbul:v3'), + ('elevenlabs', 'scribe_v2'), + ('elevenlabs', 'eleven_v3') + ) + """ + ) diff --git a/backend/app/crud/config/config.py b/backend/app/crud/config/config.py index 12a1a60fe..fdfd055f0 100644 --- a/backend/app/crud/config/config.py +++ b/backend/app/crud/config/config.py @@ -5,6 +5,7 @@ from sqlmodel import Session, and_, select from app.core.util import now +from app.crud.model_config import validate_blob_model_or_raise from app.models import ( Config, ConfigCreate, @@ -33,6 +34,8 @@ def create_or_raise( """ self._check_unique_name_or_raise(config_create.name) + validate_blob_model_or_raise(self.session, config_create.config_blob) + try: config = Config( name=config_create.name, diff --git a/backend/app/crud/config/version.py b/backend/app/crud/config/version.py index 5812e5ed0..b3483611c 100644 --- a/backend/app/crud/config/version.py +++ b/backend/app/crud/config/version.py @@ -8,6 +8,7 @@ from sqlmodel import Session, and_, select from app.core.util import now +from app.crud.model_config import validate_blob_model_or_raise from app.models import ( Config, ConfigVersion, @@ -81,6 +82,8 @@ def create_or_raise(self, version_create: ConfigVersionUpdate) -> ConfigVersion: ) raise HTTPException(status_code=400, detail=validation_errors) + validate_blob_model_or_raise(self.session, validated_blob) + try: next_version = self._get_next_version(self.config_id) diff --git a/backend/app/crud/model_config.py b/backend/app/crud/model_config.py index 6d535240a..f4ec0daf0 100644 --- a/backend/app/crud/model_config.py +++ b/backend/app/crud/model_config.py @@ -1,13 +1,24 @@ from typing import Any, Literal +from fastapi import HTTPException +from sqlalchemy.dialects.postgresql import ARRAY +from sqlalchemy.sql import sqltypes from sqlmodel import Session, select from app.models import ModelConfig +Provider = Literal["openai", "google", "sarvamai", "elevenlabs"] +CompletionType = Literal["text", "stt", "tts"] + + +def _normalize_provider(raw: str) -> str: + """Map NativeCompletionConfig providers (e.g. 'openai-native') to model_config provider names.""" + return raw[: -len("-native")] if raw.endswith("-native") else raw + def list_active_model_configs( session: Session, - provider: Literal["openai", "google"] | None = None, + provider: Provider | None = None, skip: int = 0, limit: int = 100, ) -> tuple[list[ModelConfig], bool]: @@ -30,7 +41,7 @@ def list_active_model_configs( def list_all_active_model_configs( session: Session, - provider: Literal["openai", "google"] | None = None, + provider: Provider | None = None, ) -> list[ModelConfig]: statement = select(ModelConfig).where(ModelConfig.is_active) @@ -42,7 +53,7 @@ def list_all_active_model_configs( def get_model_config( - session: Session, provider: Literal["openai", "google"], model_name: str + session: Session, provider: Provider, model_name: str ) -> ModelConfig | None: statement = select(ModelConfig).where( ModelConfig.provider == provider, @@ -52,9 +63,127 @@ def get_model_config( return session.exec(statement).first() -def is_reasoning_model( - session: Session, provider: Literal["openai", "google"], model_name: str +def _modality_filter(stmt, completion_type: CompletionType): + """Restrict query to models matching the completion type via modalities.""" + str_array = ARRAY(sqltypes.String) + input_col = ModelConfig.input_modalities + output_col = ModelConfig.output_modalities + + if completion_type == "stt": + return stmt.where( + input_col.cast(str_array).contains(["AUDIO"]), + output_col.cast(str_array).contains(["TEXT"]), + ) + if completion_type == "tts": + return stmt.where( + input_col.cast(str_array).contains(["TEXT"]), + output_col.cast(str_array).contains(["AUDIO"]), + ) + # text: must produce TEXT and not consume/produce AUDIO + return stmt.where( + output_col.cast(str_array).contains(["TEXT"]), + ~input_col.cast(str_array).contains(["AUDIO"]), + ~output_col.cast(str_array).contains(["AUDIO"]), + ) + + +def list_supported_models( + session: Session, provider: Provider, completion_type: CompletionType +) -> list[str]: + """Return active model names for a provider+completion type.""" + stmt = select(ModelConfig.model_name).where( + ModelConfig.provider == provider, + ModelConfig.is_active, + ) + stmt = _modality_filter(stmt, completion_type) + return list(session.exec(stmt).all()) + + +def is_model_supported( + session: Session, + provider: Provider, + completion_type: CompletionType, + model_name: str, ) -> bool: + """Check whether (provider, model_name) is active and matches the completion type.""" + stmt = select(ModelConfig.id).where( + ModelConfig.provider == provider, + ModelConfig.model_name == model_name, + ModelConfig.is_active, + ) + stmt = _modality_filter(stmt, completion_type) + return session.exec(stmt).first() is not None + + +def validate_blob_model_or_raise(session: Session, blob: Any) -> None: + """Reject ConfigBlob whose completion.params.model is not in model_config. + + Native configs forward raw provider params; we still expect a `model` key + in params for text/stt/tts. Missing model is treated as a validation error. + """ + completion = blob.completion + raw_provider = completion.provider + completion_type = completion.type + if raw_provider is None: + return + + if raw_provider.endswith("-native"): + return + + provider = _normalize_provider(raw_provider) + model_name = (completion.params or {}).get("model") + if not model_name: + raise HTTPException( + status_code=400, + detail=f"completion.params.model is required for provider='{raw_provider}'", + ) + + model_row = get_model_config( + session=session, + provider=provider, # type: ignore[arg-type] + model_name=model_name, + ) + if model_row is None or not is_model_supported( + session=session, + provider=provider, # type: ignore[arg-type] + completion_type=completion_type, + model_name=model_name, + ): + allowed = list_supported_models( + session=session, + provider=provider, # type: ignore[arg-type] + completion_type=completion_type, + ) + raise HTTPException( + status_code=400, + detail=( + f"Model '{model_name}' is not supported for provider='{provider}' " + f"type='{completion_type}'. Allowed: {allowed}" + ), + ) + + # TTS voice check: voice must match options declared in model_config.config.voice + if completion_type == "tts": + voice = (completion.params or {}).get("voice") + voice_spec = ( + model_row.config.get("voice") + if isinstance(model_row.config, dict) + else None + ) + allowed_voices = ( + voice_spec.get("options") if isinstance(voice_spec, dict) else None + ) + if voice and allowed_voices and voice not in allowed_voices: + raise HTTPException( + status_code=400, + detail=( + f"Voice '{voice}' is not supported for provider='{provider}' " + f"model='{model_name}'. Allowed: {allowed_voices}" + ), + ) + + +def is_reasoning_model(session: Session, provider: Provider, model_name: str) -> bool: """Return True if the model is configured with a reasoning `effort` control. A model is considered reasoning-capable if its `config` JSON contains an @@ -69,7 +198,7 @@ def is_reasoning_model( def estimate_model_cost( session: Session, - provider: Literal["openai", "google"], + provider: Provider, model_name: str, input_tokens: int, output_tokens: int, diff --git a/backend/app/models/llm/constants.py b/backend/app/models/llm/constants.py index 399748843..1838da79d 100644 --- a/backend/app/models/llm/constants.py +++ b/backend/app/models/llm/constants.py @@ -2,47 +2,6 @@ DEFAULT_TTS_MODEL = "gemini-2.5-flash-preview-tts" DEFAULT_TTS_VOICE = "Kore" -SUPPORTED_MODELS = { - ("google", "stt"): [ - DEFAULT_STT_MODEL, - "gemini-3.1-pro-preview", - "gemini-3-flash-preview", - "gemini-2.5-flash", - ], - ("google", "tts"): [DEFAULT_TTS_MODEL, "gemini-2.5-pro-preview-tts"], - ("sarvamai", "stt"): ["saaras:v3"], - ("sarvamai", "tts"): ["bulbul:v3"], - ("elevenlabs", "stt"): ["scribe_v2"], - ("elevenlabs", "tts"): ["eleven_v3"], - ("openai", "text"): [ - "gpt-4o", - "gpt-4o-mini", - "gpt-4.1", - "gpt-4.1-mini", - "gpt-4.1-nano", - "gpt-5.4", - "gpt-5.1", - "gpt-5-mini", - "gpt-5-nano", - "o1", - "o1-preview", - "o1-mini", - "gpt-5.4-pro", - "gpt-5.4-mini", - "gpt-5.4-nano", - "gpt-5", - "gpt-4-turbo", - "gpt-4", - "gpt-3.5-turbo", - ], -} - -SUPPORTED_VOICES = { - ("google", "tts"): ["Kore", "Orus", "Leda", "Charon"], - ("sarvamai", "tts"): ["simran", "shubh", "roopa"], - ("elevenlabs", "tts"): ["Sarah", "George", "Callum", "Liam"], -} - # BCP-47 to language tag -> Gemini ISO 639-1 code (Indic + English) BCP47_LOCALE_TO_GEMINI_LANG: dict[str, str] = { "en-IN": "en", diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index a5ceedfe1..da0c18120 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -13,8 +13,6 @@ DEFAULT_STT_MODEL, DEFAULT_TTS_MODEL, DEFAULT_TTS_VOICE, - SUPPORTED_MODELS, - SUPPORTED_VOICES, ) @@ -272,50 +270,12 @@ def validate_params(self): } model_class = param_models[self.type] - provider = self.provider - provider_was_auto_assigned = False - if self.type in ("stt", "tts") and provider is None: + if self.type in ("stt", "tts") and self.provider is None: self.provider = "google" - provider = self.provider - provider_was_auto_assigned = True user_provided_temperature = "temperature" in self.params validated = model_class.model_validate(self.params) - if provider is not None: - key = (provider, self.type) - - allowed_models = SUPPORTED_MODELS.get(key) - if allowed_models and validated.model not in allowed_models: - if provider_was_auto_assigned: - raise ValueError( - f"Model '{validated.model}' is not supported. " - f"Provider was auto-defaulted to '{provider}' (for type='{self.type}'), which requires models: {allowed_models}. " - f"Either specify a supported model or explicitly set 'provider' to match your model." - ) - else: - raise ValueError( - f"Model '{validated.model}' is not supported for provider='{provider}' type='{self.type}'. " - f"Allowed: {allowed_models}" - ) - - if self.type == "tts": - # voice = self.params.get("voice") - voice = validated.voice - allowed_voices = SUPPORTED_VOICES.get(key) - if allowed_voices and voice and voice not in allowed_voices: - if provider_was_auto_assigned: - raise ValueError( - f"Voice '{voice}' is not supported. " - f"Provider was auto-defaulted to '{provider}' (for type='{self.type}'), which requires voices: {allowed_voices}. " - f"Either specify a supported voice or explicitly set 'provider' to match your voice." - ) - else: - raise ValueError( - f"Voice '{voice}' is not supported for provider='{provider}'. " - f"Allowed: {allowed_voices}" - ) - self.params = validated.model_dump(exclude_none=True) if not user_provided_temperature: self.params.pop("temperature", None) diff --git a/backend/app/models/model_config.py b/backend/app/models/model_config.py index bc1d14115..f469fafff 100644 --- a/backend/app/models/model_config.py +++ b/backend/app/models/model_config.py @@ -9,12 +9,12 @@ class ModelConfigBase(SQLModel): - provider: Literal["openai", "google"] = Field( + provider: Literal["openai", "google", "sarvamai", "elevenlabs"] = Field( default="openai", sa_column=sa.Column( sa.String, nullable=False, - comment="provider name (e.g. openai, google)", + comment="provider name (e.g. openai, google, sarvamai, elevenlabs)", ), ) diff --git a/backend/app/tests/crud/test_model_config.py b/backend/app/tests/crud/test_model_config.py index be606f296..f16f523e6 100644 --- a/backend/app/tests/crud/test_model_config.py +++ b/backend/app/tests/crud/test_model_config.py @@ -2,6 +2,7 @@ from typing import Any import pytest +from fastapi import HTTPException from app.crud import model_config as model_config_crud @@ -160,3 +161,133 @@ def test_estimate_model_cost_returns_none_for_non_numeric_prices( ) assert result is None + + +def _make_blob(provider, completion_type, params): + completion = SimpleNamespace(provider=provider, type=completion_type, params=params) + return SimpleNamespace(completion=completion) + + +def _patch_validators( + monkeypatch: pytest.MonkeyPatch, + *, + row: Any | None, + supported: bool, + allowed: list[str] | None = None, +) -> None: + monkeypatch.setattr( + model_config_crud, + "get_model_config", + lambda session, provider, model_name: row, + ) + monkeypatch.setattr( + model_config_crud, + "is_model_supported", + lambda session, provider, completion_type, model_name: supported, + ) + monkeypatch.setattr( + model_config_crud, + "list_supported_models", + lambda session, provider, completion_type: allowed or [], + ) + + +def test_validate_blob_native_provider_short_circuits( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Native pass-through never hits DB.""" + called = {"hit": False} + + def boom(*a, **kw): + called["hit"] = True + return None + + monkeypatch.setattr(model_config_crud, "get_model_config", boom) + monkeypatch.setattr(model_config_crud, "is_model_supported", boom) + + blob = _make_blob("openai-native", "text", {"model": "anything"}) + model_config_crud.validate_blob_model_or_raise(session=None, blob=blob) # type: ignore[arg-type] + + assert called["hit"] is False + + +def test_validate_blob_none_provider_skips(monkeypatch: pytest.MonkeyPatch) -> None: + blob = _make_blob(None, "text", {"model": "gpt-4o"}) + # No patches — should never reach helpers + model_config_crud.validate_blob_model_or_raise(session=None, blob=blob) # type: ignore[arg-type] + + +def test_validate_blob_missing_model_raises(monkeypatch: pytest.MonkeyPatch) -> None: + blob = _make_blob("openai", "text", {"temperature": 0.5}) + with pytest.raises(HTTPException) as exc: + model_config_crud.validate_blob_model_or_raise(session=None, blob=blob) # type: ignore[arg-type] + assert exc.value.status_code == 400 + assert "model is required" in exc.value.detail + + +def test_validate_blob_unsupported_model_raises( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _patch_validators( + monkeypatch, + row=None, + supported=False, + allowed=["gpt-4o", "gpt-4o-mini"], + ) + blob = _make_blob("openai", "text", {"model": "gpt-4-turbo"}) + with pytest.raises(HTTPException) as exc: + model_config_crud.validate_blob_model_or_raise(session=None, blob=blob) # type: ignore[arg-type] + assert exc.value.status_code == 400 + assert "gpt-4-turbo" in exc.value.detail + assert "gpt-4o" in exc.value.detail + + +def test_validate_blob_supported_text_passes(monkeypatch: pytest.MonkeyPatch) -> None: + row = SimpleNamespace(config={}) + _patch_validators(monkeypatch, row=row, supported=True) + blob = _make_blob("openai", "text", {"model": "gpt-4o"}) + model_config_crud.validate_blob_model_or_raise(session=None, blob=blob) # type: ignore[arg-type] + + +def test_validate_blob_tts_invalid_voice_raises( + monkeypatch: pytest.MonkeyPatch, +) -> None: + row = SimpleNamespace( + config={"voice": {"type": "enum", "options": ["Kore", "Orus"]}} + ) + _patch_validators(monkeypatch, row=row, supported=True) + blob = _make_blob( + "google", + "tts", + {"model": "gemini-2.5-flash-preview-tts", "voice": "Sarah"}, + ) + with pytest.raises(HTTPException) as exc: + model_config_crud.validate_blob_model_or_raise(session=None, blob=blob) # type: ignore[arg-type] + assert exc.value.status_code == 400 + assert "Sarah" in exc.value.detail + assert "Kore" in exc.value.detail + + +def test_validate_blob_tts_valid_voice_passes( + monkeypatch: pytest.MonkeyPatch, +) -> None: + row = SimpleNamespace( + config={"voice": {"type": "enum", "options": ["Kore", "Orus"]}} + ) + _patch_validators(monkeypatch, row=row, supported=True) + blob = _make_blob( + "google", + "tts", + {"model": "gemini-2.5-flash-preview-tts", "voice": "Kore"}, + ) + model_config_crud.validate_blob_model_or_raise(session=None, blob=blob) # type: ignore[arg-type] + + +def test_validate_blob_tts_no_voice_spec_passes( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """If model_config row has no voice schema, voice value is not enforced.""" + row = SimpleNamespace(config={}) + _patch_validators(monkeypatch, row=row, supported=True) + blob = _make_blob("sarvamai", "tts", {"model": "bulbul:v3", "voice": "anything"}) + model_config_crud.validate_blob_model_or_raise(session=None, blob=blob) # type: ignore[arg-type] diff --git a/backend/app/tests/models/llm/test_request.py b/backend/app/tests/models/llm/test_request.py index 3d40f607a..a77551667 100644 --- a/backend/app/tests/models/llm/test_request.py +++ b/backend/app/tests/models/llm/test_request.py @@ -1,6 +1,3 @@ -import pytest -from pydantic import ValidationError - from app.models.llm.request import KaapiCompletionConfig @@ -49,63 +46,6 @@ def test_temperature_zero_preserved_when_explicitly_set(self) -> None: assert config.params["temperature"] == 0.0 -class TestNewSupportedModels: - """Test that newly added models are accepted for openai/text provider.""" - - @pytest.mark.parametrize( - "model", - [ - "gpt-5.4-pro", - "gpt-5.4-mini", - "gpt-5.4-nano", - "gpt-5", - "gpt-4-turbo", - "gpt-4", - "gpt-3.5-turbo", - ], - ) - def test_new_model_accepted(self, model: str) -> None: - """New models should be accepted for openai text provider.""" - config = KaapiCompletionConfig( - provider="openai", - type="text", - params={"model": model}, - ) - - assert config.params["model"] == model - - @pytest.mark.parametrize( - "model", - [ - "gpt-4o", - "gpt-4o-mini", - "gpt-4.1", - "gpt-4.1-mini", - "gpt-4.1-nano", - "gpt-5.4", - "gpt-5.1", - "gpt-5-mini", - "gpt-5-nano", - "o1", - "o1-preview", - "o1-mini", - ], - ) - def test_existing_models_still_accepted(self, model: str) -> None: - """Previously supported models should still be accepted.""" - config = KaapiCompletionConfig( - provider="openai", - type="text", - params={"model": model}, - ) - - assert config.params["model"] == model - - def test_unsupported_model_rejected(self) -> None: - """An unsupported model should raise a validation error.""" - with pytest.raises(ValidationError, match="not supported"): - KaapiCompletionConfig( - provider="openai", - type="text", - params={"model": "unsupported-model-xyz"}, - ) +# Model-allowlist enforcement moved from KaapiCompletionConfig.validate_params to +# the CRUD layer (crud.model_config.validate_blob_model_or_raise) which consults +# the model_config table. See tests/crud/config/* for coverage. From d7e7169cd6ca8eb8a0e17f2e2e95e3ed4702c800 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Tue, 19 May 2026 15:44:20 +0530 Subject: [PATCH 02/10] Remove redundant tests for changing config types in version creation --- .../tests/api/routes/configs/test_version.py | 211 ------------------ 1 file changed, 211 deletions(-) diff --git a/backend/app/tests/api/routes/configs/test_version.py b/backend/app/tests/api/routes/configs/test_version.py index 2bdbc0025..a56c7e4d7 100644 --- a/backend/app/tests/api/routes/configs/test_version.py +++ b/backend/app/tests/api/routes/configs/test_version.py @@ -571,107 +571,6 @@ def test_create_version_cannot_change_type_from_text_to_stt( assert "stt" in error_detail -def test_create_version_cannot_change_type_from_stt_to_tts( - db: Session, - client: TestClient, - user_api_key: TestAuthContext, -) -> None: - """Test that config type cannot be changed from 'stt' to 'tts' in a new version.""" - from app.models.llm.request import KaapiCompletionConfig - - # Create initial config with type='stt' - config_blob = ConfigBlob( - completion=KaapiCompletionConfig( - provider="openai", - type="stt", - params={ - "model": "whisper-1", - "instructions": "Transcribe audio", - "temperature": 0.2, - }, - ) - ) - config = create_test_config( - db=db, - project_id=user_api_key.project_id, - name="stt-config", - config_blob=config_blob, - ) - - # Try to create a new version with type='tts' - version_data = { - "config_blob": { - "completion": { - "provider": "openai", - "type": "tts", - "params": { - "model": "tts-1", - "voice": "alloy", - "language": "en", - }, - } - }, - "commit_message": "Attempting to change type to tts", - } - - response = client.post( - f"{settings.API_V1_STR}/configs/{config.id}/versions", - headers={"X-API-KEY": user_api_key.key}, - json=version_data, - ) - assert response.status_code == 400 - - -def test_create_version_cannot_change_type_from_tts_to_text( - db: Session, - client: TestClient, - user_api_key: TestAuthContext, -) -> None: - """Test that config type cannot be changed from 'tts' to 'text' in a new version.""" - from app.models.llm.request import KaapiCompletionConfig - - # Create initial config with type='tts' - config_blob = ConfigBlob( - completion=KaapiCompletionConfig( - provider="openai", - type="tts", - params={ - "model": "tts-1", - "voice": "alloy", - "language": "en", - }, - ) - ) - config = create_test_config( - db=db, - project_id=user_api_key.project_id, - name="tts-config", - config_blob=config_blob, - ) - - # Try to create a new version with type='text' - version_data = { - "config_blob": { - "completion": { - "provider": "openai", - "type": "text", - "params": { - "model": "gpt-4o", - "temperature": 0.7, - }, - } - }, - "commit_message": "Attempting to change type to text", - } - - response = client.post( - f"{settings.API_V1_STR}/configs/{config.id}/versions", - headers={"X-API-KEY": user_api_key.key}, - json=version_data, - ) - assert response.status_code == 400 - - def test_create_version_same_type_succeeds( db: Session, client: TestClient, @@ -818,113 +717,3 @@ def test_create_config_with_kaapi_provider_success( assert data["data"]["version"]["config_blob"]["completion"]["type"] == "text" -def test_create_version_with_kaapi_stt_provider_success( - db: Session, - client: TestClient, - user_api_key: TestAuthContext, -) -> None: - """Test creating STT config and version with Kaapi provider works correctly.""" - from app.models.llm.request import KaapiCompletionConfig - - # Create initial STT config with Kaapi provider - config_blob = ConfigBlob( - completion=KaapiCompletionConfig( - provider="openai", - type="stt", - params={ - "model": "whisper-1", - "instructions": "Transcribe audio accurately", - "temperature": 0.2, - }, - ) - ) - config = create_test_config( - db=db, - project_id=user_api_key.project_id, - name="kaapi-stt-config", - config_blob=config_blob, - ) - - # Create a new version with the same type='stt' - version_data = { - "config_blob": { - "completion": { - "provider": "openai", - "type": "stt", - "params": { - "model": "whisper-1", - "instructions": "Transcribe with high accuracy", - "temperature": 0.1, - }, - } - }, - "commit_message": "Updated STT instructions", - } - - response = client.post( - f"{settings.API_V1_STR}/configs/{config.id}/versions", - headers={"X-API-KEY": user_api_key.key}, - json=version_data, - ) - assert response.status_code == 201 - data = response.json() - assert data["success"] is True - assert data["data"]["version"] == 2 - assert data["data"]["config_blob"]["completion"]["provider"] == "openai" - assert data["data"]["config_blob"]["completion"]["type"] == "stt" - - -def test_create_version_with_kaapi_tts_provider_success( - db: Session, - client: TestClient, - user_api_key: TestAuthContext, -) -> None: - """Test creating TTS config and version with Kaapi provider works correctly.""" - from app.models.llm.request import KaapiCompletionConfig - - # Create initial TTS config with Kaapi provider - config_blob = ConfigBlob( - completion=KaapiCompletionConfig( - provider="openai", - type="tts", - params={ - "model": "tts-1", - "voice": "alloy", - "language": "en", - }, - ) - ) - config = create_test_config( - db=db, - project_id=user_api_key.project_id, - name="kaapi-tts-config", - config_blob=config_blob, - ) - - # Create a new version with the same type='tts' - version_data = { - "config_blob": { - "completion": { - "provider": "openai", - "type": "tts", - "params": { - "model": "tts-1-hd", - "voice": "nova", - "language": "en", - }, - } - }, - "commit_message": "Updated TTS to HD model with nova voice", - } - - response = client.post( - f"{settings.API_V1_STR}/configs/{config.id}/versions", - headers={"X-API-KEY": user_api_key.key}, - json=version_data, - ) - assert response.status_code == 201 - data = response.json() - assert data["success"] is True - assert data["data"]["version"] == 2 - assert data["data"]["config_blob"]["completion"]["provider"] == "openai" - assert data["data"]["config_blob"]["completion"]["type"] == "tts" From 06f94efe69e5e4e77c14bff9f23d311fabb99ea5 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Tue, 19 May 2026 15:56:20 +0530 Subject: [PATCH 03/10] Add migration to seed STT/TTS model configurations for Google, Sarvamai, and ElevenLabs --- .../versions/062_add_pending_job_monitoring_indexes.py | 4 ++-- ...model_configs.py => 063_seed_stt_tts_model_configs.py} | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) rename backend/app/alembic/versions/{061_seed_stt_tts_model_configs.py => 063_seed_stt_tts_model_configs.py} (98%) diff --git a/backend/app/alembic/versions/062_add_pending_job_monitoring_indexes.py b/backend/app/alembic/versions/062_add_pending_job_monitoring_indexes.py index 846f0646e..3632e5062 100644 --- a/backend/app/alembic/versions/062_add_pending_job_monitoring_indexes.py +++ b/backend/app/alembic/versions/062_add_pending_job_monitoring_indexes.py @@ -1,7 +1,7 @@ """add pending job monitoring indexes -Revision ID: 061 -Revises: 060 +Revision ID: 062 +Revises: 061 Create Date: 2026-05-13 00:00:00.000000 """ diff --git a/backend/app/alembic/versions/061_seed_stt_tts_model_configs.py b/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py similarity index 98% rename from backend/app/alembic/versions/061_seed_stt_tts_model_configs.py rename to backend/app/alembic/versions/063_seed_stt_tts_model_configs.py index 2100662a1..db35d453e 100644 --- a/backend/app/alembic/versions/061_seed_stt_tts_model_configs.py +++ b/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py @@ -1,7 +1,7 @@ """seed stt/tts model_config rows for google, sarvamai, elevenlabs -Revision ID: 061 -Revises: 060 +Revision ID: 063 +Revises: 062 Create Date: 2026-05-19 00:00:00.000000 """ @@ -9,8 +9,8 @@ from alembic import op # revision identifiers, used by Alembic. -revision = "061" -down_revision = "060" +revision = "063" +down_revision = "062" branch_labels = None depends_on = None From 3bc791a91b348f19ed8476446367b5c6c8a7bb63 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Tue, 19 May 2026 16:57:44 +0530 Subject: [PATCH 04/10] Remove unnecessary blank lines in test_create_config_with_kaapi_provider_success --- backend/app/tests/api/routes/configs/test_version.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/backend/app/tests/api/routes/configs/test_version.py b/backend/app/tests/api/routes/configs/test_version.py index a56c7e4d7..d7650af24 100644 --- a/backend/app/tests/api/routes/configs/test_version.py +++ b/backend/app/tests/api/routes/configs/test_version.py @@ -715,5 +715,3 @@ def test_create_config_with_kaapi_provider_success( assert data["data"]["name"] == config_data["name"] assert data["data"]["version"]["config_blob"]["completion"]["provider"] == "openai" assert data["data"]["version"]["config_blob"]["completion"]["type"] == "text" - - From 4be5b8caecb1a828ee4311eb2bf80ddf2418a076 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Wed, 20 May 2026 16:47:40 +0530 Subject: [PATCH 05/10] Remove commented-out code regarding model-allowlist enforcement from test_request.py --- backend/app/tests/models/llm/test_request.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/backend/app/tests/models/llm/test_request.py b/backend/app/tests/models/llm/test_request.py index a77551667..39f732de3 100644 --- a/backend/app/tests/models/llm/test_request.py +++ b/backend/app/tests/models/llm/test_request.py @@ -44,8 +44,3 @@ def test_temperature_zero_preserved_when_explicitly_set(self) -> None: assert "temperature" in config.params assert config.params["temperature"] == 0.0 - - -# Model-allowlist enforcement moved from KaapiCompletionConfig.validate_params to -# the CRUD layer (crud.model_config.validate_blob_model_or_raise) which consults -# the model_config table. See tests/crud/config/* for coverage. From 45e6b6a6c617f2d8f567a222d87c10aec1ec984d Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Thu, 21 May 2026 20:21:21 +0530 Subject: [PATCH 06/10] Refactor model validation logic and enhance tests for config type restrictions --- .../063_seed_stt_tts_model_configs.py | 13 -- backend/app/crud/model_config.py | 19 +- backend/app/services/llm/jobs.py | 24 ++- .../tests/api/routes/configs/test_version.py | 186 ++++++++++++++++++ backend/app/tests/crud/test_model_config.py | 54 ++++- 5 files changed, 272 insertions(+), 24 deletions(-) diff --git a/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py b/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py index db35d453e..a4d12cd92 100644 --- a/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py +++ b/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py @@ -30,13 +30,6 @@ def upgrade(): - # Re-align identity sequence to MAX(id) so new rows get contiguous ids - # even if dev/test DBs drifted from manual inserts/deletes. - op.execute( - "SELECT setval(pg_get_serial_sequence('global.model_config', 'id'), " - "(SELECT COALESCE(MAX(id), 1) FROM global.model_config))" - ) - op.execute( """ INSERT INTO global.model_config @@ -56,12 +49,6 @@ def upgrade(): """ ) - # Keep sequence in sync after insert - op.execute( - "SELECT setval(pg_get_serial_sequence('global.model_config', 'id'), " - "(SELECT MAX(id) FROM global.model_config))" - ) - def downgrade(): op.execute( diff --git a/backend/app/crud/model_config.py b/backend/app/crud/model_config.py index f4ec0daf0..c99dc231d 100644 --- a/backend/app/crud/model_config.py +++ b/backend/app/crud/model_config.py @@ -6,6 +6,7 @@ from sqlmodel import Session, select from app.models import ModelConfig +from app.models.llm.request import ConfigBlob Provider = Literal["openai", "google", "sarvamai", "elevenlabs"] CompletionType = Literal["text", "stt", "tts"] @@ -63,7 +64,7 @@ def get_model_config( return session.exec(statement).first() -def _modality_filter(stmt, completion_type: CompletionType): +def _modality_filter(stmt: Any, completion_type: CompletionType) -> Any: """Restrict query to models matching the completion type via modalities.""" str_array = ARRAY(sqltypes.String) input_col = ModelConfig.input_modalities @@ -115,11 +116,11 @@ def is_model_supported( return session.exec(stmt).first() is not None -def validate_blob_model_or_raise(session: Session, blob: Any) -> None: +def validate_blob_model_or_raise(session: Session, blob: ConfigBlob) -> None: """Reject ConfigBlob whose completion.params.model is not in model_config. - Native configs forward raw provider params; we still expect a `model` key - in params for text/stt/tts. Missing model is treated as a validation error. + model_config is the source of truth — all providers/types validated. + Native configs are exempt (they forward raw params to the provider). """ completion = blob.completion raw_provider = completion.provider @@ -131,6 +132,7 @@ def validate_blob_model_or_raise(session: Session, blob: Any) -> None: return provider = _normalize_provider(raw_provider) + model_name = (completion.params or {}).get("model") if not model_name: raise HTTPException( @@ -143,7 +145,13 @@ def validate_blob_model_or_raise(session: Session, blob: Any) -> None: provider=provider, # type: ignore[arg-type] model_name=model_name, ) - if model_row is None or not is_model_supported( + if model_row is None: + raise HTTPException( + status_code=400, + detail=f"Model '{model_name}' not found for provider='{provider}'.", + ) + + if not is_model_supported( session=session, provider=provider, # type: ignore[arg-type] completion_type=completion_type, @@ -162,7 +170,6 @@ def validate_blob_model_or_raise(session: Session, blob: Any) -> None: ), ) - # TTS voice check: voice must match options declared in model_config.config.voice if completion_type == "tts": voice = (completion.params or {}).get("voice") voice_spec = ( diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index f818ea489..550d7ff41 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -27,6 +27,7 @@ ) from app.crud.config import ConfigVersionCrud from app.crud.credentials import get_provider_credential +from app.crud.model_config import validate_blob_model_or_raise from app.crud.jobs import JobCrud from app.crud.llm import ( create_llm_call, @@ -131,6 +132,9 @@ def start_job( db: Session, request: LLMCallRequest, project_id: int, organization_id: int ) -> UUID: """Create an LLM job and schedule Celery task.""" + if not request.config.is_stored_config and request.config.blob: + validate_blob_model_or_raise(db, request.config.blob) + with log_context( tag="llm-call", lifecycle="llm.call.start_job", @@ -187,6 +191,10 @@ def start_chain_job( db: Session, request: LLMChainRequest, project_id: int, organization_id: int ) -> UUID: """Create an LLM Chain job and schedule Celery task.""" + for block in request.blocks: + if not block.config.is_stored_config and block.config.blob: + validate_blob_model_or_raise(db, block.config.blob) + trace_id = correlation_id.get() or "N/A" job_crud = JobCrud(session=db) job = job_crud.create( @@ -338,7 +346,7 @@ def resolve_config_blob( return None, "Unexpected error occurred while retrieving stored configuration" try: - return ConfigBlob(**config_version.config_blob), None + blob = ConfigBlob(**config_version.config_blob) except (TypeError, ValueError) as e: return None, f"Stored configuration blob is invalid: {str(e)}" except Exception: @@ -349,6 +357,13 @@ def resolve_config_blob( ) return None, "Unexpected error occurred while parsing stored configuration" + try: + validate_blob_model_or_raise(config_crud.session, blob) + except HTTPException as e: + return None, e.detail + + return blob, None + def apply_input_guardrails( *, @@ -523,6 +538,13 @@ def execute_llm_call( return BlockResult(error=error) else: config_blob = config.blob + try: + validate_blob_model_or_raise(session, config_blob) + except HTTPException as e: + cfg_span.set_status( + trace.Status(trace.StatusCode.ERROR, e.detail) + ) + return BlockResult(error=e.detail) original_input_value = ( query.input.content.value diff --git a/backend/app/tests/api/routes/configs/test_version.py b/backend/app/tests/api/routes/configs/test_version.py index d7650af24..77b2e8e62 100644 --- a/backend/app/tests/api/routes/configs/test_version.py +++ b/backend/app/tests/api/routes/configs/test_version.py @@ -682,6 +682,192 @@ def test_create_version_partial_update_params_only( assert config_blob_result["completion"]["params"]["temperature"] == 0.9 +def test_create_version_cannot_change_type_from_stt_to_tts( + db: Session, + client: TestClient, + user_api_key: TestAuthContext, +) -> None: + """Test that config type cannot be changed from 'stt' to 'tts' in a new version.""" + from app.models.llm.request import KaapiCompletionConfig + + config_blob = ConfigBlob( + completion=KaapiCompletionConfig( + provider="google", + type="stt", + params={"model": "gemini-2.5-pro"}, + ) + ) + config = create_test_config( + db=db, + project_id=user_api_key.project_id, + name="google-stt-config", + config_blob=config_blob, + ) + + version_data = { + "config_blob": { + "completion": { + "provider": "google", + "type": "tts", + "params": {"model": "gemini-2.5-flash-preview-tts", "voice": "Kore"}, + } + }, + "commit_message": "Attempting to change type from stt to tts", + } + + response = client.post( + f"{settings.API_V1_STR}/configs/{config.id}/versions", + headers={"X-API-KEY": user_api_key.key}, + json=version_data, + ) + assert response.status_code == 400 + error_detail = response.json().get("error", "") + assert "cannot change config type" in error_detail.lower() + assert "stt" in error_detail + assert "tts" in error_detail + + +def test_create_version_cannot_change_type_from_tts_to_text( + db: Session, + client: TestClient, + user_api_key: TestAuthContext, +) -> None: + """Test that config type cannot be changed from 'tts' to 'text' in a new version.""" + from app.models.llm.request import KaapiCompletionConfig + + config_blob = ConfigBlob( + completion=KaapiCompletionConfig( + provider="google", + type="tts", + params={"model": "gemini-2.5-flash-preview-tts", "voice": "Kore"}, + ) + ) + config = create_test_config( + db=db, + project_id=user_api_key.project_id, + name="google-tts-config", + config_blob=config_blob, + ) + + version_data = { + "config_blob": { + "completion": { + "provider": "openai", + "type": "text", + "params": {"model": "gpt-4o", "temperature": 0.7}, + } + }, + "commit_message": "Attempting to change type from tts to text", + } + + response = client.post( + f"{settings.API_V1_STR}/configs/{config.id}/versions", + headers={"X-API-KEY": user_api_key.key}, + json=version_data, + ) + assert response.status_code == 400 + error_detail = response.json().get("error", "") + assert "cannot change config type" in error_detail.lower() + assert "tts" in error_detail + assert "text" in error_detail + + +def test_create_version_with_kaapi_stt_provider_success( + db: Session, + client: TestClient, + user_api_key: TestAuthContext, +) -> None: + """Test creating a new STT version with tweaked params succeeds.""" + from app.models.llm.request import KaapiCompletionConfig + + config_blob = ConfigBlob( + completion=KaapiCompletionConfig( + provider="google", + type="stt", + params={"model": "gemini-2.5-pro"}, + ) + ) + config = create_test_config( + db=db, + project_id=user_api_key.project_id, + name="google-stt-version-config", + config_blob=config_blob, + ) + + version_data = { + "config_blob": { + "completion": { + "provider": "google", + "type": "stt", + "params": {"model": "gemini-2.5-pro", "temperature": 0.3}, + } + }, + "commit_message": "Tweak temperature for STT", + } + + response = client.post( + f"{settings.API_V1_STR}/configs/{config.id}/versions", + headers={"X-API-KEY": user_api_key.key}, + json=version_data, + ) + assert response.status_code == 201 + data = response.json() + assert data["success"] is True + assert data["data"]["version"] == 2 + assert data["data"]["config_blob"]["completion"]["type"] == "stt" + assert data["data"]["config_blob"]["completion"]["provider"] == "google" + + +def test_create_version_with_kaapi_tts_provider_success( + db: Session, + client: TestClient, + user_api_key: TestAuthContext, +) -> None: + """Test creating a new TTS version switching model and voice succeeds.""" + from app.models.llm.request import KaapiCompletionConfig + + config_blob = ConfigBlob( + completion=KaapiCompletionConfig( + provider="google", + type="tts", + params={"model": "gemini-2.5-flash-preview-tts", "voice": "Kore"}, + ) + ) + config = create_test_config( + db=db, + project_id=user_api_key.project_id, + name="google-tts-version-config", + config_blob=config_blob, + ) + + version_data = { + "config_blob": { + "completion": { + "provider": "google", + "type": "tts", + "params": {"model": "gemini-2.5-pro-preview-tts", "voice": "Orus"}, + } + }, + "commit_message": "Switch to pro TTS model with Orus voice", + } + + response = client.post( + f"{settings.API_V1_STR}/configs/{config.id}/versions", + headers={"X-API-KEY": user_api_key.key}, + json=version_data, + ) + assert response.status_code == 201 + data = response.json() + assert data["success"] is True + assert data["data"]["version"] == 2 + assert data["data"]["config_blob"]["completion"]["type"] == "tts" + assert ( + data["data"]["config_blob"]["completion"]["params"]["model"] + == "gemini-2.5-pro-preview-tts" + ) + assert data["data"]["config_blob"]["completion"]["params"]["voice"] == "Orus" + + def test_create_config_with_kaapi_provider_success( db: Session, client: TestClient, diff --git a/backend/app/tests/crud/test_model_config.py b/backend/app/tests/crud/test_model_config.py index f16f523e6..e94a313a8 100644 --- a/backend/app/tests/crud/test_model_config.py +++ b/backend/app/tests/crud/test_model_config.py @@ -225,20 +225,34 @@ def test_validate_blob_missing_model_raises(monkeypatch: pytest.MonkeyPatch) -> assert "model is required" in exc.value.detail -def test_validate_blob_unsupported_model_raises( +def test_validate_blob_model_not_found_raises( monkeypatch: pytest.MonkeyPatch, ) -> None: + """Model that doesn't exist in model_config raises 400 with model name in detail.""" + _patch_validators(monkeypatch, row=None, supported=False) + blob = _make_blob("openai", "text", {"model": "gpt-4-turbo"}) + with pytest.raises(HTTPException) as exc: + model_config_crud.validate_blob_model_or_raise(session=None, blob=blob) # type: ignore[arg-type] + assert exc.value.status_code == 400 + assert "gpt-4-turbo" in exc.value.detail + + +def test_validate_blob_wrong_type_for_model_raises( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Model that exists but is wrong type (e.g. TTS model used as text) raises 400 with allowed list.""" + row = SimpleNamespace(config={}) _patch_validators( monkeypatch, - row=None, + row=row, supported=False, allowed=["gpt-4o", "gpt-4o-mini"], ) - blob = _make_blob("openai", "text", {"model": "gpt-4-turbo"}) + blob = _make_blob("openai", "text", {"model": "some-audio-model"}) with pytest.raises(HTTPException) as exc: model_config_crud.validate_blob_model_or_raise(session=None, blob=blob) # type: ignore[arg-type] assert exc.value.status_code == 400 - assert "gpt-4-turbo" in exc.value.detail + assert "some-audio-model" in exc.value.detail assert "gpt-4o" in exc.value.detail @@ -291,3 +305,35 @@ def test_validate_blob_tts_no_voice_spec_passes( _patch_validators(monkeypatch, row=row, supported=True) blob = _make_blob("sarvamai", "tts", {"model": "bulbul:v3", "voice": "anything"}) model_config_crud.validate_blob_model_or_raise(session=None, blob=blob) # type: ignore[arg-type] + + +def test_validate_blob_stt_model_rejected_for_text_type( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """STT-only model (audio input) must be rejected when type=text. + + Regression: previously only stt/tts triggered is_model_supported; type=text + only checked model existence, so gemini-2.5-pro (STT) passed as a text model. + """ + row = SimpleNamespace(config={}) + _patch_validators( + monkeypatch, + row=row, + supported=False, # modality filter excludes AUDIO-input models for type=text + allowed=["gpt-4o", "gpt-4o-mini"], + ) + blob = _make_blob("google", "text", {"model": "gemini-2.5-pro"}) + with pytest.raises(HTTPException) as exc: + model_config_crud.validate_blob_model_or_raise(session=None, blob=blob) # type: ignore[arg-type] + assert exc.value.status_code == 400 + assert "gemini-2.5-pro" in exc.value.detail + + +def test_validate_blob_text_model_accepted_for_text_type( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Valid text model passes for type=text.""" + row = SimpleNamespace(config={}) + _patch_validators(monkeypatch, row=row, supported=True) + blob = _make_blob("openai", "text", {"model": "gpt-4o"}) + model_config_crud.validate_blob_model_or_raise(session=None, blob=blob) # type: ignore[arg-type] From 5fa40d22e5f6bb4195a0f72af3ebaa9790f0d210 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Thu, 21 May 2026 20:34:46 +0530 Subject: [PATCH 07/10] Update STT/TTS model configurations with detailed pricing and input/output costs --- .../063_seed_stt_tts_model_configs.py | 45 ++++++++++++++----- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py b/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py index a4d12cd92..b1af6f17c 100644 --- a/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py +++ b/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py @@ -30,21 +30,46 @@ def upgrade(): + # Pricing per 1M tokens (USD). response/batch = text i/o; audio = audio-modal i/o. op.execute( """ INSERT INTO global.model_config (provider, model_name, config, input_modalities, output_modalities, pricing, is_active, inserted_at, updated_at) VALUES - ('google', 'gemini-2.5-pro', '{"temperature": {"type": "float", "default": 1.0, "min": 0.0, "max": 2.0, "description": "Controls randomness. Lower = more deterministic."}}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), - ('google', 'gemini-3.1-pro-preview', '{"thinking_level": {"type": "enum", "default": "high", "options": ["low", "medium", "high"], "description": "Max reasoning depth before output. high = best quality, low = faster/cheaper."}}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), - ('google', 'gemini-3-flash-preview', '{"thinking_level": {"type": "enum", "default": "high", "options": ["minimal", "low", "medium", "high"], "description": "Max reasoning depth before output."}}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), - ('google', 'gemini-2.5-flash', '{"temperature": {"type": "float", "default": 1.0, "min": 0.0, "max": 2.0, "description": "Controls randomness. Lower = more deterministic."}}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), - ('google', 'gemini-2.5-flash-preview-tts', '{"voice": {"type": "enum", "default": "Kore", "options": ["Kore", "Orus", "Leda", "Charon"], "description": "TTS voice."}}', '{TEXT}', '{AUDIO}', NULL, true, NOW(), NOW()), - ('google', 'gemini-2.5-pro-preview-tts', '{"voice": {"type": "enum", "default": "Kore", "options": ["Kore", "Orus", "Leda", "Charon"], "description": "TTS voice."}}', '{TEXT}', '{AUDIO}', NULL, true, NOW(), NOW()), - ('sarvamai', 'saaras:v3', '{}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), - ('sarvamai', 'bulbul:v3', '{"voice": {"type": "enum", "default": "simran", "options": ["simran", "shubh", "roopa"], "description": "TTS voice."}}', '{TEXT}', '{AUDIO}', NULL, true, NOW(), NOW()), - ('elevenlabs', 'scribe_v2', '{}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), - ('elevenlabs', 'eleven_v3', '{"voice": {"type": "enum", "default": "Sarah", "options": ["Sarah", "George", "Callum", "Liam"], "description": "TTS voice."}}', '{TEXT}', '{AUDIO}', NULL, true, NOW(), NOW()) + ('google', 'gemini-2.5-pro', + '{"temperature": {"type": "float", "default": 1.0, "min": 0.0, "max": 2.0, "description": "Controls randomness. Lower = more deterministic."}}', + '{AUDIO}', '{TEXT}', + '{"response": {"input_token_cost": 1.25, "output_token_cost": 10.0}, "batch": {"input_token_cost": 0.625, "output_token_cost": 5.0}, "audio": {"input_token_cost": 3.5, "output_token_cost": 10.0}}', + true, NOW(), NOW()), + ('google', 'gemini-3.1-pro-preview', + '{"thinking_level": {"type": "enum", "default": "high", "options": ["low", "medium", "high"], "description": "Max reasoning depth before output. high = best quality, low = faster/cheaper."}}', + '{AUDIO}', '{TEXT}', + '{"response": {"input_token_cost": 2.0, "output_token_cost": 12.0}, "batch": {"input_token_cost": 1.0, "output_token_cost": 6.0}, "audio": {"input_token_cost": 3.5, "output_token_cost": 12.0}}', + true, NOW(), NOW()), + ('google', 'gemini-3-flash-preview', + '{"thinking_level": {"type": "enum", "default": "high", "options": ["minimal", "low", "medium", "high"], "description": "Max reasoning depth before output."}}', + '{AUDIO}', '{TEXT}', + '{"response": {"input_token_cost": 0.5, "output_token_cost": 3.0}, "batch": {"input_token_cost": 0.25, "output_token_cost": 1.5}, "audio": {"input_token_cost": 1.0, "output_token_cost": 3.0}}', + true, NOW(), NOW()), + ('google', 'gemini-2.5-flash', + '{"temperature": {"type": "float", "default": 1.0, "min": 0.0, "max": 2.0, "description": "Controls randomness. Lower = more deterministic."}}', + '{AUDIO}', '{TEXT}', + '{"response": {"input_token_cost": 0.3, "output_token_cost": 2.5}, "batch": {"input_token_cost": 0.15, "output_token_cost": 1.25}, "audio": {"input_token_cost": 1.0, "output_token_cost": 2.5}}', + true, NOW(), NOW()), + ('google', 'gemini-2.5-flash-preview-tts', + '{"voice": {"type": "enum", "default": "Kore", "options": ["Kore", "Orus", "Leda", "Charon"], "description": "TTS voice."}}', + '{TEXT}', '{AUDIO}', + '{"response": {"input_token_cost": 0.5, "output_token_cost": 10.0}, "batch": {"input_token_cost": 0.25, "output_token_cost": 5.0}, "audio": {"input_token_cost": 0.5, "output_token_cost": 10.0}}', + true, NOW(), NOW()), + ('google', 'gemini-2.5-pro-preview-tts', + '{"voice": {"type": "enum", "default": "Kore", "options": ["Kore", "Orus", "Leda", "Charon"], "description": "TTS voice."}}', + '{TEXT}', '{AUDIO}', + '{"response": {"input_token_cost": 1.0, "output_token_cost": 20.0}, "batch": {"input_token_cost": 0.5, "output_token_cost": 10.0}, "audio": {"input_token_cost": 1.0, "output_token_cost": 20.0}}', + true, NOW(), NOW()), + ('sarvamai', 'saaras:v3', '{}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), + ('sarvamai', 'bulbul:v3', '{"voice": {"type": "enum", "default": "simran", "options": ["simran", "shubh", "roopa"], "description": "TTS voice."}}', '{TEXT}', '{AUDIO}', NULL, true, NOW(), NOW()), + ('elevenlabs', 'scribe_v2', '{}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), + ('elevenlabs', 'eleven_v3', '{"voice": {"type": "enum", "default": "Sarah", "options": ["Sarah", "George", "Callum", "Liam"], "description": "TTS voice."}}', '{TEXT}', '{AUDIO}', NULL, true, NOW(), NOW()) ON CONFLICT (provider, model_name) DO NOTHING """ ) From 48d56ef3c16bce7a18fd5a3cf26112a56ea17fa2 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Thu, 21 May 2026 21:01:07 +0530 Subject: [PATCH 08/10] Enhance model configuration by adding completion type enum and updating database schema for STT/TTS support --- .../063_seed_stt_tts_model_configs.py | 92 ++++++++++++++++--- backend/app/crud/model_config.py | 34 +------ backend/app/models/model_config.py | 20 +++- 3 files changed, 102 insertions(+), 44 deletions(-) diff --git a/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py b/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py index b1af6f17c..64fdac8e2 100644 --- a/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py +++ b/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py @@ -30,46 +30,98 @@ def upgrade(): - # Pricing per 1M tokens (USD). response/batch = text i/o; audio = audio-modal i/o. + # 1. Create enum types + op.execute("CREATE TYPE global.provider_enum AS ENUM ('openai', 'google', 'sarvamai', 'elevenlabs')") + op.execute("CREATE TYPE global.completion_type_enum AS ENUM ('text', 'stt', 'tts')") + + # 2. Alter provider column to use enum; add completion_type column + op.execute( + """ + ALTER TABLE global.model_config + ALTER COLUMN provider TYPE global.provider_enum + USING provider::global.provider_enum, + ADD COLUMN completion_type global.completion_type_enum + """ + ) + + # 3. Backfill completion_type for pre-existing rows (openai models seeded before this migration) + op.execute( + """ + UPDATE global.model_config SET completion_type = + CASE + WHEN 'AUDIO' = ANY(input_modalities::text[]) AND NOT ('AUDIO' = ANY(output_modalities::text[])) THEN 'stt'::global.completion_type_enum + WHEN 'AUDIO' = ANY(output_modalities::text[]) AND NOT ('AUDIO' = ANY(input_modalities::text[])) THEN 'tts'::global.completion_type_enum + ELSE 'text'::global.completion_type_enum + END + WHERE completion_type IS NULL + """ + ) + + # 4. Set NOT NULL now that all rows are backfilled + op.execute( + "ALTER TABLE global.model_config ALTER COLUMN completion_type SET NOT NULL" + ) + + # 5. Add indexes + op.execute( + "CREATE INDEX ix_model_config_provider_active ON global.model_config (provider, is_active)" + ) + op.execute( + "CREATE INDEX ix_model_config_provider_type_active ON global.model_config (provider, completion_type, is_active)" + ) + op.execute( + "CREATE INDEX ix_model_config_input_modalities ON global.model_config USING gin (input_modalities)" + ) + op.execute( + "CREATE INDEX ix_model_config_output_modalities ON global.model_config USING gin (output_modalities)" + ) + + # 6. Seed rows — pricing per 1M tokens (USD): response/batch = text i/o; audio = audio-modal i/o op.execute( """ INSERT INTO global.model_config - (provider, model_name, config, input_modalities, output_modalities, pricing, is_active, inserted_at, updated_at) + (provider, model_name, completion_type, config, input_modalities, output_modalities, pricing, is_active, inserted_at, updated_at) VALUES - ('google', 'gemini-2.5-pro', + ('google', 'gemini-2.5-pro', 'stt', '{"temperature": {"type": "float", "default": 1.0, "min": 0.0, "max": 2.0, "description": "Controls randomness. Lower = more deterministic."}}', '{AUDIO}', '{TEXT}', '{"response": {"input_token_cost": 1.25, "output_token_cost": 10.0}, "batch": {"input_token_cost": 0.625, "output_token_cost": 5.0}, "audio": {"input_token_cost": 3.5, "output_token_cost": 10.0}}', true, NOW(), NOW()), - ('google', 'gemini-3.1-pro-preview', + ('google', 'gemini-3.1-pro-preview', 'stt', '{"thinking_level": {"type": "enum", "default": "high", "options": ["low", "medium", "high"], "description": "Max reasoning depth before output. high = best quality, low = faster/cheaper."}}', '{AUDIO}', '{TEXT}', '{"response": {"input_token_cost": 2.0, "output_token_cost": 12.0}, "batch": {"input_token_cost": 1.0, "output_token_cost": 6.0}, "audio": {"input_token_cost": 3.5, "output_token_cost": 12.0}}', true, NOW(), NOW()), - ('google', 'gemini-3-flash-preview', + ('google', 'gemini-3-flash-preview', 'stt', '{"thinking_level": {"type": "enum", "default": "high", "options": ["minimal", "low", "medium", "high"], "description": "Max reasoning depth before output."}}', '{AUDIO}', '{TEXT}', '{"response": {"input_token_cost": 0.5, "output_token_cost": 3.0}, "batch": {"input_token_cost": 0.25, "output_token_cost": 1.5}, "audio": {"input_token_cost": 1.0, "output_token_cost": 3.0}}', true, NOW(), NOW()), - ('google', 'gemini-2.5-flash', + ('google', 'gemini-2.5-flash', 'stt', '{"temperature": {"type": "float", "default": 1.0, "min": 0.0, "max": 2.0, "description": "Controls randomness. Lower = more deterministic."}}', '{AUDIO}', '{TEXT}', '{"response": {"input_token_cost": 0.3, "output_token_cost": 2.5}, "batch": {"input_token_cost": 0.15, "output_token_cost": 1.25}, "audio": {"input_token_cost": 1.0, "output_token_cost": 2.5}}', true, NOW(), NOW()), - ('google', 'gemini-2.5-flash-preview-tts', + ('google', 'gemini-2.5-flash-preview-tts', 'tts', '{"voice": {"type": "enum", "default": "Kore", "options": ["Kore", "Orus", "Leda", "Charon"], "description": "TTS voice."}}', '{TEXT}', '{AUDIO}', '{"response": {"input_token_cost": 0.5, "output_token_cost": 10.0}, "batch": {"input_token_cost": 0.25, "output_token_cost": 5.0}, "audio": {"input_token_cost": 0.5, "output_token_cost": 10.0}}', true, NOW(), NOW()), - ('google', 'gemini-2.5-pro-preview-tts', + ('google', 'gemini-2.5-pro-preview-tts', 'tts', '{"voice": {"type": "enum", "default": "Kore", "options": ["Kore", "Orus", "Leda", "Charon"], "description": "TTS voice."}}', '{TEXT}', '{AUDIO}', '{"response": {"input_token_cost": 1.0, "output_token_cost": 20.0}, "batch": {"input_token_cost": 0.5, "output_token_cost": 10.0}, "audio": {"input_token_cost": 1.0, "output_token_cost": 20.0}}', true, NOW(), NOW()), - ('sarvamai', 'saaras:v3', '{}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), - ('sarvamai', 'bulbul:v3', '{"voice": {"type": "enum", "default": "simran", "options": ["simran", "shubh", "roopa"], "description": "TTS voice."}}', '{TEXT}', '{AUDIO}', NULL, true, NOW(), NOW()), - ('elevenlabs', 'scribe_v2', '{}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), - ('elevenlabs', 'eleven_v3', '{"voice": {"type": "enum", "default": "Sarah", "options": ["Sarah", "George", "Callum", "Liam"], "description": "TTS voice."}}', '{TEXT}', '{AUDIO}', NULL, true, NOW(), NOW()) + ('sarvamai', 'saaras:v3', 'stt', + '{}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), + ('sarvamai', 'bulbul:v3', 'tts', + '{"voice": {"type": "enum", "default": "simran", "options": ["simran", "shubh", "roopa"], "description": "TTS voice."}}', + '{TEXT}', '{AUDIO}', NULL, true, NOW(), NOW()), + ('elevenlabs', 'scribe_v2', 'stt', + '{}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), + ('elevenlabs', 'eleven_v3', 'tts', + '{"voice": {"type": "enum", "default": "Sarah", "options": ["Sarah", "George", "Callum", "Liam"], "description": "TTS voice."}}', + '{TEXT}', '{AUDIO}', NULL, true, NOW(), NOW()) ON CONFLICT (provider, model_name) DO NOTHING """ ) @@ -93,3 +145,19 @@ def downgrade(): ) """ ) + + op.execute("DROP INDEX IF EXISTS global.ix_model_config_output_modalities") + op.execute("DROP INDEX IF EXISTS global.ix_model_config_input_modalities") + op.execute("DROP INDEX IF EXISTS global.ix_model_config_provider_type_active") + op.execute("DROP INDEX IF EXISTS global.ix_model_config_provider_active") + + op.execute( + """ + ALTER TABLE global.model_config + DROP COLUMN completion_type, + ALTER COLUMN provider TYPE varchar USING provider::varchar + """ + ) + + op.execute("DROP TYPE IF EXISTS global.completion_type_enum") + op.execute("DROP TYPE IF EXISTS global.provider_enum") diff --git a/backend/app/crud/model_config.py b/backend/app/crud/model_config.py index c99dc231d..9c627f7f4 100644 --- a/backend/app/crud/model_config.py +++ b/backend/app/crud/model_config.py @@ -1,15 +1,13 @@ from typing import Any, Literal from fastapi import HTTPException -from sqlalchemy.dialects.postgresql import ARRAY -from sqlalchemy.sql import sqltypes from sqlmodel import Session, select from app.models import ModelConfig from app.models.llm.request import ConfigBlob +from app.models.model_config import CompletionType Provider = Literal["openai", "google", "sarvamai", "elevenlabs"] -CompletionType = Literal["text", "stt", "tts"] def _normalize_provider(raw: str) -> str: @@ -64,39 +62,15 @@ def get_model_config( return session.exec(statement).first() -def _modality_filter(stmt: Any, completion_type: CompletionType) -> Any: - """Restrict query to models matching the completion type via modalities.""" - str_array = ARRAY(sqltypes.String) - input_col = ModelConfig.input_modalities - output_col = ModelConfig.output_modalities - - if completion_type == "stt": - return stmt.where( - input_col.cast(str_array).contains(["AUDIO"]), - output_col.cast(str_array).contains(["TEXT"]), - ) - if completion_type == "tts": - return stmt.where( - input_col.cast(str_array).contains(["TEXT"]), - output_col.cast(str_array).contains(["AUDIO"]), - ) - # text: must produce TEXT and not consume/produce AUDIO - return stmt.where( - output_col.cast(str_array).contains(["TEXT"]), - ~input_col.cast(str_array).contains(["AUDIO"]), - ~output_col.cast(str_array).contains(["AUDIO"]), - ) - - def list_supported_models( session: Session, provider: Provider, completion_type: CompletionType ) -> list[str]: - """Return active model names for a provider+completion type.""" + """Return active model names for a provider + completion type.""" stmt = select(ModelConfig.model_name).where( ModelConfig.provider == provider, + ModelConfig.completion_type == completion_type, ModelConfig.is_active, ) - stmt = _modality_filter(stmt, completion_type) return list(session.exec(stmt).all()) @@ -110,9 +84,9 @@ def is_model_supported( stmt = select(ModelConfig.id).where( ModelConfig.provider == provider, ModelConfig.model_name == model_name, + ModelConfig.completion_type == completion_type, ModelConfig.is_active, ) - stmt = _modality_filter(stmt, completion_type) return session.exec(stmt).first() is not None diff --git a/backend/app/models/model_config.py b/backend/app/models/model_config.py index f469fafff..cf48514f9 100644 --- a/backend/app/models/model_config.py +++ b/backend/app/models/model_config.py @@ -7,12 +7,14 @@ from app.core.util import now +CompletionType = Literal["text", "stt", "tts"] + class ModelConfigBase(SQLModel): provider: Literal["openai", "google", "sarvamai", "elevenlabs"] = Field( default="openai", sa_column=sa.Column( - sa.String, + sa.Enum("openai", "google", "sarvamai", "elevenlabs", name="provider_enum", schema="global"), nullable=False, comment="provider name (e.g. openai, google, sarvamai, elevenlabs)", ), @@ -27,6 +29,15 @@ class ModelConfigBase(SQLModel): ), ) + completion_type: CompletionType = Field( + ..., + sa_column=sa.Column( + sa.Enum("text", "stt", "tts", name="completion_type_enum", schema="global"), + nullable=False, + comment="text | stt | tts — drives routing and validation", + ), + ) + config: dict[str, Any] = Field( default_factory=dict, sa_column=sa.Column(JSONB, nullable=False, comment="model adhoc configuration"), @@ -60,7 +71,8 @@ class ModelConfigBase(SQLModel): comment=( "pricing per 1M tokens in USD. " "Structure: {response: {input_token_cost, output_token_cost}, " - "batch: {input_token_cost, output_token_cost}}" + "batch: {input_token_cost, output_token_cost}, " + "audio: {input_token_cost, output_token_cost}}" ), ), ) @@ -80,6 +92,10 @@ class ModelConfig(ModelConfigBase, table=True): __tablename__ = "model_config" __table_args__ = ( sa.UniqueConstraint("provider", "model_name"), + sa.Index("ix_model_config_provider_active", "provider", "is_active"), + sa.Index("ix_model_config_provider_type_active", "provider", "completion_type", "is_active"), + sa.Index("ix_model_config_input_modalities", "input_modalities", postgresql_using="gin"), + sa.Index("ix_model_config_output_modalities", "output_modalities", postgresql_using="gin"), {"schema": "global"}, ) From 944efdde251e1e3285a2d6c475c1f3768f58d17f Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Thu, 21 May 2026 21:04:29 +0530 Subject: [PATCH 09/10] Refactor downgrade function to remove obsolete model deletion logic and update type hints in test functions --- .../versions/063_seed_stt_tts_model_configs.py | 18 ------------------ backend/app/tests/crud/test_model_config.py | 6 +++--- 2 files changed, 3 insertions(+), 21 deletions(-) diff --git a/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py b/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py index 64fdac8e2..977670e7d 100644 --- a/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py +++ b/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py @@ -128,24 +128,6 @@ def upgrade(): def downgrade(): - op.execute( - """ - DELETE FROM global.model_config - WHERE (provider, model_name) IN ( - ('google', 'gemini-2.5-pro'), - ('google', 'gemini-3.1-pro-preview'), - ('google', 'gemini-3-flash-preview'), - ('google', 'gemini-2.5-flash'), - ('google', 'gemini-2.5-flash-preview-tts'), - ('google', 'gemini-2.5-pro-preview-tts'), - ('sarvamai', 'saaras:v3'), - ('sarvamai', 'bulbul:v3'), - ('elevenlabs', 'scribe_v2'), - ('elevenlabs', 'eleven_v3') - ) - """ - ) - op.execute("DROP INDEX IF EXISTS global.ix_model_config_output_modalities") op.execute("DROP INDEX IF EXISTS global.ix_model_config_input_modalities") op.execute("DROP INDEX IF EXISTS global.ix_model_config_provider_type_active") diff --git a/backend/app/tests/crud/test_model_config.py b/backend/app/tests/crud/test_model_config.py index e94a313a8..b0ac19d2b 100644 --- a/backend/app/tests/crud/test_model_config.py +++ b/backend/app/tests/crud/test_model_config.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from types import SimpleNamespace from typing import Any @@ -163,7 +164,7 @@ def test_estimate_model_cost_returns_none_for_non_numeric_prices( assert result is None -def _make_blob(provider, completion_type, params): +def _make_blob(provider: str | None, completion_type: str, params: Mapping[str, Any]) -> SimpleNamespace: completion = SimpleNamespace(provider=provider, type=completion_type, params=params) return SimpleNamespace(completion=completion) @@ -198,9 +199,8 @@ def test_validate_blob_native_provider_short_circuits( """Native pass-through never hits DB.""" called = {"hit": False} - def boom(*a, **kw): + def boom(*_args: Any, **_kwargs: Any) -> None: called["hit"] = True - return None monkeypatch.setattr(model_config_crud, "get_model_config", boom) monkeypatch.setattr(model_config_crud, "is_model_supported", boom) From e00df2e6b5987c228e33147796a9259831a814eb Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Thu, 21 May 2026 21:06:36 +0530 Subject: [PATCH 10/10] Refactor model configuration and test functions for improved readability and structure --- .../063_seed_stt_tts_model_configs.py | 4 ++- backend/app/models/model_config.py | 28 ++++++++++++++++--- backend/app/tests/crud/test_model_config.py | 4 ++- 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py b/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py index 977670e7d..7c6405f3a 100644 --- a/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py +++ b/backend/app/alembic/versions/063_seed_stt_tts_model_configs.py @@ -31,7 +31,9 @@ def upgrade(): # 1. Create enum types - op.execute("CREATE TYPE global.provider_enum AS ENUM ('openai', 'google', 'sarvamai', 'elevenlabs')") + op.execute( + "CREATE TYPE global.provider_enum AS ENUM ('openai', 'google', 'sarvamai', 'elevenlabs')" + ) op.execute("CREATE TYPE global.completion_type_enum AS ENUM ('text', 'stt', 'tts')") # 2. Alter provider column to use enum; add completion_type column diff --git a/backend/app/models/model_config.py b/backend/app/models/model_config.py index cf48514f9..ef284dc5f 100644 --- a/backend/app/models/model_config.py +++ b/backend/app/models/model_config.py @@ -14,7 +14,14 @@ class ModelConfigBase(SQLModel): provider: Literal["openai", "google", "sarvamai", "elevenlabs"] = Field( default="openai", sa_column=sa.Column( - sa.Enum("openai", "google", "sarvamai", "elevenlabs", name="provider_enum", schema="global"), + sa.Enum( + "openai", + "google", + "sarvamai", + "elevenlabs", + name="provider_enum", + schema="global", + ), nullable=False, comment="provider name (e.g. openai, google, sarvamai, elevenlabs)", ), @@ -93,9 +100,22 @@ class ModelConfig(ModelConfigBase, table=True): __table_args__ = ( sa.UniqueConstraint("provider", "model_name"), sa.Index("ix_model_config_provider_active", "provider", "is_active"), - sa.Index("ix_model_config_provider_type_active", "provider", "completion_type", "is_active"), - sa.Index("ix_model_config_input_modalities", "input_modalities", postgresql_using="gin"), - sa.Index("ix_model_config_output_modalities", "output_modalities", postgresql_using="gin"), + sa.Index( + "ix_model_config_provider_type_active", + "provider", + "completion_type", + "is_active", + ), + sa.Index( + "ix_model_config_input_modalities", + "input_modalities", + postgresql_using="gin", + ), + sa.Index( + "ix_model_config_output_modalities", + "output_modalities", + postgresql_using="gin", + ), {"schema": "global"}, ) diff --git a/backend/app/tests/crud/test_model_config.py b/backend/app/tests/crud/test_model_config.py index b0ac19d2b..04857ddf5 100644 --- a/backend/app/tests/crud/test_model_config.py +++ b/backend/app/tests/crud/test_model_config.py @@ -164,7 +164,9 @@ def test_estimate_model_cost_returns_none_for_non_numeric_prices( assert result is None -def _make_blob(provider: str | None, completion_type: str, params: Mapping[str, Any]) -> SimpleNamespace: +def _make_blob( + provider: str | None, completion_type: str, params: Mapping[str, Any] +) -> SimpleNamespace: completion = SimpleNamespace(provider=provider, type=completion_type, params=params) return SimpleNamespace(completion=completion)