From 6ba7c27171c7f3980aee9f87d36dd5453cf9a0ac Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Wed, 20 May 2026 14:54:20 +0530 Subject: [PATCH 1/2] feat: Add evaluation batch submission functionality - Introduced `run_evaluation_batch_submission` task to handle evaluation batch submissions. - Created `start_evaluation_batch_submission` utility to enqueue evaluation tasks. - Updated `start_evaluation` to utilize the new Celery task for batch submissions. - Enhanced error handling in evaluation run creation and processing. - Added support for Google Gemini in evaluation batch processing. - Refactored model configuration to support multiple providers and completion types. - Improved logging for better traceability during evaluation processes. - Updated tests to reflect changes in evaluation batch submission logic. --- backend/app/celery/tasks/job_execution.py | 20 ++ backend/app/celery/utils.py | 18 ++ backend/app/core/batch/gemini.py | 2 +- backend/app/core/batch/operations.py | 3 +- backend/app/crud/evaluations/batch.py | 268 +++++++++++------- backend/app/crud/evaluations/langfuse.py | 16 +- backend/app/crud/evaluations/processing.py | 265 +++++++++++------ backend/app/crud/model_config.py | 75 ++++- backend/app/models/model_config.py | 6 +- backend/app/services/evaluations/batch_job.py | 92 ++++++ .../app/services/evaluations/evaluation.py | 69 ++--- .../app/tests/api/routes/test_evaluation.py | 2 +- 12 files changed, 592 insertions(+), 244 deletions(-) create mode 100644 backend/app/services/evaluations/batch_job.py diff --git a/backend/app/celery/tasks/job_execution.py b/backend/app/celery/tasks/job_execution.py index adadf1c9c..a106b88da 100644 --- a/backend/app/celery/tasks/job_execution.py +++ b/backend/app/celery/tasks/job_execution.py @@ -172,6 +172,26 @@ def run_delete_collection_job( ) +@celery_app.task(bind=True, queue="low_priority", priority=1) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_evaluation_batch_submission") +def run_evaluation_batch_submission( + self, project_id: int, job_id: str, trace_id: str, **kwargs +): + from app.services.evaluations.batch_job import execute_evaluation_batch_submission + + _set_trace(trace_id) + return _run_with_otel_parent( + self, + lambda: execute_evaluation_batch_submission( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), + ) + + @celery_app.task(bind=True, queue="low_priority", priority=1) @gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_stt_batch_submission") def run_stt_batch_submission( diff --git a/backend/app/celery/utils.py b/backend/app/celery/utils.py index 288cba7c4..0337c696d 100644 --- a/backend/app/celery/utils.py +++ b/backend/app/celery/utils.py @@ -126,6 +126,24 @@ def start_delete_collection_job( return task_id +def start_evaluation_batch_submission( + project_id: int, job_id: str, trace_id: str = "N/A", **kwargs +) -> str: + from app.celery.tasks.job_execution import run_evaluation_batch_submission + + task_id = _enqueue_with_trace_context( + run_evaluation_batch_submission, + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + **kwargs, + ) + logger.info( + f"[start_evaluation_batch_submission] Started job {job_id} with Celery task {task_id}" + ) + return task_id + + def start_stt_batch_submission( project_id: int, job_id: str, trace_id: str = "N/A", **kwargs ) -> str: diff --git a/backend/app/core/batch/gemini.py b/backend/app/core/batch/gemini.py index f121bb9ee..4222afbb6 100644 --- a/backend/app/core/batch/gemini.py +++ b/backend/app/core/batch/gemini.py @@ -349,7 +349,7 @@ def upload_file(self, content: str, purpose: str = "batch") -> str: uploaded_file = self._upload_to_gemini( content=content, suffix=".jsonl", - mime_type="jsonl", + mime_type="application/jsonl", display_name=f"batch-input-{int(time.time())}", ) diff --git a/backend/app/core/batch/operations.py b/backend/app/core/batch/operations.py index d02884fdb..bcd51866b 100644 --- a/backend/app/core/batch/operations.py +++ b/backend/app/core/batch/operations.py @@ -74,7 +74,8 @@ def start_batch_job( logger.error(f"[start_batch_job] Failed | {e}", exc_info=True) batch_job_update = BatchJobUpdate( - error_message=f"Batch creation failed: {str(e)}" + provider_status="failed", + error_message=f"Batch creation failed: {str(e)}", ) update_batch_job( session=session, batch_job=batch_job, batch_job_update=batch_job_update diff --git a/backend/app/crud/evaluations/batch.py b/backend/app/crud/evaluations/batch.py index 4181910b8..604cc146a 100644 --- a/backend/app/crud/evaluations/batch.py +++ b/backend/app/crud/evaluations/batch.py @@ -11,15 +11,23 @@ from typing import Any from langfuse import Langfuse -from openai import OpenAI from sqlmodel import Session -from app.core.batch.base import BATCH_KEY - -from app.core.batch import OpenAIBatchProvider, start_batch_job +from app.core.batch import ( + BATCH_KEY, + GeminiBatchProvider, + OpenAIBatchProvider, + start_batch_job, +) +from app.core.batch.client import GeminiClient from app.models import EvaluationRun from app.models.batch_job import BatchJobType -from app.models.llm.request import KaapiLLMParams +from app.services.llm.mappers import ( + map_kaapi_to_google_params, + map_kaapi_to_openai_params, +) +from app.services.llm.providers.registry import LLMProvider +from app.utils import get_openai_client logger = logging.getLogger(__name__) @@ -62,149 +70,202 @@ def fetch_dataset_items(langfuse: Langfuse, dataset_name: str) -> list[dict[str, return items -def build_evaluation_jsonl( - dataset_items: list[dict[str, Any]], config: KaapiLLMParams +def build_openai_evaluation_jsonl( + dataset_items: list[dict[str, Any]], openai_params: dict[str, Any] ) -> list[dict[str, Any]]: + """Build OpenAI Responses API batch JSONL from Langfuse dataset items. + + Each line: + { + BATCH_KEY: , + "method": "POST", + "url": "/v1/responses", + "body": { ...openai_params, "input": } + } """ - Build JSONL data for evaluation batch using OpenAI Responses API. + jsonl_data: list[dict[str, Any]] = [] + for item in dataset_items: + question = item["input"].get("question", "") + if not question: + logger.warning( + f"[build_openai_evaluation_jsonl] Skipping item - no question found | item_id={item['id']}" + ) + continue - Each line is a dict with: - - BATCH_KEY: Unique identifier for the request (dataset item ID) - - method: POST - - url: /v1/responses - - body: Response request using config as-is with input from dataset + body = dict(openai_params) + body["input"] = question - Args: - dataset_items: List of dataset items from Langfuse - config: Evaluation configuration dict with OpenAI Responses API parameters. - This config is used as-is in the body, with only "input" being added - from the dataset. Config can include any fields like: - - model (required) - - instructions - - tools - - reasoning - - text - - temperature - - include - etc. + jsonl_data.append( + { + BATCH_KEY: item["id"], + "method": "POST", + "url": "/v1/responses", + "body": body, + } + ) + return jsonl_data - Returns: - List of dictionaries (JSONL data) + +def build_google_evaluation_jsonl( + dataset_items: list[dict[str, Any]], google_params: dict[str, Any] +) -> list[dict[str, Any]]: + """Build Gemini batch JSONL from Langfuse dataset items. + + Each line: + { + "key": , + "request": { contents, systemInstruction?, generationConfig? } + } """ - jsonl_data = [] + jsonl_data: list[dict[str, Any]] = [] + system_instruction = google_params.get("instructions") + + generation_config: dict[str, Any] = {} + temperature = google_params.get("temperature") + if temperature is not None: + generation_config["temperature"] = temperature + reasoning = google_params.get("reasoning") + if reasoning: + generation_config["thinkingConfig"] = { + "includeThoughts": False, + "thinkingLevel": reasoning, + } + for item in dataset_items: - # Extract question from input question = item["input"].get("question", "") if not question: logger.warning( - f"[build_evaluation_jsonl] Skipping item - no question found | item_id={item['id']}" + f"[build_google_evaluation_jsonl] Skipping item - no question found | item_id={item['id']}" ) continue - # Build the batch request object for Responses API - # Use config as-is and only add the input field - body: dict[str, Any] = { - "model": config.model, - "instructions": config.instructions, - "input": question, # Add input from dataset + request: dict[str, Any] = { + "contents": [{"parts": [{"text": question}], "role": "user"}], } + if system_instruction: + request["systemInstruction"] = {"parts": [{"text": system_instruction}]} + if generation_config: + request["generationConfig"] = generation_config - if "temperature" in config.model_fields_set: - body["temperature"] = config.temperature - - # Add reasoning only if provided - if config.reasoning: - body["reasoning"] = {"effort": config.reasoning} - - # Add tools only if knowledge_base_ids are provided - if config.knowledge_base_ids: - body["tools"] = [ - { - "type": "file_search", - "vector_store_ids": config.knowledge_base_ids, - "max_num_results": config.max_num_results or 20, - } - ] - - batch_request = { - BATCH_KEY: item["id"], - "method": "POST", - "url": "/v1/responses", - "body": body, - } + jsonl_data.append({"key": item["id"], "request": request}) - jsonl_data.append(batch_request) return jsonl_data def start_evaluation_batch( langfuse: Langfuse, - openai_client: OpenAI, session: Session, eval_run: EvaluationRun, - config: KaapiLLMParams, + params: dict[str, Any], + provider: str, ) -> EvaluationRun: """ - Fetch data, build JSONL, and start evaluation batch. - - This function orchestrates the evaluation-specific logic and delegates - to the generic batch infrastructure for actual batch creation. + Fetch dataset, build JSONL, submit batch via the appropriate provider. Args: langfuse: Configured Langfuse client - openai_client: Configured OpenAI client session: Database session - eval_run: EvaluationRun database object (with run_name, dataset_name, config) - config: KaapiLLMParams with model, instructions, knowledge_base_ids, etc. + eval_run: EvaluationRun database object + params: Kaapi-standardized completion params (dict) + provider: Completion provider ("openai" or "google", with optional "-native" suffix) Returns: Updated EvaluationRun with batch_job_id populated - - Raises: - Exception: If any step fails """ try: - # Step 1: Fetch dataset items from Langfuse logger.info( - f"[start_evaluation_batch] Starting evaluation batch | run={eval_run.run_name}" + f"[start_evaluation_batch] Starting evaluation batch | run={eval_run.run_name} | provider={provider}" ) dataset_items = fetch_dataset_items( langfuse=langfuse, dataset_name=eval_run.dataset_name ) - # Step 2: Build evaluation-specific JSONL - jsonl_data = build_evaluation_jsonl(dataset_items=dataset_items, config=config) + base_provider = provider.replace("-native", "") - if not jsonl_data: - raise ValueError( - "Evaluation dataset did not produce any JSONL entries (missing questions?)." + if base_provider == LLMProvider.OPENAI: + mapped_params, warnings = map_kaapi_to_openai_params( + session=session, kaapi_params=params ) + if warnings: + logger.info("[start_evaluation_batch] Mapper warnings: %s", warnings) - # Step 3: Create batch provider - provider = OpenAIBatchProvider(client=openai_client) + jsonl_data = build_openai_evaluation_jsonl( + dataset_items=dataset_items, openai_params=mapped_params + ) + if not jsonl_data: + raise ValueError( + "Evaluation dataset did not produce any JSONL entries (missing questions?)." + ) + + openai_client = get_openai_client( + session=session, + org_id=eval_run.organization_id, + project_id=eval_run.project_id, + ) + batch_provider = OpenAIBatchProvider(client=openai_client) - # Step 4: Prepare batch configuration - batch_config = { - "endpoint": "/v1/responses", - "description": f"Evaluation: {eval_run.run_name}", - "completion_window": "24h", - # Store complete config for reference - "evaluation_config": config.model_dump(exclude_unset=True), - } + batch_config = { + "endpoint": "/v1/responses", + "description": f"Evaluation: {eval_run.run_name}", + "completion_window": "24h", + "evaluation_config": params, + } - # Step 5: Start batch job using generic infrastructure - batch_job = start_batch_job( - session=session, - provider=provider, - provider_name="openai", - job_type=BatchJobType.EVALUATION, - organization_id=eval_run.organization_id, - project_id=eval_run.project_id, - jsonl_data=jsonl_data, - config=batch_config, - ) + batch_job = start_batch_job( + session=session, + provider=batch_provider, + provider_name="openai", + job_type=BatchJobType.EVALUATION, + organization_id=eval_run.organization_id, + project_id=eval_run.project_id, + jsonl_data=jsonl_data, + config=batch_config, + ) + + elif base_provider == LLMProvider.GOOGLE: + mapped_params, warnings = map_kaapi_to_google_params( + kaapi_params=params, completion_type="text" + ) + if warnings: + logger.info("[start_evaluation_batch] Mapper warnings: %s", warnings) + + jsonl_data = build_google_evaluation_jsonl( + dataset_items=dataset_items, google_params=mapped_params + ) + if not jsonl_data: + raise ValueError( + "Evaluation dataset did not produce any JSONL entries (missing questions?)." + ) + + gemini_client = GeminiClient.from_credentials( + session=session, + org_id=eval_run.organization_id, + project_id=eval_run.project_id, + ) + model_name = mapped_params.get("model", "gemini-2.5-pro") + batch_provider = GeminiBatchProvider( + client=gemini_client.client, model=model_name + ) + + batch_config = { + "display_name": f"evaluation-{eval_run.run_name}", + "model": f"models/{model_name}", + } + + batch_job = start_batch_job( + session=session, + provider=batch_provider, + provider_name="google", + job_type=BatchJobType.EVALUATION, + organization_id=eval_run.organization_id, + project_id=eval_run.project_id, + jsonl_data=jsonl_data, + config=batch_config, + ) + + else: + raise ValueError(f"Unsupported provider for evaluation batches: {provider}") - # Step 6: Link batch_job to evaluation_run eval_run.batch_job_id = batch_job.id eval_run.status = "processing" eval_run.total_items = batch_job.total_items @@ -217,7 +278,8 @@ def start_evaluation_batch( f"[start_evaluation_batch] Successfully started evaluation batch | " f"batch_job_id={batch_job.id} | " f"provider_batch_id={batch_job.provider_batch_id} | " - f"run={eval_run.run_name} | items={batch_job.total_items}" + f"run={eval_run.run_name} | items={batch_job.total_items} | " + f"provider={base_provider}" ) return eval_run diff --git a/backend/app/crud/evaluations/langfuse.py b/backend/app/crud/evaluations/langfuse.py index be5da4eb3..c72326829 100644 --- a/backend/app/crud/evaluations/langfuse.py +++ b/backend/app/crud/evaluations/langfuse.py @@ -145,11 +145,17 @@ def create_langfuse_dataset_run( trace_id_mapping[item_id] = trace_id except Exception as e: - logger.error( - f"[create_langfuse_dataset_run] Failed to create trace | " - f"item_id={item_id} | {e}", - exc_info=True, - ) + if getattr(e, "status_code", None) == 429: + logger.error( + f"[create_langfuse_dataset_run] Langfuse rate limit (429) | " + f"item_id={item_id}" + ) + else: + logger.error( + f"[create_langfuse_dataset_run] Failed to create trace | " + f"item_id={item_id} | {e}", + exc_info=True, + ) continue langfuse.flush() diff --git a/backend/app/crud/evaluations/processing.py b/backend/app/crud/evaluations/processing.py index 607e29e49..87671943a 100644 --- a/backend/app/crud/evaluations/processing.py +++ b/backend/app/crud/evaluations/processing.py @@ -9,6 +9,7 @@ """ import ast +import asyncio import json import logging from collections import defaultdict @@ -20,13 +21,17 @@ from sqlmodel import Session, select from app.core.batch import ( + GeminiBatchProvider, OpenAIBatchProvider, download_batch_results, poll_batch_status, upload_batch_results_to_object_store, ) -from app.core.batch.base import BATCH_KEY +from app.core.batch.base import BATCH_KEY, BatchProvider +from app.core.batch.client import GeminiClient +from app.core.batch.gemini import BatchJobState, extract_text_from_response_dict from app.crud.evaluations.batch import fetch_dataset_items +from app.services.llm.providers.registry import LLMProvider from app.crud.evaluations.core import resolve_model_from_config, update_evaluation_run from app.crud.evaluations.cost import attach_cost from app.crud.evaluations.embeddings import ( @@ -47,6 +52,32 @@ logger = logging.getLogger(__name__) +def _get_batch_provider( + session: Session, + provider_name: str, + organization_id: int, + project_id: int, +) -> BatchProvider: + """Get appropriate batch provider instance for the main response batch.""" + if provider_name in (LLMProvider.OPENAI, LLMProvider.OPENAI_NATIVE): + openai_client = get_openai_client( + session=session, + org_id=organization_id, + project_id=project_id, + ) + return OpenAIBatchProvider(client=openai_client) + + if provider_name in (LLMProvider.GOOGLE, LLMProvider.GOOGLE_NATIVE): + gemini_client = GeminiClient.from_credentials( + session=session, + org_id=organization_id, + project_id=project_id, + ) + return GeminiBatchProvider(client=gemini_client.client) + + raise ValueError(f"Unsupported provider for evaluation polling: {provider_name}") + + def _extract_batch_error_message( provider: OpenAIBatchProvider, error_file_id: str, @@ -115,8 +146,35 @@ def _extract_batch_error_message( return error_msg +def _extract_gemini_usage(response: dict[str, Any]) -> dict[str, Any] | None: + """Map Gemini usageMetadata to OpenAI-style usage dict for cost tracking.""" + usage_metadata = response.get("usageMetadata") or response.get("usage_metadata") + if not isinstance(usage_metadata, dict): + return None + input_tokens = usage_metadata.get("promptTokenCount") or usage_metadata.get( + "prompt_token_count" + ) + output_tokens = usage_metadata.get("candidatesTokenCount") or usage_metadata.get( + "candidates_token_count" + ) + total_tokens = usage_metadata.get("totalTokenCount") or usage_metadata.get( + "total_token_count" + ) + reasoning_tokens = usage_metadata.get("thoughtsTokenCount") or usage_metadata.get( + "thoughts_token_count" + ) + return { + "input_tokens": input_tokens or 0, + "output_tokens": output_tokens or 0, + "total_tokens": total_tokens or 0, + "reasoning_tokens": reasoning_tokens or 0, + } + + def parse_evaluation_output( - raw_results: list[dict[str, Any]], dataset_items: list[dict[str, Any]] + raw_results: list[dict[str, Any]], + dataset_items: list[dict[str, Any]], + provider_name: str = LLMProvider.OPENAI, ) -> list[dict[str, Any]]: """ Parse batch output into evaluation results. @@ -148,20 +206,19 @@ def parse_evaluation_output( """ # Create lookup map for dataset items by ID dataset_map = {item["id"]: item for item in dataset_items} + is_google = provider_name in (LLMProvider.GOOGLE, LLMProvider.GOOGLE_NATIVE) results = [] for line_num, response in enumerate(raw_results, 1): try: - # Extract BATCH_KEY (which is our dataset item ID) - item_id = response.get(BATCH_KEY) + item_id = response.get(BATCH_KEY) or response.get("key") if not item_id: logger.warning( - f"[parse_evaluation_output] No {BATCH_KEY} found, skipping | line={line_num}" + f"[parse_evaluation_output] No item_id found, skipping | line={line_num}" ) continue - # Get original dataset item dataset_item = dataset_map.get(item_id) if not dataset_item: logger.warning( @@ -169,64 +226,70 @@ def parse_evaluation_output( ) continue - # Extract the response body - response_body = response.get("response", {}).get("body", {}) + generated_output = "" + response_id: str | None = None + usage: dict[str, Any] | None = None - # Extract response ID from response.body.id - response_id = response_body.get("id") + if is_google: + gemini_response = response.get("response") + error = response.get("error") + if error: + error_msg = str(error) + logger.warning( + f"[parse_evaluation_output] Item had error | item_id={item_id} | {error_msg}" + ) + generated_output = f"ERROR: {error_msg}" + elif isinstance(gemini_response, dict): + text = extract_text_from_response_dict(gemini_response) + generated_output = text or "" + response_id = gemini_response.get( + "responseId" + ) or gemini_response.get("response_id") + usage = _extract_gemini_usage(gemini_response) + else: + response_body = response.get("response", {}).get("body", {}) + response_id = response_body.get("id") + usage = response_body.get("usage") - # Extract usage information for cost tracking - usage = response_body.get("usage") + if response.get("error"): + error_msg = response["error"].get("message", "Unknown error") + logger.warning( + f"[parse_evaluation_output] Item had error | item_id={item_id} | {error_msg}" + ) + generated_output = f"ERROR: {error_msg}" + else: + output = response_body.get("output", "") - # Handle errors in batch processing - if response.get("error"): - error_msg = response["error"].get("message", "Unknown error") - logger.warning( - f"[parse_evaluation_output] Item had error | item_id={item_id} | {error_msg}" - ) - generated_output = f"ERROR: {error_msg}" - else: - # Extract text from output (can be string, list, or complex structure) - output = response_body.get("output", "") - - # If string, try to parse it (may be JSON or Python repr of list) - if isinstance(output, str): - try: - output = json.loads(output) - except (json.JSONDecodeError, ValueError): + if isinstance(output, str): try: - output = ast.literal_eval(output) - except (ValueError, SyntaxError): - # Keep as string if parsing fails - generated_output = output - output = None - - # If we have a list structure, extract text from message items - if isinstance(output, list): - generated_output = "" - for item in output: - if isinstance(item, dict) and item.get("type") == "message": - for content in item.get("content", []): - if ( - isinstance(content, dict) - and content.get("type") == "output_text" - ): - generated_output = content.get("text", "") + output = json.loads(output) + except (json.JSONDecodeError, ValueError): + try: + output = ast.literal_eval(output) + except (ValueError, SyntaxError): + generated_output = output + output = None + + if isinstance(output, list): + for item in output: + if isinstance(item, dict) and item.get("type") == "message": + for content in item.get("content", []): + if ( + isinstance(content, dict) + and content.get("type") == "output_text" + ): + generated_output = content.get("text", "") + break + if generated_output: break - if generated_output: - break - elif output is not None: - # output was not a string and not a list - generated_output = "" - logger.warning( - f"[parse_evaluation_output] Unexpected output type | item_id={item_id} | type={type(output)}" - ) + elif output is not None: + generated_output = "" + logger.warning( + f"[parse_evaluation_output] Unexpected output type | item_id={item_id} | type={type(output)}" + ) - # Extract question and ground truth from dataset item question = dataset_item["input"].get("question", "") ground_truth = dataset_item["expected_output"].get("answer", "") - - # Extract question_id from dataset item metadata question_id = dataset_item.get("metadata", {}).get("question_id") results.append( @@ -300,8 +363,15 @@ async def process_completed_evaluation( logger.info( f"[process_completed_evaluation] {log_prefix} Downloading batch results | batch_job_id={batch_job.id}" ) - provider = OpenAIBatchProvider(client=openai_client) - raw_results = download_batch_results(provider=provider, batch_job=batch_job) + provider = _get_batch_provider( + session=session, + provider_name=batch_job.provider, + organization_id=eval_run.organization_id, + project_id=eval_run.project_id, + ) + raw_results = await asyncio.to_thread( + download_batch_results, provider=provider, batch_job=batch_job + ) # Step 2a: Upload raw results to object store for evaluation_run object_store_url = None @@ -318,13 +388,17 @@ async def process_completed_evaluation( logger.info( f"[process_completed_evaluation] {log_prefix} Fetching dataset items | dataset={eval_run.dataset_name}" ) - dataset_items = fetch_dataset_items( - langfuse=langfuse, dataset_name=eval_run.dataset_name + dataset_items = await asyncio.to_thread( + fetch_dataset_items, + langfuse=langfuse, + dataset_name=eval_run.dataset_name, ) # Step 4: Parse evaluation results results = parse_evaluation_output( - raw_results=raw_results, dataset_items=dataset_items + raw_results=raw_results, + dataset_items=dataset_items, + provider_name=batch_job.provider, ) if not results: @@ -348,7 +422,8 @@ async def process_completed_evaluation( update=EvaluationRunUpdate(cost=eval_run.cost), ) - trace_id_mapping = create_langfuse_dataset_run( + trace_id_mapping = await asyncio.to_thread( + create_langfuse_dataset_run, langfuse=langfuse, dataset_name=eval_run.dataset_name, model=model, @@ -365,7 +440,8 @@ async def process_completed_evaluation( # Step 6: Start embedding batch for similarity scoring # Pass trace_id_mapping directly without storing in DB try: - eval_run = start_embedding_batch( + eval_run = await asyncio.to_thread( + start_embedding_batch, session=session, openai_client=openai_client, eval_run=eval_run, @@ -463,8 +539,10 @@ async def process_completed_embedding_batch( # Step 2: Create provider and download results provider = OpenAIBatchProvider(client=openai_client) - raw_results = download_batch_results( - provider=provider, batch_job=embedding_batch_job + raw_results = await asyncio.to_thread( + download_batch_results, + provider=provider, + batch_job=embedding_batch_job, ) # Step 3: Parse embedding results @@ -497,7 +575,8 @@ async def process_completed_embedding_batch( per_item_scores = similarity_stats.get("per_item_scores", []) if per_item_scores: try: - update_traces_with_cosine_scores( + await asyncio.to_thread( + update_traces_with_cosine_scores, langfuse=langfuse, per_item_scores=per_item_scores, ) @@ -594,8 +673,11 @@ async def check_and_process_evaluation( if embedding_batch_job: # Poll embedding batch status provider = OpenAIBatchProvider(client=openai_client) - poll_batch_status( - session=session, provider=provider, batch_job=embedding_batch_job + await asyncio.to_thread( + poll_batch_status, + session=session, + provider=provider, + batch_job=embedding_batch_job, ) session.refresh(embedding_batch_job) @@ -666,23 +748,41 @@ async def check_and_process_evaluation( f"BatchJob {eval_run.batch_job_id} not found for evaluation {eval_run.id}" ) - # IMPORTANT: Poll OpenAI to get the latest status before checking - provider = OpenAIBatchProvider(client=openai_client) - status_result = poll_batch_status( - session=session, provider=provider, batch_job=batch_job + # Build batch provider from batch_job.provider (OpenAI or Gemini) + provider = _get_batch_provider( + session=session, + provider_name=batch_job.provider, + organization_id=eval_run.organization_id, + project_id=eval_run.project_id, + ) + status_result = await asyncio.to_thread( + poll_batch_status, + session=session, + provider=provider, + batch_job=batch_job, ) # Refresh batch_job to get the updated provider_status session.refresh(batch_job) provider_status = batch_job.provider_status + is_openai = batch_job.provider in ( + LLMProvider.OPENAI, + LLMProvider.OPENAI_NATIVE, + ) # Handle different provider statuses - if provider_status == "completed": - # Check if batch completed but all requests failed - # (output_file_id is absent, error_file_id is present) - if not status_result.get( - "provider_output_file_id", batch_job.provider_output_file_id - ) and status_result.get("error_file_id"): + if ( + provider_status == "completed" + or provider_status == BatchJobState.SUCCEEDED.value + ): + # OpenAI-only: batch completed but all requests failed (error file present, output file absent) + if ( + is_openai + and not status_result.get( + "provider_output_file_id", batch_job.provider_output_file_id + ) + and status_result.get("error_file_id") + ): error_msg = _extract_batch_error_message( provider=provider, error_file_id=status_result["error_file_id"], @@ -730,7 +830,14 @@ async def check_and_process_evaluation( "action": "processed", } - elif provider_status in ["failed", "expired", "cancelled"]: + elif provider_status in ( + "failed", + "expired", + "cancelled", + BatchJobState.FAILED.value, + BatchJobState.CANCELLED.value, + BatchJobState.EXPIRED.value, + ): # Mark evaluation as failed based on provider status error_msg = batch_job.error_message or f"Provider batch {provider_status}" diff --git a/backend/app/crud/model_config.py b/backend/app/crud/model_config.py index 6d535240a..1ab7f923a 100644 --- a/backend/app/crud/model_config.py +++ b/backend/app/crud/model_config.py @@ -1,20 +1,28 @@ from typing import Any, Literal +from sqlalchemy.dialects.postgresql import ARRAY +from sqlalchemy.sql import sqltypes from sqlmodel import Session, select from app.models import ModelConfig +Provider = str +CompletionType = Literal["text", "stt", "tts"] + def list_active_model_configs( session: Session, - provider: Literal["openai", "google"] | None = None, + provider: Provider | None = None, skip: int = 0, limit: int = 100, ) -> tuple[list[ModelConfig], bool]: statement = select(ModelConfig).where(ModelConfig.is_active) if provider: - statement = statement.where(ModelConfig.provider == provider) + try: + statement = statement.where(ModelConfig.provider == provider) + except Exception: + return [], False statement = statement.order_by(ModelConfig.provider, ModelConfig.model_name) statement = statement.offset(skip).limit(limit + 1) @@ -30,19 +38,22 @@ def list_active_model_configs( def list_all_active_model_configs( session: Session, - provider: Literal["openai", "google"] | None = None, + provider: Provider | None = None, ) -> list[ModelConfig]: statement = select(ModelConfig).where(ModelConfig.is_active) if provider: - statement = statement.where(ModelConfig.provider == provider) + try: + statement = statement.where(ModelConfig.provider == provider) + except Exception: + return [] statement = statement.order_by(ModelConfig.provider, ModelConfig.model_name) return list(session.exec(statement).all()) def get_model_config( - session: Session, provider: Literal["openai", "google"], model_name: str + session: Session, provider: Provider, model_name: str ) -> ModelConfig | None: statement = select(ModelConfig).where( ModelConfig.provider == provider, @@ -52,9 +63,57 @@ def get_model_config( return session.exec(statement).first() -def is_reasoning_model( - session: Session, provider: Literal["openai", "google"], model_name: str +def _modality_filter(stmt, completion_type: CompletionType): + """Restrict query to models matching the completion type via modalities.""" + str_array = ARRAY(sqltypes.String) + input_col = ModelConfig.input_modalities + output_col = ModelConfig.output_modalities + + if completion_type == "stt": + return stmt.where( + input_col.cast(str_array).contains(["AUDIO"]), + output_col.cast(str_array).contains(["TEXT"]), + ) + if completion_type == "tts": + return stmt.where( + input_col.cast(str_array).contains(["TEXT"]), + output_col.cast(str_array).contains(["AUDIO"]), + ) + return stmt.where( + input_col.cast(str_array).contains(["TEXT"]), + output_col.cast(str_array).contains(["TEXT"]), + ) + + +def list_supported_models( + session: Session, provider: Provider, completion_type: CompletionType +) -> list[str]: + """Return active model names for a provider+completion type.""" + stmt = select(ModelConfig.model_name).where( + ModelConfig.provider == provider, + ModelConfig.is_active, + ) + stmt = _modality_filter(stmt, completion_type) + return list(session.exec(stmt).all()) + + +def is_model_supported( + session: Session, + provider: Provider, + completion_type: CompletionType, + model_name: str, ) -> bool: + """Check whether (provider, model_name) is active and matches the completion type.""" + stmt = select(ModelConfig.id).where( + ModelConfig.provider == provider, + ModelConfig.model_name == model_name, + ModelConfig.is_active, + ) + stmt = _modality_filter(stmt, completion_type) + return session.exec(stmt).first() is not None + + +def is_reasoning_model(session: Session, provider: Provider, model_name: str) -> bool: """Return True if the model is configured with a reasoning `effort` control. A model is considered reasoning-capable if its `config` JSON contains an @@ -69,7 +128,7 @@ def is_reasoning_model( def estimate_model_cost( session: Session, - provider: Literal["openai", "google"], + provider: Provider, model_name: str, input_tokens: int, output_tokens: int, diff --git a/backend/app/models/model_config.py b/backend/app/models/model_config.py index bc1d14115..c4a109cba 100644 --- a/backend/app/models/model_config.py +++ b/backend/app/models/model_config.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Literal +from typing import Any import sqlalchemy as sa from sqlalchemy.dialects.postgresql import ARRAY, JSONB @@ -9,12 +9,12 @@ class ModelConfigBase(SQLModel): - provider: Literal["openai", "google"] = Field( + provider: str = Field( default="openai", sa_column=sa.Column( sa.String, nullable=False, - comment="provider name (e.g. openai, google)", + comment="provider name (e.g. openai, google, sarvamai, elevenlabs)", ), ) diff --git a/backend/app/services/evaluations/batch_job.py b/backend/app/services/evaluations/batch_job.py new file mode 100644 index 000000000..ec13045ee --- /dev/null +++ b/backend/app/services/evaluations/batch_job.py @@ -0,0 +1,92 @@ +import logging +from uuid import UUID + +from celery.exceptions import SoftTimeLimitExceeded +from gevent import Timeout +from sqlmodel import Session + +from app.core.db import engine +from app.crud.evaluations import ( + get_evaluation_run_by_id, + resolve_evaluation_config, + start_evaluation_batch, +) +from app.crud.evaluations.core import update_evaluation_run +from app.models.evaluation import EvaluationRunUpdate +from app.utils import get_langfuse_client + +logger = logging.getLogger(__name__) + + +def execute_evaluation_batch_submission( + project_id: int, + job_id: str, + task_id: str, + task_instance, + organization_id: int, + config_id: str, + config_version: int, + **kwargs, +) -> dict: + run_id = int(job_id) + logger.info( + f"[execute_evaluation_batch_submission] Starting | run_id={run_id} | task={task_id}" + ) + with Session(engine) as session: + run = get_evaluation_run_by_id( + session=session, + evaluation_id=run_id, + organization_id=organization_id, + project_id=project_id, + ) + if not run: + return {"success": False, "error": "Run not found"} + try: + config, error = resolve_evaluation_config( + session=session, + config_id=UUID(str(config_id)), + config_version=config_version, + project_id=project_id, + ) + if error: + update_evaluation_run( + session=session, + eval_run=run, + update=EvaluationRunUpdate(status="failed", error_message=error), + ) + return {"success": False, "error": error} + + langfuse = get_langfuse_client( + session=session, org_id=organization_id, project_id=project_id + ) + run = start_evaluation_batch( + langfuse=langfuse, + session=session, + eval_run=run, + params=config.completion.params, + provider=config.completion.provider, + ) + return {"success": True, "batch_job_id": run.batch_job_id} + except (Timeout, SoftTimeLimitExceeded): + logger.warning( + f"[execute_evaluation_batch_submission] Timeout | run_id={run_id}" + ) + update_evaluation_run( + session=session, + eval_run=run, + update=EvaluationRunUpdate( + status="failed", error_message="Task exceeded soft time limit" + ), + ) + raise + except Exception as e: + logger.error( + f"[execute_evaluation_batch_submission] Failed | run_id={run_id} | {e}", + exc_info=True, + ) + update_evaluation_run( + session=session, + eval_run=run, + update=EvaluationRunUpdate(status="failed", error_message=str(e)), + ) + raise diff --git a/backend/app/services/evaluations/evaluation.py b/backend/app/services/evaluations/evaluation.py index 9f78b9e37..1d23b9efc 100644 --- a/backend/app/services/evaluations/evaluation.py +++ b/backend/app/services/evaluations/evaluation.py @@ -3,9 +3,11 @@ import logging from uuid import UUID +from asgi_correlation_id import correlation_id from fastapi import HTTPException from sqlmodel import Session +from app.celery.utils import start_evaluation_batch_submission from app.crud.evaluations import ( create_evaluation_run, fetch_trace_scores_from_langfuse, @@ -13,12 +15,11 @@ get_evaluation_run_by_id, resolve_evaluation_config, save_score, - start_evaluation_batch, ) -from app.models.evaluation import EvaluationRun -from app.models.llm.request import TextLLMParams, STTLLMParams, TTSLLMParams +from app.models.evaluation import EvaluationRun, EvaluationRunUpdate +from app.crud.evaluations.core import update_evaluation_run from app.services.llm.providers import LLMProvider -from app.utils import get_langfuse_client, get_openai_client +from app.utils import get_langfuse_client from app.core.cloud.storage import get_cloud_storage from app.core.storage_utils import load_json_from_object_store @@ -106,29 +107,16 @@ def start_evaluation( status_code=400, detail=f"Failed to resolve config from stored config: {error}", ) - elif config.completion.provider != LLMProvider.OPENAI: + elif config.completion.provider not in (LLMProvider.OPENAI, LLMProvider.GOOGLE): raise HTTPException( status_code=422, - detail="Only 'openai' provider is supported for evaluation configs", + detail="Only 'openai' and 'google' providers are supported for evaluation configs", ) logger.info( "[start_evaluation] Successfully resolved config from config management" ) - # Get API clients - openai_client = get_openai_client( - session=session, - org_id=organization_id, - project_id=project_id, - ) - langfuse = get_langfuse_client( - session=session, - org_id=organization_id, - project_id=project_id, - ) - - # Step 3: Create EvaluationRun record with config references eval_run = create_evaluation_run( session=session, run_name=experiment_name, @@ -140,39 +128,34 @@ def start_evaluation( project_id=project_id, ) - # Step 4: Start the batch evaluation + trace_id = correlation_id.get() or "N/A" try: - # Convert params dict to appropriate model instance based on type - param_models = { - "text": TextLLMParams, - "stt": STTLLMParams, - "tts": TTSLLMParams, - } - model_class = param_models[config.completion.type] - validated_params = model_class.model_validate(config.completion.params) - - eval_run = start_evaluation_batch( - langfuse=langfuse, - openai_client=openai_client, - session=session, - eval_run=eval_run, - config=validated_params, + celery_task_id = start_evaluation_batch_submission( + project_id=project_id, + job_id=str(eval_run.id), + trace_id=trace_id, + organization_id=organization_id, + config_id=str(config_id), + config_version=config_version, ) - logger.info( - f"[start_evaluation] Evaluation started successfully | " - f"batch_job_id={eval_run.batch_job_id} | total_items={eval_run.total_items}" + f"[start_evaluation] Batch submission queued | " + f"run_id={eval_run.id} | celery_task_id={celery_task_id}" ) - return eval_run - except Exception as e: logger.error( - f"[start_evaluation] Failed to start evaluation | run_id={eval_run.id} | {e}", + f"[start_evaluation] Failed to queue batch submission | run_id={eval_run.id} | {e}", exc_info=True, ) - # Error is already handled in start_evaluation_batch - session.refresh(eval_run) + eval_run = update_evaluation_run( + session=session, + eval_run=eval_run, + update=EvaluationRunUpdate( + status="failed", + error_message=f"Failed to queue batch submission: {e}", + ), + ) return eval_run diff --git a/backend/app/tests/api/routes/test_evaluation.py b/backend/app/tests/api/routes/test_evaluation.py index 8b4e6679c..f324111c8 100644 --- a/backend/app/tests/api/routes/test_evaluation.py +++ b/backend/app/tests/api/routes/test_evaluation.py @@ -7,7 +7,7 @@ from fastapi.testclient import TestClient from sqlmodel import Session, select -from app.crud.evaluations.batch import build_evaluation_jsonl +from app.crud.evaluations.batch import build_openai_evaluation_jsonl from app.models import EvaluationDataset, EvaluationRun from app.models.llm.request import TextLLMParams from app.tests.utils.auth import TestAuthContext From ad191d3ebd67c03e98abfbc43daeba69e05409cd Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Thu, 21 May 2026 08:18:15 +0530 Subject: [PATCH 2/2] feat: Add Langfuse flush to prevent thread accumulation in evaluation polling --- backend/app/crud/evaluations/processing.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/backend/app/crud/evaluations/processing.py b/backend/app/crud/evaluations/processing.py index 87671943a..acf2db61d 100644 --- a/backend/app/crud/evaluations/processing.py +++ b/backend/app/crud/evaluations/processing.py @@ -1039,6 +1039,16 @@ async def poll_all_pending_evaluations(session: Session) -> dict[str, Any]: ) total_failed_count += 1 + # Flush background ingestion threads before discarding this client. + # Each Langfuse() instance spawns TaskManager threads; without flush, + # they accumulate across cron runs and pile up concurrent API calls. + try: + langfuse.flush() + except Exception as flush_err: + logger.warning( + f"[poll_all_pending_evaluations] Langfuse flush failed | project_id={project_id} | {flush_err}" + ) + except Exception as e: logger.error( f"[poll_all_pending_evaluations] Failed to process project | project_id={project_id} | {e}",