From fe896bc70febac466248924cafdbcd153779929e Mon Sep 17 00:00:00 2001 From: Juan Sugg Date: Thu, 12 Mar 2026 22:29:19 -0300 Subject: [PATCH] refactor: internalize runtime and transcription boundaries --- .github/workflows/ci.yml | 14 +- .../workflows/darwin-x86_64-validation.yml | 1 + .../full-dataset-quality-gate-regression.yml | 18 + .../linux-python-3_13-cli-validation.yml | 14 + .../linux-selfhosted-gpu-validation.yml | 36 +- .github/workflows/macos15-mps-validation.yml | 18 + .github/workflows/python-publish-testpypi.yml | 14 +- .github/workflows/python-publish.yml | 49 +- pyproject.toml | 4 + .../runtime/accurate_public_boundary.py | 425 +++- ser/_internal/runtime/fast_public_boundary.py | 437 ++++ .../runtime/medium_public_boundary.py | 400 +++- ser/_internal/transcription/compatibility.py | 20 +- .../transcription/public_boundary_runtime.py | 2 +- .../transcription/public_boundary_support.py | 135 +- ser/runtime/accurate_inference.py | 500 +--- ser/runtime/fast_inference.py | 417 +--- ser/runtime/medium_inference.py | 496 +--- .../backends/stable_whisper_mps_compat.py | 182 +- ser/transcript/transcript_extractor.py | 420 +--- tests/__init__.py | 1 + tests/conftest.py | 2 + tests/fixtures/__init__.py | 1 + tests/fixtures/settings.py | 17 + .../integration/test_accurate_inference.py | 364 +++ .../integration}/test_backend_hooks.py | 10 +- .../suites/integration/test_fast_inference.py | 220 ++ .../integration/test_medium_inference.py | 364 +++ .../integration/test_process_isolation.py | 314 +++ .../integration}/test_runtime_pipeline.py | 10 +- .../integration}/test_runtime_registry.py | 10 +- .../integration/test_transcript_extractor.py | 271 +++ .../smoke/test_cli_runtime_workflows.py | 110 + tests/suites/unit/__init__.py | 1 + .../unit}/test_accurate_execution.py | 12 +- .../unit}/test_accurate_execution_flow.py | 3 + .../unit}/test_accurate_operation_setup.py | 2 + .../unit}/test_accurate_process_timeout.py | 2 + .../unit}/test_accurate_retry_operation.py | 2 + .../unit}/test_accurate_runtime_support.py | 2 + .../unit}/test_accurate_worker_lifecycle.py | 2 + .../unit}/test_accurate_worker_operation.py | 2 + .../unit}/test_medium_execution.py | 12 +- .../unit}/test_medium_execution_context.py | 4 + .../unit}/test_medium_execution_flow.py | 4 + .../unit}/test_medium_process_operation.py | 2 + .../unit}/test_medium_process_timeout.py | 2 + .../unit}/test_medium_retry_operation.py | 2 + .../unit}/test_medium_runtime_support.py | 2 + .../unit}/test_medium_worker_lifecycle.py | 2 + .../unit}/test_medium_worker_operation.py | 2 + .../test_runtime_worker_error_timeout.py | 2 + tests/test_accurate_inference.py | 893 ------- tests/test_fast_inference.py | 331 --- tests/test_medium_inference.py | 740 ------ tests/test_medium_timeout_and_fallback.py | 574 ----- ...est_runtime_worker_lifecycle_delegation.py | 209 -- tests/test_stable_whisper_mps_compat.py | 96 +- tests/test_transcript_extractor.py | 2053 ----------------- tests/utils/__init__.py | 1 + tests/utils/helpers/__init__.py | 1 + tests/utils/helpers/process_spawn_support.py | 115 + 62 files changed, 3733 insertions(+), 6638 deletions(-) create mode 100644 ser/_internal/runtime/fast_public_boundary.py create mode 100644 tests/__init__.py create mode 100644 tests/fixtures/__init__.py create mode 100644 tests/fixtures/settings.py create mode 100644 tests/suites/integration/test_accurate_inference.py rename tests/{ => suites/integration}/test_backend_hooks.py (97%) create mode 100644 tests/suites/integration/test_fast_inference.py create mode 100644 tests/suites/integration/test_medium_inference.py create mode 100644 tests/suites/integration/test_process_isolation.py rename tests/{ => suites/integration}/test_runtime_pipeline.py (99%) rename tests/{ => suites/integration}/test_runtime_registry.py (97%) create mode 100644 tests/suites/integration/test_transcript_extractor.py create mode 100644 tests/suites/smoke/test_cli_runtime_workflows.py create mode 100644 tests/suites/unit/__init__.py rename tests/{ => suites/unit}/test_accurate_execution.py (95%) rename tests/{ => suites/unit}/test_accurate_execution_flow.py (99%) rename tests/{ => suites/unit}/test_accurate_operation_setup.py (99%) rename tests/{ => suites/unit}/test_accurate_process_timeout.py (99%) rename tests/{ => suites/unit}/test_accurate_retry_operation.py (99%) rename tests/{ => suites/unit}/test_accurate_runtime_support.py (99%) rename tests/{ => suites/unit}/test_accurate_worker_lifecycle.py (99%) rename tests/{ => suites/unit}/test_accurate_worker_operation.py (99%) rename tests/{ => suites/unit}/test_medium_execution.py (95%) rename tests/{ => suites/unit}/test_medium_execution_context.py (99%) rename tests/{ => suites/unit}/test_medium_execution_flow.py (99%) rename tests/{ => suites/unit}/test_medium_process_operation.py (99%) rename tests/{ => suites/unit}/test_medium_process_timeout.py (99%) rename tests/{ => suites/unit}/test_medium_retry_operation.py (99%) rename tests/{ => suites/unit}/test_medium_runtime_support.py (99%) rename tests/{ => suites/unit}/test_medium_worker_lifecycle.py (99%) rename tests/{ => suites/unit}/test_medium_worker_operation.py (99%) rename tests/{ => suites/unit}/test_runtime_worker_error_timeout.py (99%) delete mode 100644 tests/test_accurate_inference.py delete mode 100644 tests/test_fast_inference.py delete mode 100644 tests/test_medium_inference.py delete mode 100644 tests/test_medium_timeout_and_fallback.py delete mode 100644 tests/test_runtime_worker_lifecycle_delegation.py delete mode 100644 tests/test_transcript_extractor.py create mode 100644 tests/utils/__init__.py create mode 100644 tests/utils/helpers/__init__.py create mode 100644 tests/utils/helpers/process_spawn_support.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6a5f1e5..1f2a2f4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -135,15 +135,11 @@ jobs: - name: Run tests run: uv run --frozen --python ${{ matrix.python-version }} --extra dev pytest -q - - name: Enforce Python 3.13 medium contract checks + - name: Enforce Python 3.13 organized workflow suites if: matrix.python-version == '3.13' run: > uv run --frozen --python ${{ matrix.python-version }} --extra dev --extra medium pytest -q - tests/test_medium_inference.py - tests/test_medium_timeout_and_fallback.py - tests/test_backend_hooks.py - tests/test_runtime_registry.py - tests/test_runtime_pipeline.py + tests/suites contract-gates: runs-on: ubuntu-latest @@ -173,6 +169,12 @@ jobs: - name: Enforce API import boundary lint gate run: make import-lint + - name: Enforce organized smoke and process-isolation suites + run: > + uv run --frozen --python 3.12 --extra dev pytest -q + tests/suites + -m "smoke or process_isolation" + - name: Enforce transcription benchmark contract gate run: > uv run --frozen --python 3.12 --extra dev pytest -q diff --git a/.github/workflows/darwin-x86_64-validation.yml b/.github/workflows/darwin-x86_64-validation.yml index 07dc519..4a11c7d 100644 --- a/.github/workflows/darwin-x86_64-validation.yml +++ b/.github/workflows/darwin-x86_64-validation.yml @@ -1,6 +1,7 @@ name: Darwin x86_64 Validation on: + workflow_call: workflow_dispatch: permissions: diff --git a/.github/workflows/full-dataset-quality-gate-regression.yml b/.github/workflows/full-dataset-quality-gate-regression.yml index 2c34b4b..5a4ad64 100644 --- a/.github/workflows/full-dataset-quality-gate-regression.yml +++ b/.github/workflows/full-dataset-quality-gate-regression.yml @@ -1,6 +1,24 @@ name: Full-Dataset Quality Gate Regression on: + workflow_call: + inputs: + run_training: + required: false + default: true + type: boolean + dataset_glob: + required: false + default: "ser/dataset/ravdess/Actor_*/*.wav" + type: string + out_file: + required: false + default: "profile_quality_gate_report_full.json" + type: string + progress_every: + required: false + default: "120" + type: string workflow_dispatch: inputs: run_training: diff --git a/.github/workflows/linux-python-3_13-cli-validation.yml b/.github/workflows/linux-python-3_13-cli-validation.yml index 2152a55..97df5f4 100644 --- a/.github/workflows/linux-python-3_13-cli-validation.yml +++ b/.github/workflows/linux-python-3_13-cli-validation.yml @@ -1,6 +1,20 @@ name: Linux Python 3.13 CLI Validation on: + workflow_call: + inputs: + run_accurate_research: + required: false + default: false + type: boolean + accurate_model_id: + required: false + default: "openai/whisper-tiny" + type: string + accurate_research_model_id: + required: false + default: "iic/emotion2vec_plus_large" + type: string workflow_dispatch: inputs: run_accurate_research: diff --git a/.github/workflows/linux-selfhosted-gpu-validation.yml b/.github/workflows/linux-selfhosted-gpu-validation.yml index b0ce6bc..15b3380 100644 --- a/.github/workflows/linux-selfhosted-gpu-validation.yml +++ b/.github/workflows/linux-selfhosted-gpu-validation.yml @@ -1,6 +1,40 @@ name: Linux Self-Hosted GPU Validation on: + workflow_call: + inputs: + python_version: + required: false + default: "3.12" + type: string + run_cuda: + required: false + default: true + type: boolean + run_xpu: + required: false + default: false + type: boolean + cuda_runner_labels_json: + required: false + default: '["self-hosted","linux","x64","cuda"]' + type: string + xpu_runner_labels_json: + required: false + default: '["self-hosted","linux","x64","xpu"]' + type: string + accurate_model_id: + required: false + default: "openai/whisper-tiny" + type: string + run_accurate_research: + required: false + default: false + type: boolean + accurate_research_model_id: + required: false + default: "iic/emotion2vec_plus_large" + type: string workflow_dispatch: inputs: python_version: @@ -139,7 +173,7 @@ jobs: tests/test_torch_inference.py tests/test_feature_runtime_policy.py tests/test_transcription_runtime_policy.py - tests/test_transcript_extractor.py + tests/suites/integration/test_transcript_extractor.py - name: Medium profile train and predict (CUDA lane) run: | diff --git a/.github/workflows/macos15-mps-validation.yml b/.github/workflows/macos15-mps-validation.yml index f4838d0..4795ee0 100644 --- a/.github/workflows/macos15-mps-validation.yml +++ b/.github/workflows/macos15-mps-validation.yml @@ -1,6 +1,24 @@ name: macOS 15 MPS Validation on: + workflow_call: + inputs: + python_version: + required: false + default: "3.12" + type: string + accurate_model_id: + required: false + default: "openai/whisper-tiny" + type: string + run_accurate_research: + required: false + default: false + type: boolean + accurate_research_model_id: + required: false + default: "iic/emotion2vec_plus_large" + type: string workflow_dispatch: inputs: python_version: diff --git a/.github/workflows/python-publish-testpypi.yml b/.github/workflows/python-publish-testpypi.yml index cfbe2f7..1327559 100644 --- a/.github/workflows/python-publish-testpypi.yml +++ b/.github/workflows/python-publish-testpypi.yml @@ -58,8 +58,18 @@ jobs: `CI verified for commit ${sha} via run #${ok.run_number}.` ); - build-distributions: + linux-cli-validation: + needs: verify-ci + uses: ./.github/workflows/linux-python-3_13-cli-validation.yml + with: + run_accurate_research: false + + darwin-validation: needs: verify-ci + uses: ./.github/workflows/darwin-x86_64-validation.yml + + build-distributions: + needs: [verify-ci, linux-cli-validation, darwin-validation] runs-on: ubuntu-latest steps: @@ -84,7 +94,7 @@ jobs: path: dist/ publish-to-testpypi: - needs: [verify-ci, build-distributions] + needs: [verify-ci, linux-cli-validation, darwin-validation, build-distributions] runs-on: ubuntu-latest environment: diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index c1ae090..57caf21 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -58,8 +58,46 @@ jobs: `CI verified for commit ${sha} via run #${ok.run_number}.` ); - build-distributions: + linux-cli-validation: + needs: verify-ci + uses: ./.github/workflows/linux-python-3_13-cli-validation.yml + with: + run_accurate_research: false + + darwin-validation: needs: verify-ci + uses: ./.github/workflows/darwin-x86_64-validation.yml + + macos-mps-validation: + needs: verify-ci + uses: ./.github/workflows/macos15-mps-validation.yml + with: + run_accurate_research: false + + selfhosted-gpu-validation: + needs: verify-ci + continue-on-error: true + uses: ./.github/workflows/linux-selfhosted-gpu-validation.yml + with: + run_cuda: true + run_xpu: false + run_accurate_research: false + + full-dataset-quality-gate: + needs: verify-ci + continue-on-error: true + uses: ./.github/workflows/full-dataset-quality-gate-regression.yml + with: + run_training: true + + build-distributions: + needs: + [ + verify-ci, + linux-cli-validation, + darwin-validation, + macos-mps-validation, + ] runs-on: ubuntu-latest steps: @@ -84,7 +122,14 @@ jobs: path: dist/ publish-to-pypi: - needs: [verify-ci, build-distributions] + needs: + [ + verify-ci, + linux-cli-validation, + darwin-validation, + macos-mps-validation, + build-distributions, + ] runs-on: ubuntu-latest environment: diff --git a/pyproject.toml b/pyproject.toml index 11e2f09..cb6d4d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -183,5 +183,9 @@ reportMissingTypeStubs = "none" testpaths = ["tests"] addopts = "-ra --strict-config --strict-markers" markers = [ + "unit: narrow owner or helper tests with no multi-module workflow orchestration", + "integration: integration tests that cross module or workflow boundaries", + "process_isolation: tests covering spawned-process orchestration and contracts", + "smoke: fast user-path smoke tests through public workflows", "topology_contract: topology, boundary, or architecture behavior tests", ] diff --git a/ser/_internal/runtime/accurate_public_boundary.py b/ser/_internal/runtime/accurate_public_boundary.py index eba8681..93b2969 100644 --- a/ser/_internal/runtime/accurate_public_boundary.py +++ b/ser/_internal/runtime/accurate_public_boundary.py @@ -3,20 +3,57 @@ from __future__ import annotations import logging +import multiprocessing as mp from collections.abc import Callable -from dataclasses import replace +from dataclasses import dataclass, replace from functools import partial -from typing import Any, Protocol, TypeVar, cast +from multiprocessing.connection import Connection +from multiprocessing.process import BaseProcess +from typing import Any, Literal, Protocol, TypeVar, cast import numpy as np from numpy.typing import NDArray +from ser._internal.runtime.process_timeout import ( + run_with_process_timeout as _run_with_process_timeout_impl, +) +from ser._internal.runtime.single_flight import SingleFlightRegistry +from ser._internal.runtime.worker_bindings import ( + is_setup_complete_message as _is_setup_complete_message_binding, +) +from ser._internal.runtime.worker_bindings import ( + parse_worker_completion_message as _parse_worker_completion_message_binding, +) +from ser._internal.runtime.worker_bindings import raise_worker_error as _raise_worker_error_binding +from ser._internal.runtime.worker_bindings import ( + recv_worker_message as _recv_worker_message_binding, +) +from ser._internal.runtime.worker_bindings import run_worker_entry as _run_worker_entry_binding +from ser._internal.runtime.worker_bindings import ( + terminate_worker_process as _terminate_worker_process_binding, +) +from ser._internal.runtime.worker_lifecycle import ( + is_setup_complete_message as _is_setup_complete_message_impl, +) +from ser._internal.runtime.worker_lifecycle import ( + parse_worker_completion_message as _parse_worker_completion_message_impl, +) +from ser._internal.runtime.worker_lifecycle import raise_worker_error as _raise_worker_error_impl +from ser._internal.runtime.worker_lifecycle import recv_worker_message as _recv_worker_message_impl +from ser._internal.runtime.worker_lifecycle import ( + run_process_setup_compute_handshake as _run_process_setup_compute_handshake_impl, +) +from ser._internal.runtime.worker_lifecycle import run_with_timeout as _run_with_timeout_impl +from ser._internal.runtime.worker_lifecycle import ( + terminate_worker_process as _terminate_worker_process_impl, +) from ser.config import AppConfig, ProfileRuntimeConfig +from ser.models.emotion_model import load_model from ser.models.profile_runtime import ( resolve_accurate_model_id, resolve_accurate_research_model_id, ) -from ser.repr import FeatureBackend +from ser.repr import Emotion2VecBackend, FeatureBackend, WhisperBackend from ser.repr.runtime_policy import resolve_feature_runtime_policy from ser.runtime import mps_oom as mps_oom_helpers from ser.runtime.accurate_backend_runtime import ( @@ -62,6 +99,9 @@ from ser.runtime.accurate_runtime_support import ( build_cpu_settings_snapshot as _build_cpu_settings_snapshot_impl, ) +from ser.runtime.accurate_runtime_support import ( + build_process_settings_snapshot as _build_process_settings_snapshot_impl, +) from ser.runtime.accurate_runtime_support import ( encode_accurate_sequence as _encode_accurate_sequence_impl, ) @@ -76,6 +116,11 @@ from ser.runtime.accurate_worker_operation import ( finalize_in_process_setup as _finalize_in_process_setup_impl, ) +from ser.runtime.accurate_worker_operation import prepare_retry_state as _prepare_retry_state_impl +from ser.runtime.accurate_worker_operation import ( + run_inference_operation as _run_inference_operation_impl, +) +from ser.runtime.contracts import InferenceRequest from ser.runtime.phase_contract import PHASE_EMOTION_INFERENCE, PHASE_EMOTION_SETUP from ser.runtime.phase_timing import ( log_phase_completed, @@ -87,6 +132,7 @@ jittered_retry_delay_seconds as _jittered_retry_delay_seconds_impl, ) from ser.runtime.schema import InferenceResult +from ser.utils.audio_utils import read_audio_file class _AccurateLoadedModel(ArtifactMetadataCarrier, _AccurateLoadedModelLike, Protocol): @@ -96,6 +142,118 @@ class _AccurateLoadedModel(ArtifactMetadataCarrier, _AccurateLoadedModelLike, Pr _AccurateLoadedModelT = TypeVar("_AccurateLoadedModelT", bound=_AccurateLoadedModel) _AccuratePayloadT = TypeVar("_AccuratePayloadT", bound=_AccuratePayloadLike) +type WorkerPhaseMessage = tuple[Literal["phase"], Literal["setup_complete"]] +type WorkerSuccessMessage = tuple[Literal["ok"], InferenceResult] +type WorkerErrorMessage = tuple[Literal["err"], str, str] +type WorkerMessage = WorkerPhaseMessage | WorkerSuccessMessage | WorkerErrorMessage + +_TERMINATE_GRACE_SECONDS = 0.5 +_KILL_GRACE_SECONDS = 0.5 +_SINGLE_FLIGHT_REGISTRY = SingleFlightRegistry() +_WORKER_LOGGER = logging.getLogger("ser.runtime.accurate_inference") + + +@dataclass(frozen=True) +class AccurateProcessPayload: + """Serializable payload for one process-isolated accurate inference attempt.""" + + request: InferenceRequest + settings: AppConfig + expected_backend_id: str + expected_profile: str + expected_backend_model_id: str | None + + +class AccurateModelUnavailableError(FileNotFoundError): + """Spawn-safe accurate worker error marker for unavailable model artifacts.""" + + +class AccurateRuntimeDependencyError(RuntimeError): + """Spawn-safe accurate worker error marker for missing runtime dependencies.""" + + +class AccurateModelLoadError(RuntimeError): + """Spawn-safe accurate worker error marker for model load failures.""" + + +class AccurateTransientBackendError(RuntimeError): + """Spawn-safe accurate worker error marker for transient backend failures.""" + + +def _build_backend_for_worker_profile( + *, + expected_backend_id: str, + expected_backend_model_id: str | None, + settings: AppConfig, +) -> FeatureBackend: + """Builds one accurate worker backend with spawn-safe error types.""" + return build_backend_for_profile( + expected_backend_id=expected_backend_id, + expected_backend_model_id=expected_backend_model_id, + settings=settings, + whisper_backend_factory=WhisperBackend, + emotion2vec_backend_factory=Emotion2VecBackend, + unsupported_backend_error=AccurateModelUnavailableError, + ) + + +def _prepare_accurate_process_operation( + payload: AccurateProcessPayload, +) -> PreparedAccurateOperation[_AccurateLoadedModel, FeatureBackend]: + """Builds one accurate worker operation using module-level collaborators.""" + return prepare_process_operation( + payload, + load_model_fn=load_model, + read_audio_file_fn=read_audio_file, + build_backend_for_profile_fn=_build_backend_for_worker_profile, + logger=_WORKER_LOGGER, + model_unavailable_error_factory=AccurateModelUnavailableError, + model_load_error_factory=AccurateModelLoadError, + runtime_dependency_error_factory=AccurateRuntimeDependencyError, + transient_error_factory=AccurateTransientBackendError, + ) + + +def _run_accurate_process_inference_once( + *, + loaded_model: _AccurateLoadedModel, + backend: FeatureBackend, + audio: NDArray[np.float32], + sample_rate: int, + runtime_config: ProfileRuntimeConfig, +) -> InferenceResult: + """Runs one accurate worker-process inference attempt.""" + return run_accurate_inference_once( + loaded_model=loaded_model, + backend=backend, + audio=audio, + sample_rate=sample_rate, + runtime_config=runtime_config, + logger=_WORKER_LOGGER, + dependency_error_factory=AccurateRuntimeDependencyError, + transient_error_factory=AccurateTransientBackendError, + ) + + +def _run_accurate_process_operation( + prepared: PreparedAccurateOperation[_AccurateLoadedModel, FeatureBackend], +) -> InferenceResult: + """Runs one accurate compute phase inside the spawned worker process.""" + return run_process_operation( + prepared, + run_accurate_inference_once=_run_accurate_process_inference_once, + ) + + +def _accurate_worker_entry(payload: AccurateProcessPayload, connection: Connection) -> None: + """Executes one spawned accurate worker using module-level collaborators only.""" + _run_worker_entry_binding( + payload=payload, + connection=connection, + prepare_process_operation=_prepare_accurate_process_operation, + run_process_operation=_run_accurate_process_operation, + ) + def run_accurate_inference_once( *, @@ -135,6 +293,266 @@ def run_accurate_inference_once( ) +def run_accurate_inference_from_public_boundary( + request: InferenceRequest, + settings: AppConfig, + *, + loaded_model: _AccurateLoadedModel | None = None, + backend: FeatureBackend | None = None, + enforce_timeout: bool = True, + allow_retries: bool = True, + expected_backend_id: str = "hf_whisper", + expected_profile: str = "accurate", + expected_backend_model_id: str | None = None, + logger: logging.Logger, + model_unavailable_error_type: type[Exception], + runtime_dependency_error_type: type[Exception], + model_load_error_type: type[Exception], + timeout_error_type: type[Exception], + inference_execution_error_type: type[Exception], + transient_backend_error_type: type[Exception], +) -> InferenceResult: + """Runs accurate inference through the internal public-boundary owner.""" + + worker_error_factories: dict[str, Callable[[str], Exception]] = { + "ValueError": ValueError, + runtime_dependency_error_type.__name__: runtime_dependency_error_type, + transient_backend_error_type.__name__: transient_backend_error_type, + model_unavailable_error_type.__name__: model_unavailable_error_type, + model_load_error_type.__name__: model_load_error_type, + timeout_error_type.__name__: timeout_error_type, + "RuntimeError": RuntimeError, + } + + def _run_with_process_timeout( + payload: AccurateProcessPayload, + *, + timeout_seconds: float, + ) -> InferenceResult: + return _run_with_process_timeout_impl( + payload=payload, + resolve_profile=lambda active_payload: active_payload.expected_profile, + timeout_seconds=timeout_seconds, + get_context=mp.get_context, + logger=logger, + setup_phase_name=PHASE_EMOTION_SETUP, + inference_phase_name=PHASE_EMOTION_INFERENCE, + log_phase_started=log_phase_started, + log_phase_completed=log_phase_completed, + log_phase_failed=log_phase_failed, + run_process_setup_compute_handshake=_run_process_setup_compute_handshake_impl, + worker_target=_accurate_worker_entry, + recv_worker_message=_recv_worker_message, + is_setup_complete_message=_is_setup_complete_message, + terminate_worker_process=_terminate_worker_process, + timeout_error_factory=timeout_error_type, + execution_error_factory=inference_execution_error_type, + worker_label="Accurate inference", + process_join_grace_seconds=_TERMINATE_GRACE_SECONDS, + parse_worker_completion_message=_parse_worker_completion_message, + ) + + def _recv_worker_message( + connection: Connection, + *, + stage: str, + ) -> WorkerMessage: + return _recv_worker_message_binding( + connection=connection, + stage=stage, + impl=_recv_worker_message_impl, + worker_label="Accurate inference", + error_factory=inference_execution_error_type, + ) + + def _is_setup_complete_message(message: WorkerMessage) -> bool: + return _is_setup_complete_message_binding( + message=message, + impl=_is_setup_complete_message_impl, + worker_label="Accurate inference", + error_factory=inference_execution_error_type, + ) + + def _parse_worker_completion_message(worker_message: WorkerMessage) -> InferenceResult: + return _parse_worker_completion_message_binding( + worker_message=worker_message, + impl=_parse_worker_completion_message_impl, + worker_label="Accurate inference", + error_factory=inference_execution_error_type, + raise_worker_error=_raise_worker_error, + result_type=InferenceResult, + ) + + def _prepare_in_process_accurate_operation( + *, + request: _AccurateRequestLike, + settings: AppConfig, + runtime_config: ProfileRuntimeConfig, + loaded_model: _AccurateLoadedModel | None, + backend: FeatureBackend | None, + expected_backend_id: str, + expected_profile: str, + expected_backend_model_id: str | None, + ) -> PreparedAccurateOperation[_AccurateLoadedModel, FeatureBackend]: + return prepare_in_process_accurate_operation( + request=request, + settings=settings, + runtime_config=runtime_config, + loaded_model=loaded_model, + backend=backend, + expected_backend_id=expected_backend_id, + expected_profile=expected_profile, + expected_backend_model_id=expected_backend_model_id, + load_model_fn=load_model, + read_audio_file_fn=read_audio_file, + build_backend_for_profile_fn=_build_backend_for_profile, + logger=logger, + model_unavailable_error_factory=model_unavailable_error_type, + model_load_error_factory=model_load_error_type, + ) + + def _run_accurate_inference_once( + *, + loaded_model: _AccurateLoadedModel, + backend: FeatureBackend, + audio: NDArray[np.float32], + sample_rate: int, + runtime_config: ProfileRuntimeConfig, + ) -> InferenceResult: + return run_accurate_inference_once( + loaded_model=loaded_model, + backend=backend, + audio=audio, + sample_rate=sample_rate, + runtime_config=runtime_config, + logger=logger, + dependency_error_factory=runtime_dependency_error_type, + transient_error_factory=transient_backend_error_type, + ) + + def _build_backend_for_profile( + *, + expected_backend_id: str, + expected_backend_model_id: str | None, + settings: AppConfig, + expected_profile: str | None = None, + ) -> FeatureBackend: + del expected_profile + return build_backend_for_profile( + expected_backend_id=expected_backend_id, + expected_backend_model_id=expected_backend_model_id, + settings=settings, + whisper_backend_factory=WhisperBackend, + emotion2vec_backend_factory=Emotion2VecBackend, + unsupported_backend_error=model_unavailable_error_type, + ) + + def _terminate_worker_process(process: BaseProcess) -> None: + _terminate_worker_process_binding( + process=process, + impl=_terminate_worker_process_impl, + terminate_grace_seconds=_TERMINATE_GRACE_SECONDS, + kill_grace_seconds=_KILL_GRACE_SECONDS, + ) + + def _raise_worker_error(error_type: str, message: str) -> None: + _raise_worker_error_binding( + error_type=error_type, + message=message, + impl=_raise_worker_error_impl, + known_error_factories=worker_error_factories, + unknown_error_factory=inference_execution_error_type, + worker_label="Accurate inference", + ) + + runtime_config = _runtime_config_for_profile_impl( + settings=settings, + expected_profile=expected_profile, + unsupported_profile_error=model_unavailable_error_type, + ) + resolved_expected_backend_model_id = expected_backend_model_id + if resolved_expected_backend_model_id is None and expected_backend_id == "hf_whisper": + resolved_expected_backend_model_id = resolve_accurate_model_id(settings) + use_process_isolation = ( + enforce_timeout + and loaded_model is None + and backend is None + and settings.runtime_flags.profile_pipeline + and runtime_config.process_isolation + ) + process_payload: AccurateProcessPayload | None = None + if use_process_isolation: + process_payload = AccurateProcessPayload( + request=request, + settings=_build_process_settings_snapshot_impl(settings), + expected_backend_id=expected_backend_id, + expected_profile=expected_profile, + expected_backend_model_id=resolved_expected_backend_model_id, + ) + cpu_settings = _build_cpu_settings_snapshot_impl(settings) + cpu_backend_builder: Callable[[], FeatureBackend] = partial( + _build_backend_for_profile, + expected_backend_id=expected_backend_id, + expected_backend_model_id=resolved_expected_backend_model_id, + settings=cpu_settings, + expected_profile=expected_profile, + ) + + retry_state, prepared_operation, setup_started_at = _prepare_retry_state_impl( + use_process_isolation=use_process_isolation, + request=request, + settings=settings, + runtime_config=runtime_config, + loaded_model=loaded_model, + backend=backend, + logger=logger, + profile=expected_profile, + setup_phase_name=PHASE_EMOTION_SETUP, + log_phase_started=log_phase_started, + log_phase_failed=log_phase_failed, + process_payload=process_payload, + prepare_in_process_operation=partial( + _prepare_in_process_accurate_operation, + expected_backend_id=expected_backend_id, + expected_profile=expected_profile, + expected_backend_model_id=resolved_expected_backend_model_id, + ), + ) + with _SINGLE_FLIGHT_REGISTRY.lock( + profile=expected_profile, + backend_model_id=resolved_expected_backend_model_id, + ): + return execute_accurate_inference_with_retry( + use_process_isolation=use_process_isolation, + retry_state=retry_state, + prepared_operation=prepared_operation, + setup_started_at=setup_started_at, + settings=settings, + backend=backend, + expected_backend_id=expected_backend_id, + expected_profile=expected_profile, + allow_retries=allow_retries, + enforce_timeout=enforce_timeout, + cpu_backend_builder=cpu_backend_builder, + logger=logger, + run_accurate_retryable_operation=lambda **kwargs: run_accurate_retryable_operation( + logger=logger, + run_with_process_timeout=_run_with_process_timeout, + run_accurate_inference_once=_run_accurate_inference_once, + run_with_timeout=_run_with_timeout_impl, + run_inference_operation=_run_inference_operation_impl, + timeout_error_factory=timeout_error_type, + **kwargs, + ), + retry_delay_seconds=retry_delay_seconds, + process_payload_cpu_fallback=payload_with_cpu_settings, + timeout_error_type=timeout_error_type, + runtime_dependency_error_type=runtime_dependency_error_type, + inference_execution_error_type=inference_execution_error_type, + transient_backend_error_type=transient_backend_error_type, + ) + + def run_accurate_retryable_operation( *, enforce_timeout: bool, @@ -430,6 +848,7 @@ def payload_with_cpu_settings(payload: _AccuratePayloadT) -> _AccuratePayloadT: __all__ = [ + "run_accurate_inference_from_public_boundary", "build_backend_for_profile", "execute_accurate_inference_with_retry", "payload_with_cpu_settings", diff --git a/ser/_internal/runtime/fast_public_boundary.py b/ser/_internal/runtime/fast_public_boundary.py new file mode 100644 index 0000000..e9cc2e1 --- /dev/null +++ b/ser/_internal/runtime/fast_public_boundary.py @@ -0,0 +1,437 @@ +"""Internal support owner for fast inference public-boundary wrappers.""" + +from __future__ import annotations + +import logging +import multiprocessing as mp +from collections.abc import Callable +from dataclasses import dataclass, replace +from multiprocessing.connection import Connection +from multiprocessing.process import BaseProcess +from typing import Literal + +from ser._internal.runtime.process_timeout import ( + run_with_process_timeout as _run_with_process_timeout_impl, +) +from ser._internal.runtime.single_flight import SingleFlightRegistry +from ser._internal.runtime.worker_bindings import ( + is_setup_complete_message as _is_setup_complete_message_binding, +) +from ser._internal.runtime.worker_bindings import ( + parse_worker_completion_message as _parse_worker_completion_message_binding, +) +from ser._internal.runtime.worker_bindings import raise_worker_error as _raise_worker_error_binding +from ser._internal.runtime.worker_bindings import ( + recv_worker_message as _recv_worker_message_binding, +) +from ser._internal.runtime.worker_bindings import run_worker_entry as _run_worker_entry_binding +from ser._internal.runtime.worker_bindings import ( + terminate_worker_process as _terminate_worker_process_binding, +) +from ser._internal.runtime.worker_lifecycle import ( + is_setup_complete_message as _is_setup_complete_message_impl, +) +from ser._internal.runtime.worker_lifecycle import ( + parse_worker_completion_message as _parse_worker_completion_message_impl, +) +from ser._internal.runtime.worker_lifecycle import raise_worker_error as _raise_worker_error_impl +from ser._internal.runtime.worker_lifecycle import recv_worker_message as _recv_worker_message_impl +from ser._internal.runtime.worker_lifecycle import ( + run_process_setup_compute_handshake as _run_process_setup_compute_handshake_impl, +) +from ser._internal.runtime.worker_lifecycle import run_with_timeout as _run_with_timeout_impl +from ser._internal.runtime.worker_lifecycle import ( + terminate_worker_process as _terminate_worker_process_impl, +) +from ser.config import AppConfig +from ser.models.emotion_model import LoadedModel, load_model, predict_emotions_detailed +from ser.runtime.contracts import InferenceRequest +from ser.runtime.phase_contract import PHASE_EMOTION_INFERENCE, PHASE_EMOTION_SETUP +from ser.runtime.phase_timing import ( + log_phase_completed, + log_phase_failed, + log_phase_started, +) +from ser.runtime.policy import run_with_retry_policy +from ser.runtime.schema import InferenceResult + +type WorkerPhaseMessage = tuple[Literal["phase"], Literal["setup_complete"]] +type WorkerSuccessMessage = tuple[Literal["ok"], InferenceResult] +type WorkerErrorMessage = tuple[Literal["err"], str, str] +type WorkerMessage = WorkerPhaseMessage | WorkerSuccessMessage | WorkerErrorMessage + +_TERMINATE_GRACE_SECONDS = 0.5 +_KILL_GRACE_SECONDS = 0.5 +_SINGLE_FLIGHT_REGISTRY = SingleFlightRegistry() + + +@dataclass(frozen=True) +class FastProcessPayload: + """Serializable payload for one process-isolated fast inference attempt.""" + + request: InferenceRequest + settings: AppConfig + + +@dataclass(frozen=True) +class _PreparedFastOperation: + """Holds setup-complete data for one fast worker compute phase.""" + + loaded_model: LoadedModel + request: InferenceRequest + + +class FastModelUnavailableError(FileNotFoundError): + """Spawn-safe fast worker error marker for unavailable model artifacts.""" + + +class FastModelLoadError(RuntimeError): + """Spawn-safe fast worker error marker for model load failures.""" + + +def _load_fast_model_for_worker( + settings: AppConfig, + *, + loaded_model: LoadedModel | None, +) -> LoadedModel: + """Loads one fast model for the spawned worker process.""" + if loaded_model is None: + try: + return load_model( + settings=settings, + expected_backend_id="handcrafted", + expected_profile="fast", + ) + except FileNotFoundError as err: + raise FastModelUnavailableError(str(err)) from err + except ValueError as err: + raise FastModelLoadError( + "Failed to load fast-profile model artifact from configured paths." + ) from err + _ensure_fast_compatible_model(loaded_model, FastModelUnavailableError) + return loaded_model + + +def _prepare_fast_process_operation(payload: FastProcessPayload) -> _PreparedFastOperation: + """Builds one fast worker operation using module-level collaborators.""" + active_loaded_model = _load_fast_model_for_worker(payload.settings, loaded_model=None) + return _PreparedFastOperation(loaded_model=active_loaded_model, request=payload.request) + + +def _run_fast_process_operation(prepared: _PreparedFastOperation) -> InferenceResult: + """Runs one fast compute phase inside the spawned worker process.""" + return predict_emotions_detailed( + prepared.request.file_path, + loaded_model=prepared.loaded_model, + ) + + +def _fast_worker_entry(payload: FastProcessPayload, connection: Connection) -> None: + """Executes one spawned fast worker using module-level collaborators only.""" + _run_worker_entry_binding( + payload=payload, + connection=connection, + prepare_process_operation=_prepare_fast_process_operation, + run_process_operation=_run_fast_process_operation, + ) + + +def run_fast_inference_from_public_boundary( + request: InferenceRequest, + settings: AppConfig, + *, + loaded_model: LoadedModel | None = None, + enforce_timeout: bool = True, + allow_retries: bool = True, + logger: logging.Logger, + model_unavailable_error_type: type[Exception], + model_load_error_type: type[Exception], + timeout_error_type: type[Exception], + execution_error_type: type[Exception], + transient_error_type: type[Exception], +) -> InferenceResult: + """Runs fast inference through the internal public-boundary owner.""" + + worker_error_factories: dict[str, Callable[[str], Exception]] = { + "ValueError": ValueError, + model_unavailable_error_type.__name__: model_unavailable_error_type, + model_load_error_type.__name__: model_load_error_type, + timeout_error_type.__name__: timeout_error_type, + transient_error_type.__name__: transient_error_type, + "RuntimeError": RuntimeError, + } + + def _retry_delay_seconds(base_delay: float, attempt: int) -> float: + if base_delay <= 0.0: + return 0.0 + return base_delay * float(attempt) + + def _load_fast_model( + settings: AppConfig, + *, + loaded_model: LoadedModel | None, + ) -> LoadedModel: + if loaded_model is None: + try: + return load_model( + settings=settings, + expected_backend_id="handcrafted", + expected_profile="fast", + ) + except FileNotFoundError as err: + raise model_unavailable_error_type(str(err)) from err + except ValueError as err: + raise model_load_error_type( + "Failed to load fast-profile model artifact from configured paths." + ) from err + _ensure_fast_compatible_model(loaded_model, model_unavailable_error_type) + return loaded_model + + def _run_fast_inference_once( + *, + request: InferenceRequest, + loaded_model: LoadedModel | None, + settings: AppConfig, + ) -> InferenceResult: + active_loaded_model = _load_fast_model(settings, loaded_model=loaded_model) + return predict_emotions_detailed( + request.file_path, + loaded_model=active_loaded_model, + ) + + def _recv_worker_message( + connection: Connection, + *, + stage: str, + ) -> WorkerMessage: + return _recv_worker_message_binding( + connection=connection, + stage=stage, + impl=_recv_worker_message_impl, + worker_label="Fast inference", + error_factory=execution_error_type, + ) + + def _is_setup_complete_message(message: WorkerMessage) -> bool: + return _is_setup_complete_message_binding( + message=message, + impl=_is_setup_complete_message_impl, + worker_label="Fast inference", + error_factory=execution_error_type, + ) + + def _raise_worker_error(error_type: str, message: str) -> None: + _raise_worker_error_binding( + error_type=error_type, + message=message, + impl=_raise_worker_error_impl, + known_error_factories=worker_error_factories, + unknown_error_factory=execution_error_type, + worker_label="Fast inference", + ) + + def _parse_worker_completion_message(worker_message: WorkerMessage) -> InferenceResult: + return _parse_worker_completion_message_binding( + worker_message=worker_message, + impl=_parse_worker_completion_message_impl, + worker_label="Fast inference", + error_factory=execution_error_type, + raise_worker_error=_raise_worker_error, + result_type=InferenceResult, + ) + + def _terminate_worker_process(process: BaseProcess) -> None: + _terminate_worker_process_binding( + process=process, + impl=_terminate_worker_process_impl, + terminate_grace_seconds=_TERMINATE_GRACE_SECONDS, + kill_grace_seconds=_KILL_GRACE_SECONDS, + ) + + def _run_with_process_timeout( + payload: FastProcessPayload, + *, + timeout_seconds: float, + ) -> InferenceResult: + return _run_with_process_timeout_impl( + payload=payload, + resolve_profile=lambda _payload: "fast", + timeout_seconds=timeout_seconds, + get_context=mp.get_context, + logger=logger, + setup_phase_name=PHASE_EMOTION_SETUP, + inference_phase_name=PHASE_EMOTION_INFERENCE, + log_phase_started=log_phase_started, + log_phase_completed=log_phase_completed, + log_phase_failed=log_phase_failed, + run_process_setup_compute_handshake=_run_process_setup_compute_handshake_impl, + worker_target=_fast_worker_entry, + recv_worker_message=_recv_worker_message, + is_setup_complete_message=_is_setup_complete_message, + terminate_worker_process=_terminate_worker_process, + timeout_error_factory=timeout_error_type, + execution_error_factory=execution_error_type, + worker_label="Fast inference", + process_join_grace_seconds=_TERMINATE_GRACE_SECONDS, + parse_worker_completion_message=_parse_worker_completion_message, + ) + + runtime_config = settings.fast_runtime + use_process_isolation = ( + enforce_timeout + and loaded_model is None + and settings.runtime_flags.profile_pipeline + and runtime_config.process_isolation + ) + + process_payload: FastProcessPayload | None = None + active_loaded_model: LoadedModel | None = None + setup_started_at: float | None = None + if use_process_isolation: + process_payload = FastProcessPayload( + request=request, + settings=replace(settings, emotions=dict(settings.emotions)), + ) + else: + setup_started_at = log_phase_started( + logger, + phase_name=PHASE_EMOTION_SETUP, + profile="fast", + ) + try: + active_loaded_model = _load_fast_model(settings, loaded_model=loaded_model) + except Exception: + log_phase_failed( + logger, + phase_name=PHASE_EMOTION_SETUP, + started_at=setup_started_at, + profile="fast", + ) + raise + log_phase_completed( + logger, + phase_name=PHASE_EMOTION_SETUP, + started_at=setup_started_at, + profile="fast", + ) + setup_started_at = None + + with _SINGLE_FLIGHT_REGISTRY.lock(profile="fast", backend_model_id=None): + + def operation() -> InferenceResult: + if enforce_timeout: + if use_process_isolation: + if process_payload is None: + raise RuntimeError( + "Fast process payload is missing for isolated execution." + ) + return _run_with_process_timeout( + process_payload, + timeout_seconds=runtime_config.timeout_seconds, + ) + inference_started_at = log_phase_started( + logger, + phase_name=PHASE_EMOTION_INFERENCE, + profile="fast", + ) + try: + result = _run_with_timeout_impl( + operation=lambda: _run_fast_inference_once( + request=request, + loaded_model=active_loaded_model, + settings=settings, + ), + timeout_seconds=runtime_config.timeout_seconds, + timeout_error_factory=timeout_error_type, + timeout_label="Fast inference", + ) + except Exception: + log_phase_failed( + logger, + phase_name=PHASE_EMOTION_INFERENCE, + started_at=inference_started_at, + profile="fast", + ) + raise + log_phase_completed( + logger, + phase_name=PHASE_EMOTION_INFERENCE, + started_at=inference_started_at, + profile="fast", + ) + return result + inference_started_at = log_phase_started( + logger, + phase_name=PHASE_EMOTION_INFERENCE, + profile="fast", + ) + try: + result = _run_fast_inference_once( + request=request, + loaded_model=active_loaded_model, + settings=settings, + ) + except Exception: + log_phase_failed( + logger, + phase_name=PHASE_EMOTION_INFERENCE, + started_at=inference_started_at, + profile="fast", + ) + raise + log_phase_completed( + logger, + phase_name=PHASE_EMOTION_INFERENCE, + started_at=inference_started_at, + profile="fast", + ) + return result + + try: + return run_with_retry_policy( + operation=operation, + runtime_config=runtime_config, + allow_retries=allow_retries, + profile_label="Fast", + timeout_error_type=timeout_error_type, + transient_error_type=transient_error_type, + transient_exhausted_error=lambda _err: execution_error_type( + "Fast inference exhausted retry budget after backend failures." + ), + retry_delay_seconds=_retry_delay_seconds, + logger=logger, + ) + except ValueError: + raise + except execution_error_type: + raise + except RuntimeError as err: + raise execution_error_type( + "Fast inference failed with a non-retryable runtime error." + ) from err + + +def _ensure_fast_compatible_model( + loaded_model: LoadedModel, + unavailable_error_factory: Callable[[str], Exception] | type[Exception], +) -> None: + """Validates that loaded artifact metadata is compatible with fast runtime.""" + metadata = loaded_model.artifact_metadata + if not isinstance(metadata, dict): + raise unavailable_error_factory( + "Fast profile requires a v2 model artifact metadata envelope. " + "Train a fast-profile model before inference." + ) + if metadata.get("backend_id") != "handcrafted": + raise unavailable_error_factory( + "No fast-profile model artifact is available. " + f"Found backend_id={metadata.get('backend_id')!r}; expected 'handcrafted'." + ) + if metadata.get("profile") != "fast": + raise unavailable_error_factory( + "No fast-profile model artifact is available. " + f"Found profile={metadata.get('profile')!r}; expected 'fast'." + ) + + +__all__ = ["run_fast_inference_from_public_boundary"] diff --git a/ser/_internal/runtime/medium_public_boundary.py b/ser/_internal/runtime/medium_public_boundary.py index e4db36a..0e21e82 100644 --- a/ser/_internal/runtime/medium_public_boundary.py +++ b/ser/_internal/runtime/medium_public_boundary.py @@ -3,15 +3,54 @@ from __future__ import annotations import logging +import multiprocessing as mp from collections.abc import Callable -from dataclasses import replace +from dataclasses import dataclass, replace from functools import partial -from typing import Any, Protocol, TypeVar, cast +from multiprocessing.connection import Connection +from multiprocessing.process import BaseProcess +from typing import Any, Literal, Protocol, TypeVar, cast import numpy as np from numpy.typing import NDArray +from ser._internal.runtime.process_timeout import ( + run_with_process_timeout as _run_with_process_timeout_impl, +) +from ser._internal.runtime.single_flight import SingleFlightRegistry +from ser._internal.runtime.worker_bindings import ( + is_setup_complete_message as _is_setup_complete_message_binding, +) +from ser._internal.runtime.worker_bindings import ( + parse_worker_completion_message as _parse_worker_completion_message_binding, +) +from ser._internal.runtime.worker_bindings import raise_worker_error as _raise_worker_error_binding +from ser._internal.runtime.worker_bindings import ( + recv_worker_message as _recv_worker_message_binding, +) +from ser._internal.runtime.worker_bindings import run_worker_entry as _run_worker_entry_binding +from ser._internal.runtime.worker_bindings import ( + terminate_worker_process as _terminate_worker_process_binding, +) +from ser._internal.runtime.worker_lifecycle import ( + is_setup_complete_message as _is_setup_complete_message_impl, +) +from ser._internal.runtime.worker_lifecycle import ( + parse_worker_completion_message as _parse_worker_completion_message_impl, +) +from ser._internal.runtime.worker_lifecycle import raise_worker_error as _raise_worker_error_impl +from ser._internal.runtime.worker_lifecycle import recv_worker_message as _recv_worker_message_impl +from ser._internal.runtime.worker_lifecycle import ( + run_process_setup_compute_handshake as _run_process_setup_compute_handshake_impl, +) +from ser._internal.runtime.worker_lifecycle import run_with_timeout as _run_with_timeout_impl +from ser._internal.runtime.worker_lifecycle import ( + terminate_worker_process as _terminate_worker_process_impl, +) from ser.config import AppConfig, MediumRuntimeConfig +from ser.models.emotion_model import LoadedModel as EmotionLoadedModel +from ser.models.emotion_model import load_model +from ser.models.profile_runtime import resolve_medium_model_id from ser.repr import XLSRBackend from ser.repr.runtime_policy import FeatureRuntimePolicy from ser.runtime import medium_execution as medium_execution_helpers @@ -45,6 +84,9 @@ from ser.runtime.medium_process_operation import ( run_process_operation as _run_process_operation_impl, ) +from ser.runtime.medium_retry_operation import ( + run_medium_inference_with_retry_policy as _run_medium_retry_policy_impl, +) from ser.runtime.medium_runtime_support import LoadedModelLike as _MediumRuntimeLoadedModelLike from ser.runtime.medium_runtime_support import ( build_cpu_settings_snapshot as _build_cpu_settings_snapshot_impl, @@ -52,6 +94,9 @@ from ser.runtime.medium_runtime_support import ( build_medium_backend_for_settings as _build_medium_backend_for_settings_impl, ) +from ser.runtime.medium_runtime_support import ( + build_runtime_settings_snapshot as _build_runtime_settings_snapshot_impl, +) from ser.runtime.medium_runtime_support import ( encode_medium_sequence as _encode_medium_sequence_impl, ) @@ -69,6 +114,7 @@ ) from ser.runtime.policy import run_with_retry_policy from ser.runtime.schema import InferenceResult +from ser.utils.audio_utils import read_audio_file class _MediumLoadedModel(_MediumExecutionLoadedModelLike, _MediumRuntimeLoadedModelLike, Protocol): @@ -78,6 +124,109 @@ class _MediumLoadedModel(_MediumExecutionLoadedModelLike, _MediumRuntimeLoadedMo _MediumLoadedModelT = TypeVar("_MediumLoadedModelT", bound=_MediumLoadedModel) _MediumPayloadT = TypeVar("_MediumPayloadT", bound=_MediumPayloadLike) +type WorkerPhaseMessage = tuple[Literal["phase"], Literal["setup_complete"]] +type WorkerSuccessMessage = tuple[Literal["ok"], InferenceResult] +type WorkerErrorMessage = tuple[Literal["err"], str, str] +type WorkerMessage = WorkerPhaseMessage | WorkerSuccessMessage | WorkerErrorMessage + +_TERMINATE_GRACE_SECONDS = 0.5 +_KILL_GRACE_SECONDS = 0.5 +_SINGLE_FLIGHT_REGISTRY = SingleFlightRegistry() +_WORKER_LOGGER = logging.getLogger("ser.runtime.medium_inference") + + +@dataclass(frozen=True) +class MediumProcessPayload: + """Serializable payload for one process-isolated medium inference attempt.""" + + request: InferenceRequest + settings: AppConfig + expected_backend_model_id: str + + +class MediumModelUnavailableError(FileNotFoundError): + """Spawn-safe medium worker error marker for unavailable model artifacts.""" + + +class MediumRuntimeDependencyError(RuntimeError): + """Spawn-safe medium worker error marker for missing runtime dependencies.""" + + +class MediumModelLoadError(RuntimeError): + """Spawn-safe medium worker error marker for model load failures.""" + + +class MediumTransientBackendError(RuntimeError): + """Spawn-safe medium worker error marker for transient backend failures.""" + + +def _prepare_medium_process_operation( + payload: MediumProcessPayload, +) -> medium_worker_operation_helpers.PreparedMediumOperation[EmotionLoadedModel, XLSRBackend]: + """Builds one medium worker operation using only module-level collaborators.""" + return prepare_process_operation( + payload, + load_model_fn=load_model, + read_audio_file_fn=read_audio_file, + backend_factory=XLSRBackend, + resolve_runtime_policy=lambda settings: resolve_medium_feature_runtime_policy( + settings=settings + ), + logger=_WORKER_LOGGER, + model_unavailable_error_factory=MediumModelUnavailableError, + model_load_error_factory=MediumModelLoadError, + prepare_medium_backend_runtime=lambda active_backend: prepare_medium_backend_runtime( + backend=active_backend, + is_dependency_error=is_dependency_error, + dependency_error_factory=MediumRuntimeDependencyError, + transient_error_factory=MediumTransientBackendError, + ), + ) + + +def _run_medium_process_inference_once( + *, + loaded_model: _MediumLoadedModel, + backend: XLSRBackend, + audio: NDArray[np.float32], + sample_rate: int, + runtime_config: MediumRuntimeConfig, +) -> InferenceResult: + """Runs one medium worker-process inference attempt.""" + return run_medium_inference_once( + loaded_model=loaded_model, + backend=backend, + audio=audio, + sample_rate=sample_rate, + runtime_config=runtime_config, + logger=_WORKER_LOGGER, + is_dependency_error=is_dependency_error, + dependency_error_factory=MediumRuntimeDependencyError, + transient_error_factory=MediumTransientBackendError, + ) + + +def _run_medium_process_operation( + prepared: medium_worker_operation_helpers.PreparedMediumOperation[ + EmotionLoadedModel, XLSRBackend + ], +) -> InferenceResult: + """Runs one medium compute phase inside the spawned worker process.""" + return run_process_operation( + prepared, + run_medium_inference_once=_run_medium_process_inference_once, + ) + + +def _medium_worker_entry(payload: MediumProcessPayload, connection: Connection) -> None: + """Executes one spawned medium worker using module-level collaborators only.""" + _run_worker_entry_binding( + payload=payload, + connection=connection, + prepare_process_operation=_prepare_medium_process_operation, + run_process_operation=_run_medium_process_operation, + ) + def run_medium_inference_once( *, @@ -116,6 +265,252 @@ def run_medium_inference_once( ) +def run_medium_inference_from_public_boundary( + request: InferenceRequest, + settings: AppConfig, + *, + loaded_model: _MediumLoadedModel | None = None, + backend: XLSRBackend | None = None, + enforce_timeout: bool = True, + allow_retries: bool = True, + logger: logging.Logger, + model_unavailable_error_type: type[Exception], + runtime_dependency_error_type: type[Exception], + model_load_error_type: type[Exception], + timeout_error_type: type[Exception], + execution_error_type: type[Exception], + transient_error_type: type[Exception], +) -> InferenceResult: + """Runs medium inference through the internal public-boundary owner.""" + + worker_error_factories: dict[str, Callable[[str], Exception]] = { + "ValueError": ValueError, + runtime_dependency_error_type.__name__: runtime_dependency_error_type, + transient_error_type.__name__: transient_error_type, + model_unavailable_error_type.__name__: model_unavailable_error_type, + model_load_error_type.__name__: model_load_error_type, + timeout_error_type.__name__: timeout_error_type, + "RuntimeError": RuntimeError, + } + + def _run_with_timeout( + operation: Callable[[], InferenceResult], + timeout_seconds: float, + ) -> InferenceResult: + return _run_with_timeout_impl( + operation=operation, + timeout_seconds=timeout_seconds, + timeout_error_factory=timeout_error_type, + timeout_label="Medium inference", + ) + + def _run_medium_inference_once( + *, + loaded_model: _MediumLoadedModel, + backend: XLSRBackend, + audio: NDArray[np.float32], + sample_rate: int, + runtime_config: MediumRuntimeConfig, + ) -> InferenceResult: + return run_medium_inference_once( + loaded_model=loaded_model, + backend=backend, + audio=audio, + sample_rate=sample_rate, + runtime_config=runtime_config, + logger=logger, + is_dependency_error=is_dependency_error, + dependency_error_factory=runtime_dependency_error_type, + transient_error_factory=transient_error_type, + ) + + def _run_with_process_timeout( + payload: MediumProcessPayload, + timeout_seconds: float, + ) -> InferenceResult: + return _run_with_process_timeout_impl( + payload=payload, + resolve_profile=lambda _payload: "medium", + timeout_seconds=timeout_seconds, + get_context=mp.get_context, + logger=logger, + setup_phase_name=PHASE_EMOTION_SETUP, + inference_phase_name=PHASE_EMOTION_INFERENCE, + log_phase_started=log_phase_started, + log_phase_completed=log_phase_completed, + log_phase_failed=log_phase_failed, + run_process_setup_compute_handshake=_run_process_setup_compute_handshake_impl, + worker_target=_medium_worker_entry, + recv_worker_message=_recv_worker_message, + is_setup_complete_message=_is_setup_complete_message, + terminate_worker_process=_terminate_worker_process, + timeout_error_factory=timeout_error_type, + execution_error_factory=execution_error_type, + worker_label="Medium inference", + process_join_grace_seconds=_TERMINATE_GRACE_SECONDS, + parse_worker_completion_message=_parse_worker_completion_message, + ) + + def _recv_worker_message( + connection: Connection, + *, + stage: str, + ) -> WorkerMessage: + return _recv_worker_message_binding( + connection=connection, + stage=stage, + impl=_recv_worker_message_impl, + worker_label="Medium inference", + error_factory=execution_error_type, + ) + + def _is_setup_complete_message(message: WorkerMessage) -> bool: + return _is_setup_complete_message_binding( + message=message, + impl=_is_setup_complete_message_impl, + worker_label="Medium inference", + error_factory=execution_error_type, + ) + + def _parse_worker_completion_message(worker_message: WorkerMessage) -> InferenceResult: + return _parse_worker_completion_message_binding( + worker_message=worker_message, + impl=_parse_worker_completion_message_impl, + worker_label="Medium inference", + error_factory=execution_error_type, + raise_worker_error=_raise_worker_error, + result_type=InferenceResult, + ) + + def _prepare_in_process_operation( + *, + request: InferenceRequest, + settings: AppConfig, + loaded_model: _MediumLoadedModel | None, + backend: XLSRBackend | None, + expected_backend_model_id: str, + runtime_device: str, + runtime_dtype: str, + ) -> medium_worker_operation_helpers.PreparedMediumOperation[_MediumLoadedModel, XLSRBackend]: + return prepare_in_process_operation( + request=request, + settings=settings, + loaded_model=loaded_model, + backend=backend, + expected_backend_model_id=expected_backend_model_id, + runtime_device=runtime_device, + runtime_dtype=runtime_dtype, + load_model_fn=load_model, + read_audio_file_fn=read_audio_file, + backend_factory=XLSRBackend, + logger=logger, + model_unavailable_error_factory=model_unavailable_error_type, + model_load_error_factory=model_load_error_type, + ) + + def _prepare_execution_context( + *, + request: InferenceRequest, + settings: AppConfig, + loaded_model: _MediumLoadedModel | None, + backend: XLSRBackend | None, + enforce_timeout: bool, + ) -> MediumExecutionContext[MediumProcessPayload, _MediumLoadedModel, XLSRBackend]: + return prepare_execution_context( + request=request, + settings=settings, + loaded_model=loaded_model, + backend=backend, + enforce_timeout=enforce_timeout, + resolve_medium_model_id=resolve_medium_model_id, + resolve_runtime_policy=lambda active_settings: resolve_medium_feature_runtime_policy( + settings=active_settings + ), + prepare_retry_state=medium_worker_operation_helpers.prepare_retry_state, + prepare_in_process_operation=_prepare_in_process_operation, + build_process_payload=lambda backend_model_id, policy_device, policy_dtype: ( + MediumProcessPayload( + request=request, + settings=_build_runtime_settings_snapshot_impl( + settings, + runtime_device=policy_device, + runtime_dtype=policy_dtype, + ), + expected_backend_model_id=backend_model_id, + ) + ), + logger=logger, + ) + + def _terminate_worker_process(process: BaseProcess) -> None: + _terminate_worker_process_binding( + process=process, + impl=_terminate_worker_process_impl, + terminate_grace_seconds=_TERMINATE_GRACE_SECONDS, + kill_grace_seconds=_KILL_GRACE_SECONDS, + ) + + def _raise_worker_error(error_type: str, message: str) -> None: + _raise_worker_error_binding( + error_type=error_type, + message=message, + impl=_raise_worker_error_impl, + known_error_factories=worker_error_factories, + unknown_error_factory=execution_error_type, + worker_label="Medium inference", + ) + + execution_context = _prepare_execution_context( + request=request, + settings=settings, + loaded_model=loaded_model, + backend=backend, + enforce_timeout=enforce_timeout, + ) + expected_backend_model_id = execution_context.expected_backend_model_id + + with _SINGLE_FLIGHT_REGISTRY.lock( + profile="medium", + backend_model_id=expected_backend_model_id, + ): + return execute_medium_inference_with_retry( + execution_context=execution_context, + settings=settings, + injected_backend=backend, + enforce_timeout=enforce_timeout, + allow_retries=allow_retries, + expected_backend_model_id=expected_backend_model_id, + logger=logger, + run_with_process_timeout=_run_with_process_timeout, + run_process_operation=lambda prepared: run_process_operation( + prepared, + run_medium_inference_once=_run_medium_inference_once, + ), + run_with_timeout=_run_with_timeout, + prepare_medium_backend_runtime=lambda active_backend: prepare_medium_backend_runtime( + backend=active_backend, + is_dependency_error=is_dependency_error, + dependency_error_factory=runtime_dependency_error_type, + transient_error_factory=transient_error_type, + ), + cpu_backend_builder=lambda: _build_medium_backend_for_settings_impl( + settings=_build_cpu_settings_snapshot_impl(settings), + expected_backend_model_id=expected_backend_model_id, + runtime_device="cpu", + runtime_dtype="float32", + backend_factory=XLSRBackend, + ), + timeout_error_type=timeout_error_type, + transient_error_type=transient_error_type, + runtime_dependency_error_type=runtime_dependency_error_type, + execution_error_type=execution_error_type, + run_retry_policy_impl=_run_medium_retry_policy_impl, + retry_delay_seconds=retry_delay_seconds, + should_retry_on_cpu_after_transient_failure=should_retry_on_cpu_after_transient_failure, + summarize_transient_failure=summarize_transient_failure, + ) + + def prepare_in_process_operation( *, request: InferenceRequest, @@ -445,6 +840,7 @@ def prepare_medium_backend_runtime( __all__ = [ + "run_medium_inference_from_public_boundary", "execute_medium_inference_with_retry", "is_dependency_error", "prepare_execution_context", diff --git a/ser/_internal/transcription/compatibility.py b/ser/_internal/transcription/compatibility.py index 47671e4..1e57c37 100644 --- a/ser/_internal/transcription/compatibility.py +++ b/ser/_internal/transcription/compatibility.py @@ -16,6 +16,7 @@ type _CompatibilityIssueKind = Literal["noise", "operational"] type _EmittedIssueKeySet = set[tuple[str, str, str]] +_EMITTED_COMPATIBILITY_ISSUE_KEYS: _EmittedIssueKeySet = set() class _CompatibilityAdapter(Protocol): @@ -45,6 +46,13 @@ def backend_id(self) -> TranscriptionBackendId: type _ErrorFactory = Callable[[str], Exception] +def _resolve_emitted_issue_keys( + emitted_issue_keys: _EmittedIssueKeySet | None, +) -> _EmittedIssueKeySet: + """Returns the shared compatibility registry when no explicit registry is provided.""" + return _EMITTED_COMPATIBILITY_ISSUE_KEYS if emitted_issue_keys is None else emitted_issue_keys + + def _summarize_operational_issue_message(issue_message: str, *, max_chars: int) -> str: """Returns one concise operational issue message for CLI log hygiene.""" normalized = " ".join(issue_message.split()) @@ -92,13 +100,14 @@ def mark_compatibility_issues_as_emitted( backend_id: TranscriptionBackendId, issue_kind: _CompatibilityIssueKind, issue_codes: tuple[str, ...], - emitted_issue_keys: _EmittedIssueKeySet, + emitted_issue_keys: _EmittedIssueKeySet | None = None, ) -> None: """Marks compatibility issues as already emitted to prevent duplicate logs.""" + active_emitted_issue_keys = _resolve_emitted_issue_keys(emitted_issue_keys) for issue_code in issue_codes: if not issue_code: continue - emitted_issue_keys.add((backend_id, issue_kind, issue_code)) + active_emitted_issue_keys.add((backend_id, issue_kind, issue_code)) def check_adapter_compatibility( @@ -109,10 +118,11 @@ def check_adapter_compatibility( runtime_request_resolver: Callable[[_TProfile, AppConfig], BackendRuntimeRequest], adapter_resolver: _AdapterResolver, error_factory: _ErrorFactory, - emitted_issue_keys: _EmittedIssueKeySet, + emitted_issue_keys: _EmittedIssueKeySet | None = None, logger: logging.Logger, ) -> CompatibilityReport: """Validates backend compatibility and logs non-blocking issues once.""" + active_emitted_issue_keys = _resolve_emitted_issue_keys(emitted_issue_keys) backend_id = active_profile.backend_id adapter = adapter_resolver(backend_id) resolved_runtime_request = ( @@ -131,7 +141,7 @@ def check_adapter_compatibility( issue_code=issue.code, issue_message=issue.message, issue_impact=issue.impact, - emitted_issue_keys=emitted_issue_keys, + emitted_issue_keys=active_emitted_issue_keys, logger=logger, ) for issue in report.operational_issues: @@ -141,7 +151,7 @@ def check_adapter_compatibility( issue_code=issue.code, issue_message=issue.message, issue_impact=issue.impact, - emitted_issue_keys=emitted_issue_keys, + emitted_issue_keys=active_emitted_issue_keys, logger=logger, ) if report.has_blocking_issues: diff --git a/ser/_internal/transcription/public_boundary_runtime.py b/ser/_internal/transcription/public_boundary_runtime.py index e86bf6f..7a69056 100644 --- a/ser/_internal/transcription/public_boundary_runtime.py +++ b/ser/_internal/transcription/public_boundary_runtime.py @@ -50,7 +50,7 @@ def check_adapter_compatibility_from_public_boundary( check_adapter_compatibility_impl: Callable[..., CompatibilityReport], runtime_request_resolver: _RuntimeRequestResolver, adapter_resolver: _AdapterResolver, - emitted_issue_keys: set[tuple[str, str, str]], + emitted_issue_keys: set[tuple[str, str, str]] | None, logger: logging.Logger, error_factory: _ErrorFactory, ) -> CompatibilityReport: diff --git a/ser/_internal/transcription/public_boundary_support.py b/ser/_internal/transcription/public_boundary_support.py index 45a5d11..75ba522 100644 --- a/ser/_internal/transcription/public_boundary_support.py +++ b/ser/_internal/transcription/public_boundary_support.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +import multiprocessing as mp from collections.abc import Callable from typing import Literal, Never, Protocol, cast @@ -34,6 +35,12 @@ transcription_setup_required as _transcription_setup_required_impl, ) from ser._internal.transcription.process_isolation import WorkerMessage as _WorkerMessage +from ser._internal.transcription.process_isolation import ( + raise_worker_error as _raise_worker_error_impl, +) +from ser._internal.transcription.process_isolation import ( + recv_worker_message as _recv_worker_message_impl, +) from ser._internal.transcription.process_isolation import ( run_faster_whisper_process_isolated as _run_faster_whisper_process_isolated_impl, ) @@ -43,12 +50,36 @@ from ser._internal.transcription.process_isolation import ( should_use_process_isolated_path as _should_use_process_isolated_path_impl, ) +from ser._internal.transcription.process_isolation import ( + terminate_worker_process as _terminate_worker_process_impl, +) +from ser._internal.transcription.process_isolation import ( + transcription_worker_entry as _transcription_worker_entry_impl, +) from ser._internal.transcription.process_worker import ( build_transcription_process_payload as _build_transcription_process_payload, ) +from ser._internal.transcription.public_boundary_process import ( + raise_worker_error_from_public_boundary as _raise_worker_error_boundary_impl, +) +from ser._internal.transcription.public_boundary_process import ( + recv_worker_message_from_public_boundary as _recv_worker_message_boundary_impl, +) +from ser._internal.transcription.public_boundary_process import ( + resolve_transcription_adapter_from_public_boundary as _resolve_transcription_adapter_boundary_impl, +) from ser._internal.transcription.public_boundary_process import ( run_faster_whisper_process_isolated_from_public_boundary as _run_faster_whisper_process_isolated_boundary_impl, ) +from ser._internal.transcription.public_boundary_process import ( + spawn_context_for_public_boundary as _spawn_context_boundary_impl, +) +from ser._internal.transcription.public_boundary_process import ( + terminate_worker_process_from_public_boundary as _terminate_worker_process_boundary_impl, +) +from ser._internal.transcription.public_boundary_process import ( + transcription_worker_entry_from_public_boundary as _transcription_worker_entry_boundary_impl, +) from ser._internal.transcription.public_boundary_runtime import ( check_adapter_compatibility_from_public_boundary as _check_adapter_compatibility_boundary_impl, ) @@ -128,6 +159,9 @@ def use_vad(self) -> bool: type _WorkerTerminator = Callable[[object], None] type _WorkerEntry = Callable[[object, object], None] +_TERMINATE_GRACE_SECONDS = 5.0 +_KILL_GRACE_SECONDS = 2.0 + def _runtime_request_from_profile( active_profile: _BackendProfile, @@ -168,7 +202,7 @@ def check_adapter_compatibility( active_profile: _BackendProfile, settings: AppConfig, runtime_request: BackendRuntimeRequest | None = None, - emitted_issue_keys: set[tuple[str, str, str]], + emitted_issue_keys: set[tuple[str, str, str]] | None = None, logger: logging.Logger, error_factory: _ErrorFactory, ) -> CompatibilityReport: @@ -191,7 +225,7 @@ def mark_compatibility_issues_as_emitted( backend_id: TranscriptionBackendId, issue_kind: _CompatibilityIssueKind, issue_codes: tuple[str, ...], - emitted_issue_keys: set[tuple[str, str, str]], + emitted_issue_keys: set[tuple[str, str, str]] | None = None, ) -> None: """Marks compatibility issues as emitted so they are logged only once.""" _mark_compatibility_issues_as_emitted_impl( @@ -206,7 +240,7 @@ def transcription_setup_required( *, active_profile: _BackendProfile, settings: AppConfig, - emitted_issue_keys: set[tuple[str, str, str]], + emitted_issue_keys: set[tuple[str, str, str]] | None = None, logger: logging.Logger, error_factory: _ErrorFactory, ) -> bool: @@ -230,7 +264,7 @@ def prepare_transcription_assets( *, active_profile: _BackendProfile, settings: AppConfig, - emitted_issue_keys: set[tuple[str, str, str]], + emitted_issue_keys: set[tuple[str, str, str]] | None = None, logger: logging.Logger, error_factory: _ErrorFactory, ) -> None: @@ -256,7 +290,7 @@ def load_whisper_model_for_settings( settings: AppConfig, profile_factory: _ProfileFactory, logger: logging.Logger, - emitted_issue_keys: set[tuple[str, str, str]], + emitted_issue_keys: set[tuple[str, str, str]] | None = None, error_factory: _ErrorFactory, ) -> object: """Loads one transcription model using an explicit settings snapshot.""" @@ -301,6 +335,34 @@ def _runtime_request_for_isolated_faster_whisper( ) +def _spawn_context_for_public_boundary() -> object: + """Returns the multiprocessing spawn context used by the public boundary.""" + return _spawn_context_boundary_impl(get_context=mp.get_context) + + +def _resolve_transcription_adapter_for_public_boundary( + backend_id: TranscriptionBackendId, +) -> object: + """Resolves one backend adapter for process-isolated public-boundary workers.""" + return _resolve_transcription_adapter_boundary_impl( + backend_id, + adapter_resolver=resolve_transcription_backend_adapter, + ) + + +def _transcription_worker_entry_for_public_boundary( + payload: object, + connection: object, +) -> None: + """Runs one process-isolated transcription worker with module-level collaborators.""" + _transcription_worker_entry_boundary_impl( + payload, + connection, + transcription_worker_entry_impl=_transcription_worker_entry_impl, + adapter_resolver=_resolve_transcription_adapter_for_public_boundary, + ) + + def run_faster_whisper_process_isolated( *, file_path: str, @@ -309,15 +371,32 @@ def run_faster_whisper_process_isolated( settings: AppConfig, logger: logging.Logger, error_factory: _ErrorFactory, - terminate_grace_seconds: float, - transcript_word_factory: _TranscriptWordFactory, - spawn_context_resolver: _ResolveSpawnContext, - worker_entry: _WorkerEntry, - recv_worker_message_fn: _WorkerMessageReceiver, - raise_worker_error_fn: _WorkerErrorRaiser, - terminate_worker_process_fn: _WorkerTerminator, ) -> list[TranscriptWord]: """Runs faster-whisper transcription inside a spawned worker process.""" + + def _recv_worker_message(connection: object, *, stage: str) -> _WorkerMessage: + return _recv_worker_message_boundary_impl( + connection, + recv_worker_message_impl=_recv_worker_message_impl, + stage=stage, + error_factory=error_factory, + ) + + def _raise_worker_error(message: object) -> Never: + _raise_worker_error_boundary_impl( + cast(_WorkerMessage, message), + raise_worker_error_impl=_raise_worker_error_impl, + error_factory=error_factory, + ) + + def _terminate_worker_process(process: object) -> None: + _terminate_worker_process_boundary_impl( + process, + terminate_worker_process_impl=_terminate_worker_process_impl, + terminate_grace_seconds=_TERMINATE_GRACE_SECONDS, + kill_grace_seconds=_KILL_GRACE_SECONDS, + ) + return _run_faster_whisper_process_isolated_boundary_impl( file_path=file_path, language=language, @@ -333,14 +412,14 @@ def run_faster_whisper_process_isolated( ) ), payload_factory=_build_transcription_process_payload, - spawn_context_resolver=spawn_context_resolver, - worker_entry=worker_entry, - recv_worker_message_fn=recv_worker_message_fn, - raise_worker_error_fn=raise_worker_error_fn, - terminate_worker_process_fn=terminate_worker_process_fn, + spawn_context_resolver=_spawn_context_for_public_boundary, + worker_entry=_transcription_worker_entry_for_public_boundary, + recv_worker_message_fn=_recv_worker_message, + raise_worker_error_fn=_raise_worker_error, + terminate_worker_process_fn=_terminate_worker_process, logger=logger, error_factory=error_factory, - terminate_grace_seconds=terminate_grace_seconds, + terminate_grace_seconds=_TERMINATE_GRACE_SECONDS, ) @@ -352,7 +431,7 @@ def extract_transcript_in_process( settings: AppConfig, profile_factory: _ProfileFactory, logger: logging.Logger, - emitted_issue_keys: set[tuple[str, str, str]], + emitted_issue_keys: set[tuple[str, str, str]] | None = None, error_factory: _ErrorFactory, release_memory_fn: Callable[..., None], phase_started_fn: Callable[..., float], @@ -419,7 +498,7 @@ def transcribe_with_profile( settings: AppConfig, profile_factory: _ProfileFactory, logger: logging.Logger, - emitted_issue_keys: set[tuple[str, str, str]], + emitted_issue_keys: set[tuple[str, str, str]] | None = None, error_factory: _ErrorFactory, passthrough_error_cls: type[Exception], ) -> list[TranscriptWord]: @@ -460,16 +539,8 @@ def extract_transcript( *, settings: AppConfig, profile_factory: _ProfileFactory, - transcript_word_factory: _TranscriptWordFactory, logger: logging.Logger, - emitted_issue_keys: set[tuple[str, str, str]], error_factory: _ErrorFactory, - terminate_grace_seconds: float, - spawn_context_resolver: _ResolveSpawnContext, - worker_entry: _WorkerEntry, - recv_worker_message_fn: _WorkerMessageReceiver, - raise_worker_error_fn: _WorkerErrorRaiser, - terminate_worker_process_fn: _WorkerTerminator, release_memory_fn: Callable[..., None], phase_started_fn: Callable[..., float], phase_completed_fn: Callable[..., float | None], @@ -497,19 +568,11 @@ def extract_transcript( run_process_isolated_fn=lambda **kwargs: run_faster_whisper_process_isolated( logger=logger, error_factory=error_factory, - terminate_grace_seconds=terminate_grace_seconds, - transcript_word_factory=transcript_word_factory, - spawn_context_resolver=spawn_context_resolver, - worker_entry=worker_entry, - recv_worker_message_fn=recv_worker_message_fn, - raise_worker_error_fn=raise_worker_error_fn, - terminate_worker_process_fn=terminate_worker_process_fn, **kwargs, ), run_in_process_fn=lambda **kwargs: extract_transcript_in_process( profile_factory=profile_factory, logger=logger, - emitted_issue_keys=emitted_issue_keys, error_factory=error_factory, release_memory_fn=release_memory_fn, phase_started_fn=phase_started_fn, diff --git a/ser/runtime/accurate_inference.py b/ser/runtime/accurate_inference.py index d496be1..49c7c7a 100644 --- a/ser/runtime/accurate_inference.py +++ b/ser/runtime/accurate_inference.py @@ -1,116 +1,17 @@ -"""Accurate-profile inference runner with bounded retries and timeout guards.""" +"""Accurate-profile public inference boundary.""" from __future__ import annotations -import multiprocessing as mp -from collections.abc import Callable -from dataclasses import dataclass -from functools import partial -from multiprocessing.connection import Connection -from multiprocessing.process import BaseProcess -from typing import Literal, cast - -import numpy as np -from numpy.typing import NDArray - from ser._internal.runtime import accurate_public_boundary as _boundary_support -from ser._internal.runtime.single_flight import SingleFlightRegistry -from ser._internal.runtime.worker_lifecycle import ( - is_setup_complete_message as _is_setup_complete_message_impl, -) -from ser._internal.runtime.worker_lifecycle import ( - parse_worker_completion_message as _parse_worker_completion_message_impl, -) -from ser._internal.runtime.worker_lifecycle import raise_worker_error as _raise_worker_error_impl -from ser._internal.runtime.worker_lifecycle import recv_worker_message as _recv_worker_message_impl -from ser._internal.runtime.worker_lifecycle import ( - run_process_setup_compute_handshake as _run_process_setup_compute_handshake_impl, -) -from ser._internal.runtime.worker_lifecycle import run_with_timeout as _run_with_timeout_impl -from ser._internal.runtime.worker_lifecycle import ( - terminate_worker_process as _terminate_worker_process_impl, -) -from ser.config import AppConfig, ProfileRuntimeConfig -from ser.models.emotion_model import LoadedModel, load_model -from ser.models.profile_runtime import resolve_accurate_model_id -from ser.repr import ( - Emotion2VecBackend, - FeatureBackend, - WhisperBackend, -) -from ser.runtime.accurate_backend_runtime import ( - runtime_config_for_profile as _runtime_config_for_profile_impl, -) -from ser.runtime.accurate_runtime_support import ( - build_cpu_settings_snapshot as _build_cpu_settings_snapshot_impl, -) -from ser.runtime.accurate_runtime_support import ( - build_process_settings_snapshot as _build_process_settings_snapshot_impl, -) -from ser.runtime.accurate_worker_lifecycle import ( - is_setup_complete_message as _is_setup_complete_message_orchestration, -) -from ser.runtime.accurate_worker_lifecycle import ( - parse_worker_completion_message as _parse_worker_completion_message_orchestration, -) -from ser.runtime.accurate_worker_lifecycle import ( - raise_worker_error as _raise_worker_error_orchestration, -) -from ser.runtime.accurate_worker_lifecycle import ( - recv_worker_message as _recv_worker_message_orchestration, -) -from ser.runtime.accurate_worker_lifecycle import ( - run_with_process_timeout as _run_with_process_timeout_orchestration_impl, -) -from ser.runtime.accurate_worker_lifecycle import ( - run_worker_entry as _run_worker_entry_orchestration, -) -from ser.runtime.accurate_worker_lifecycle import ( - terminate_worker_process as _terminate_worker_process_orchestration, -) -from ser.runtime.accurate_worker_operation import ( - AccurateRetryOperationState, - PreparedAccurateOperation, -) -from ser.runtime.accurate_worker_operation import prepare_retry_state as _prepare_retry_state_impl -from ser.runtime.accurate_worker_operation import ( - run_inference_operation as _run_inference_operation_impl, -) +from ser.config import AppConfig +from ser.models.emotion_model import LoadedModel +from ser.repr import FeatureBackend from ser.runtime.contracts import InferenceRequest -from ser.runtime.phase_contract import PHASE_EMOTION_INFERENCE, PHASE_EMOTION_SETUP -from ser.runtime.phase_timing import ( - log_phase_completed, - log_phase_failed, - log_phase_started, -) from ser.runtime.schema import InferenceResult -from ser.utils.audio_utils import read_audio_file from ser.utils.logger import get_logger logger = get_logger(__name__) -type FeatureMatrix = NDArray[np.float64] -type WorkerPhaseMessage = tuple[Literal["phase"], Literal["setup_complete"]] -type WorkerSuccessMessage = tuple[Literal["ok"], InferenceResult] -type WorkerErrorMessage = tuple[Literal["err"], str, str] -type WorkerMessage = WorkerPhaseMessage | WorkerSuccessMessage | WorkerErrorMessage -type _PreparedAccurateOperation = PreparedAccurateOperation[LoadedModel, FeatureBackend] - -_TERMINATE_GRACE_SECONDS = 0.5 -_KILL_GRACE_SECONDS = 0.5 -_SINGLE_FLIGHT_REGISTRY = SingleFlightRegistry() - - -@dataclass(frozen=True) -class AccurateProcessPayload: - """Serializable payload for one process-isolated accurate inference attempt.""" - - request: InferenceRequest - settings: AppConfig - expected_backend_id: str - expected_profile: str - expected_backend_model_id: str | None - class AccurateModelUnavailableError(FileNotFoundError): """Raised when a compatible accurate-profile model artifact is unavailable.""" @@ -136,17 +37,6 @@ class AccurateTransientBackendError(RuntimeError): """Raised for retryable accurate backend encoding failures.""" -_WORKER_ERROR_FACTORIES: dict[str, Callable[[str], Exception]] = { - "ValueError": ValueError, - "AccurateRuntimeDependencyError": AccurateRuntimeDependencyError, - "AccurateTransientBackendError": AccurateTransientBackendError, - "AccurateModelUnavailableError": AccurateModelUnavailableError, - "AccurateModelLoadError": AccurateModelLoadError, - "AccurateInferenceTimeoutError": AccurateInferenceTimeoutError, - "RuntimeError": RuntimeError, -} - - def run_accurate_inference( request: InferenceRequest, settings: AppConfig, @@ -159,377 +49,33 @@ def run_accurate_inference( expected_profile: str = "accurate", expected_backend_model_id: str | None = None, ) -> InferenceResult: - """Runs accurate-profile inference with bounded retries and timeout budget. - - Args: - request: Runtime inference request payload. - settings: Active application settings. - loaded_model: Optional preloaded model artifact for repeated inference calls. - backend: Optional preinitialized feature backend for repeated inference calls. - enforce_timeout: Whether to apply timeout wrapper for each inference attempt. - allow_retries: Whether to apply configured retry budget for retryable errors. - expected_backend_id: Expected backend id in model artifact metadata. - expected_profile: Expected profile identifier in model artifact metadata. - expected_backend_model_id: Expected backend model id in model artifact metadata. - - Returns: - Detailed inference result with frame and segment predictions. - - Raises: - AccurateModelUnavailableError: If no compatible accurate artifact exists. - AccurateRuntimeDependencyError: If required accurate dependencies are missing. - AccurateModelLoadError: If model loading fails for non-compatibility reasons. - AccurateInferenceTimeoutError: If all attempts timed out. - AccurateInferenceExecutionError: If transient backend failures exhaust retries. - ValueError: If feature dimensions are incompatible with loaded artifact. - """ - runtime_config = _runtime_config_for_profile_impl( - settings=settings, - expected_profile=expected_profile, - unsupported_profile_error=AccurateModelUnavailableError, - ) - resolved_expected_backend_model_id = expected_backend_model_id - if resolved_expected_backend_model_id is None and expected_backend_id == "hf_whisper": - resolved_expected_backend_model_id = resolve_accurate_model_id(settings) - use_process_isolation = ( - enforce_timeout - and loaded_model is None - and backend is None - and settings.runtime_flags.profile_pipeline - and runtime_config.process_isolation - ) - process_payload: AccurateProcessPayload | None = None - if use_process_isolation: - process_payload = AccurateProcessPayload( - request=request, - settings=_build_process_settings_snapshot_impl(settings), - expected_backend_id=expected_backend_id, - expected_profile=expected_profile, - expected_backend_model_id=resolved_expected_backend_model_id, - ) - cpu_settings = _build_cpu_settings_snapshot_impl(settings) - cpu_backend_builder: Callable[[], FeatureBackend] = partial( - _build_backend_for_profile, - expected_backend_id=expected_backend_id, - expected_backend_model_id=resolved_expected_backend_model_id, - settings=cpu_settings, - ) - - retry_state, prepared_operation, setup_started_at = _prepare_retry_state_impl( - use_process_isolation=use_process_isolation, - request=request, - settings=settings, - runtime_config=runtime_config, + """Runs accurate-profile inference with bounded retries and timeout budget.""" + return _boundary_support.run_accurate_inference_from_public_boundary( + request, + settings, loaded_model=loaded_model, backend=backend, - logger=logger, - profile=expected_profile, - setup_phase_name=PHASE_EMOTION_SETUP, - log_phase_started=log_phase_started, - log_phase_failed=log_phase_failed, - process_payload=process_payload, - prepare_in_process_operation=partial( - _prepare_in_process_accurate_operation, - expected_backend_id=expected_backend_id, - expected_profile=expected_profile, - expected_backend_model_id=resolved_expected_backend_model_id, - ), - ) - with _SINGLE_FLIGHT_REGISTRY.lock( - profile=expected_profile, - backend_model_id=resolved_expected_backend_model_id, - ): - return _execute_accurate_inference_with_retry( - use_process_isolation=use_process_isolation, - retry_state=retry_state, - prepared_operation=prepared_operation, - setup_started_at=setup_started_at, - settings=settings, - runtime_config=runtime_config, - backend=backend, - expected_backend_id=expected_backend_id, - expected_profile=expected_profile, - allow_retries=allow_retries, - enforce_timeout=enforce_timeout, - cpu_backend_builder=cpu_backend_builder, - ) - - -def _execute_accurate_inference_with_retry( - *, - use_process_isolation: bool, - retry_state: AccurateRetryOperationState[AccurateProcessPayload, FeatureBackend], - prepared_operation: _PreparedAccurateOperation | None, - setup_started_at: float | None, - settings: AppConfig, - runtime_config: ProfileRuntimeConfig, - backend: FeatureBackend | None, - expected_backend_id: str, - expected_profile: str, - allow_retries: bool, - enforce_timeout: bool, - cpu_backend_builder: Callable[[], FeatureBackend], -) -> InferenceResult: - """Finalizes setup and executes accurate inference under retry policy.""" - return _boundary_support.execute_accurate_inference_with_retry( - use_process_isolation=use_process_isolation, - retry_state=retry_state, - prepared_operation=prepared_operation, - setup_started_at=setup_started_at, - settings=settings, - backend=backend, + enforce_timeout=enforce_timeout, + allow_retries=allow_retries, expected_backend_id=expected_backend_id, expected_profile=expected_profile, - allow_retries=allow_retries, - enforce_timeout=enforce_timeout, - cpu_backend_builder=cpu_backend_builder, + expected_backend_model_id=expected_backend_model_id, logger=logger, - run_accurate_retryable_operation=_run_accurate_retryable_operation, - retry_delay_seconds=_retry_delay_seconds, - process_payload_cpu_fallback=_payload_with_cpu_settings, - timeout_error_type=AccurateInferenceTimeoutError, + model_unavailable_error_type=AccurateModelUnavailableError, runtime_dependency_error_type=AccurateRuntimeDependencyError, + model_load_error_type=AccurateModelLoadError, + timeout_error_type=AccurateInferenceTimeoutError, inference_execution_error_type=AccurateInferenceExecutionError, transient_backend_error_type=AccurateTransientBackendError, ) -def _run_accurate_retryable_operation( - *, - enforce_timeout: bool, - use_process_isolation: bool, - retry_state: AccurateRetryOperationState[AccurateProcessPayload, FeatureBackend], - prepared_operation: _PreparedAccurateOperation | None, - timeout_seconds: float, - expected_profile: str, -) -> InferenceResult: - """Runs one accurate inference attempt using the current retry state.""" - return _boundary_support.run_accurate_retryable_operation( - enforce_timeout=enforce_timeout, - use_process_isolation=use_process_isolation, - retry_state=retry_state, - prepared_operation=prepared_operation, - timeout_seconds=timeout_seconds, - expected_profile=expected_profile, - logger=logger, - run_with_process_timeout=_run_with_process_timeout, - run_accurate_inference_once=_run_accurate_inference_once, - run_with_timeout=_run_with_timeout_impl, - run_inference_operation=_run_inference_operation_impl, - timeout_error_factory=AccurateInferenceTimeoutError, - ) - - -def _payload_with_cpu_settings( - payload: AccurateProcessPayload, -) -> AccurateProcessPayload: - """Returns one process payload updated to use CPU torch selectors.""" - return _boundary_support.payload_with_cpu_settings(payload) - - -def _run_accurate_inference_once( - *, - loaded_model: LoadedModel, - backend: FeatureBackend, - audio: NDArray[np.float32], - sample_rate: int, - runtime_config: ProfileRuntimeConfig, -) -> InferenceResult: - """Runs one accurate inference attempt without retry control.""" - return _boundary_support.run_accurate_inference_once( - loaded_model=loaded_model, - backend=backend, - audio=audio, - sample_rate=sample_rate, - runtime_config=runtime_config, - logger=logger, - dependency_error_factory=AccurateRuntimeDependencyError, - transient_error_factory=AccurateTransientBackendError, - ) - - -def _run_with_process_timeout( - payload: AccurateProcessPayload, - *, - timeout_seconds: float, -) -> InferenceResult: - """Runs one process-isolated attempt with timeout applied only to compute.""" - return _run_with_process_timeout_orchestration_impl( - payload=payload, - timeout_seconds=timeout_seconds, - get_context=mp.get_context, - logger=logger, - setup_phase_name=PHASE_EMOTION_SETUP, - inference_phase_name=PHASE_EMOTION_INFERENCE, - log_phase_started=log_phase_started, - log_phase_completed=log_phase_completed, - log_phase_failed=log_phase_failed, - run_process_setup_compute_handshake=_run_process_setup_compute_handshake_impl, - worker_target=_worker_entry, - recv_worker_message=_recv_worker_message, - is_setup_complete_message=_is_setup_complete_message, - terminate_worker_process=_terminate_worker_process, - timeout_error_factory=AccurateInferenceTimeoutError, - execution_error_factory=AccurateInferenceExecutionError, - worker_label="Accurate inference", - process_join_grace_seconds=_TERMINATE_GRACE_SECONDS, - parse_worker_completion_message=_parse_worker_completion_message, - ) - - -def _recv_worker_message( - connection: Connection, - *, - stage: str, -) -> tuple[object, ...]: - """Receives one worker message and validates tuple envelope shape.""" - return _recv_worker_message_orchestration( - connection=connection, - stage=stage, - impl=_recv_worker_message_impl, - worker_label="Accurate inference", - error_factory=AccurateInferenceExecutionError, - ) - - -def _is_setup_complete_message(message: tuple[object, ...]) -> bool: - """Returns whether one worker message marks setup completion.""" - return _is_setup_complete_message_orchestration( - message=message, - impl=_is_setup_complete_message_impl, - worker_label="Accurate inference", - error_factory=AccurateInferenceExecutionError, - ) - - -def _parse_worker_completion_message(worker_message: tuple[object, ...]) -> InferenceResult: - """Parses one worker completion message and returns inference result.""" - return _parse_worker_completion_message_orchestration( - worker_message=worker_message, - impl=_parse_worker_completion_message_impl, - worker_label="Accurate inference", - error_factory=AccurateInferenceExecutionError, - raise_worker_error=_raise_worker_error, - result_type=InferenceResult, - ) - - -def _worker_entry( - payload: AccurateProcessPayload, - connection: Connection, -) -> None: - """Executes one inference operation inside child process.""" - _run_worker_entry_orchestration( - payload=payload, - connection=connection, - prepare_process_operation=_prepare_process_operation, - run_process_operation=_run_process_operation, - ) - - -def _prepare_in_process_accurate_operation( - *, - request: InferenceRequest, - settings: AppConfig, - runtime_config: ProfileRuntimeConfig, - loaded_model: LoadedModel | None, - backend: FeatureBackend | None, - expected_backend_id: str, - expected_profile: str, - expected_backend_model_id: str | None, -) -> _PreparedAccurateOperation: - """Prepares one in-process accurate operation using runtime-specific contracts.""" - return cast( - _PreparedAccurateOperation, - _boundary_support.prepare_in_process_accurate_operation( - request=request, - settings=settings, - runtime_config=runtime_config, - loaded_model=loaded_model, - backend=backend, - expected_backend_id=expected_backend_id, - expected_profile=expected_profile, - expected_backend_model_id=expected_backend_model_id, - load_model_fn=load_model, - read_audio_file_fn=read_audio_file, - build_backend_for_profile_fn=_build_backend_for_profile, - logger=logger, - model_unavailable_error_factory=AccurateModelUnavailableError, - model_load_error_factory=AccurateModelLoadError, - ), - ) - - -def _prepare_process_operation( - payload: AccurateProcessPayload, -) -> _PreparedAccurateOperation: - """Performs untimed setup for one process-isolated accurate operation.""" - return cast( - _PreparedAccurateOperation, - _boundary_support.prepare_process_operation( - payload, - load_model_fn=load_model, - read_audio_file_fn=read_audio_file, - build_backend_for_profile_fn=_build_backend_for_profile, - logger=logger, - model_unavailable_error_factory=AccurateModelUnavailableError, - model_load_error_factory=AccurateModelLoadError, - runtime_dependency_error_factory=AccurateRuntimeDependencyError, - transient_error_factory=AccurateTransientBackendError, - ), - ) - - -def _run_process_operation(prepared: _PreparedAccurateOperation) -> InferenceResult: - """Runs one accurate compute phase inside isolated worker process.""" - return _boundary_support.run_process_operation( - prepared, - run_accurate_inference_once=lambda **kwargs: _run_accurate_inference_once(**kwargs), - ) - - -def _build_backend_for_profile( - *, - expected_backend_id: str, - expected_backend_model_id: str | None, - settings: AppConfig, -) -> FeatureBackend: - """Builds a feature backend aligned with profile/backend runtime expectations.""" - return _boundary_support.build_backend_for_profile( - expected_backend_id=expected_backend_id, - expected_backend_model_id=expected_backend_model_id, - settings=settings, - whisper_backend_factory=WhisperBackend, - emotion2vec_backend_factory=Emotion2VecBackend, - unsupported_backend_error=AccurateModelUnavailableError, - ) - - -def _terminate_worker_process(process: BaseProcess) -> None: - """Terminates a timed-out worker process with kill fallback.""" - _terminate_worker_process_orchestration( - process=process, - impl=_terminate_worker_process_impl, - terminate_grace_seconds=_TERMINATE_GRACE_SECONDS, - kill_grace_seconds=_KILL_GRACE_SECONDS, - ) - - -def _raise_worker_error(error_type: str, message: str) -> None: - """Rehydrates child-process errors into runtime-domain exceptions.""" - _raise_worker_error_orchestration( - error_type=error_type, - message=message, - impl=_raise_worker_error_impl, - known_error_factories=_WORKER_ERROR_FACTORIES, - unknown_error_factory=AccurateInferenceExecutionError, - worker_label="Accurate inference", - ) - - -def _retry_delay_seconds(*, base_delay: float, attempt: int) -> float: - """Returns bounded retry delay with small jitter.""" - return _boundary_support.retry_delay_seconds( - base_delay=base_delay, - attempt=attempt, - ) +__all__ = [ + "AccurateInferenceExecutionError", + "AccurateInferenceTimeoutError", + "AccurateModelLoadError", + "AccurateModelUnavailableError", + "AccurateRuntimeDependencyError", + "AccurateTransientBackendError", + "run_accurate_inference", +] diff --git a/ser/runtime/fast_inference.py b/ser/runtime/fast_inference.py index 5f7b592..fe63f8a 100644 --- a/ser/runtime/fast_inference.py +++ b/ser/runtime/fast_inference.py @@ -1,87 +1,16 @@ -"""Fast-profile inference runner with shared runtime policy semantics.""" +"""Fast-profile public inference boundary.""" from __future__ import annotations -import multiprocessing as mp -from collections.abc import Callable -from dataclasses import dataclass, replace -from multiprocessing.connection import Connection -from multiprocessing.process import BaseProcess -from typing import Literal - -from ser._internal.runtime.process_timeout import ( - run_with_process_timeout as _run_with_process_timeout_orchestration, -) -from ser._internal.runtime.single_flight import SingleFlightRegistry -from ser._internal.runtime.worker_bindings import ( - is_setup_complete_message as _is_setup_complete_message_binding, -) -from ser._internal.runtime.worker_bindings import ( - parse_worker_completion_message as _parse_worker_completion_message_binding, -) -from ser._internal.runtime.worker_bindings import raise_worker_error as _raise_worker_error_binding -from ser._internal.runtime.worker_bindings import ( - recv_worker_message as _recv_worker_message_binding, -) -from ser._internal.runtime.worker_bindings import run_worker_entry as _run_worker_entry_binding -from ser._internal.runtime.worker_bindings import ( - terminate_worker_process as _terminate_worker_process_binding, -) -from ser._internal.runtime.worker_lifecycle import ( - is_setup_complete_message as _is_setup_complete_message_impl, -) -from ser._internal.runtime.worker_lifecycle import ( - parse_worker_completion_message as _parse_worker_completion_message_impl, -) -from ser._internal.runtime.worker_lifecycle import raise_worker_error as _raise_worker_error_impl -from ser._internal.runtime.worker_lifecycle import recv_worker_message as _recv_worker_message_impl -from ser._internal.runtime.worker_lifecycle import ( - run_process_setup_compute_handshake as _run_process_setup_compute_handshake_impl, -) -from ser._internal.runtime.worker_lifecycle import run_with_timeout as _run_with_timeout_impl -from ser._internal.runtime.worker_lifecycle import ( - terminate_worker_process as _terminate_worker_process_impl, -) +from ser._internal.runtime import fast_public_boundary as _boundary_support from ser.config import AppConfig -from ser.models.emotion_model import LoadedModel, load_model, predict_emotions_detailed +from ser.models.emotion_model import LoadedModel from ser.runtime.contracts import InferenceRequest -from ser.runtime.phase_contract import PHASE_EMOTION_INFERENCE, PHASE_EMOTION_SETUP -from ser.runtime.phase_timing import ( - log_phase_completed, - log_phase_failed, - log_phase_started, -) -from ser.runtime.policy import run_with_retry_policy from ser.runtime.schema import InferenceResult from ser.utils.logger import get_logger logger = get_logger(__name__) -type WorkerPhaseMessage = tuple[Literal["phase"], Literal["setup_complete"]] -type WorkerSuccessMessage = tuple[Literal["ok"], InferenceResult] -type WorkerErrorMessage = tuple[Literal["err"], str, str] -type WorkerMessage = WorkerPhaseMessage | WorkerSuccessMessage | WorkerErrorMessage - -_TERMINATE_GRACE_SECONDS = 0.5 -_KILL_GRACE_SECONDS = 0.5 -_SINGLE_FLIGHT_REGISTRY = SingleFlightRegistry() - - -@dataclass(frozen=True) -class FastProcessPayload: - """Serializable payload for one process-isolated fast inference attempt.""" - - request: InferenceRequest - settings: AppConfig - - -@dataclass(frozen=True) -class _PreparedFastOperation: - """Holds setup-complete data for one fast worker compute phase.""" - - loaded_model: LoadedModel - request: InferenceRequest - class FastModelUnavailableError(FileNotFoundError): """Raised when a compatible fast-profile model artifact is unavailable.""" @@ -103,16 +32,6 @@ class FastTransientBackendError(RuntimeError): """Raised for retryable fast backend failures.""" -_WORKER_ERROR_FACTORIES: dict[str, Callable[[str], Exception]] = { - "ValueError": ValueError, - "FastModelUnavailableError": FastModelUnavailableError, - "FastModelLoadError": FastModelLoadError, - "FastInferenceTimeoutError": FastInferenceTimeoutError, - "FastTransientBackendError": FastTransientBackendError, - "RuntimeError": RuntimeError, -} - - def run_fast_inference( request: InferenceRequest, settings: AppConfig, @@ -122,318 +41,26 @@ def run_fast_inference( allow_retries: bool = True, ) -> InferenceResult: """Runs fast-profile inference with shared runtime timeout/retry policy.""" - runtime_config = settings.fast_runtime - use_process_isolation = ( - enforce_timeout - and loaded_model is None - and settings.runtime_flags.profile_pipeline - and runtime_config.process_isolation - ) - - process_payload: FastProcessPayload | None = None - active_loaded_model: LoadedModel | None = None - setup_started_at: float | None = None - if use_process_isolation: - process_payload = FastProcessPayload( - request=request, - settings=_build_process_settings_snapshot(settings), - ) - else: - setup_started_at = log_phase_started( - logger, - phase_name=PHASE_EMOTION_SETUP, - profile="fast", - ) - try: - active_loaded_model = _load_fast_model(settings, loaded_model=loaded_model) - except Exception: - log_phase_failed( - logger, - phase_name=PHASE_EMOTION_SETUP, - started_at=setup_started_at, - profile="fast", - ) - raise - log_phase_completed( - logger, - phase_name=PHASE_EMOTION_SETUP, - started_at=setup_started_at, - profile="fast", - ) - setup_started_at = None - - with _SINGLE_FLIGHT_REGISTRY.lock(profile="fast", backend_model_id=None): - - def operation() -> InferenceResult: - if enforce_timeout: - if use_process_isolation: - if process_payload is None: - raise RuntimeError( - "Fast process payload is missing for isolated execution." - ) - return _run_with_process_timeout( - process_payload, - timeout_seconds=runtime_config.timeout_seconds, - ) - inference_started_at = log_phase_started( - logger, - phase_name=PHASE_EMOTION_INFERENCE, - profile="fast", - ) - try: - result = _run_with_timeout_impl( - operation=lambda: _run_fast_inference_once( - request=request, - loaded_model=active_loaded_model, - settings=settings, - ), - timeout_seconds=runtime_config.timeout_seconds, - timeout_error_factory=FastInferenceTimeoutError, - timeout_label="Fast inference", - ) - except Exception: - log_phase_failed( - logger, - phase_name=PHASE_EMOTION_INFERENCE, - started_at=inference_started_at, - profile="fast", - ) - raise - log_phase_completed( - logger, - phase_name=PHASE_EMOTION_INFERENCE, - started_at=inference_started_at, - profile="fast", - ) - return result - inference_started_at = log_phase_started( - logger, - phase_name=PHASE_EMOTION_INFERENCE, - profile="fast", - ) - try: - result = _run_fast_inference_once( - request=request, - loaded_model=active_loaded_model, - settings=settings, - ) - except Exception: - log_phase_failed( - logger, - phase_name=PHASE_EMOTION_INFERENCE, - started_at=inference_started_at, - profile="fast", - ) - raise - log_phase_completed( - logger, - phase_name=PHASE_EMOTION_INFERENCE, - started_at=inference_started_at, - profile="fast", - ) - return result - - try: - return run_with_retry_policy( - operation=operation, - runtime_config=runtime_config, - allow_retries=allow_retries, - profile_label="Fast", - timeout_error_type=FastInferenceTimeoutError, - transient_error_type=FastTransientBackendError, - transient_exhausted_error=lambda _err: FastInferenceExecutionError( - "Fast inference exhausted retry budget after backend failures." - ), - retry_delay_seconds=_retry_delay_seconds, - logger=logger, - ) - except ValueError: - raise - except FastInferenceExecutionError: - raise - except RuntimeError as err: - raise FastInferenceExecutionError( - "Fast inference failed with a non-retryable runtime error." - ) from err - - -def _load_fast_model( - settings: AppConfig, - *, - loaded_model: LoadedModel | None, -) -> LoadedModel: - """Loads and validates fast model metadata when model is not injected.""" - if loaded_model is None: - try: - return load_model( - settings=settings, - expected_backend_id="handcrafted", - expected_profile="fast", - ) - except FileNotFoundError as err: - raise FastModelUnavailableError(str(err)) from err - except ValueError as err: - raise FastModelLoadError( - "Failed to load fast-profile model artifact from configured paths." - ) from err - _ensure_fast_compatible_model(loaded_model) - return loaded_model - - -def _run_fast_inference_once( - *, - request: InferenceRequest, - loaded_model: LoadedModel | None, - settings: AppConfig, -) -> InferenceResult: - """Runs one fast inference attempt without retry control.""" - active_loaded_model = _load_fast_model(settings, loaded_model=loaded_model) - return predict_emotions_detailed( - request.file_path, - loaded_model=active_loaded_model, - ) - - -def _run_with_process_timeout( - payload: FastProcessPayload, - *, - timeout_seconds: float, -) -> InferenceResult: - """Runs one process-isolated attempt with timeout applied only to compute.""" - return _run_with_process_timeout_orchestration( - payload=payload, - resolve_profile=lambda _payload: "fast", - timeout_seconds=timeout_seconds, - get_context=mp.get_context, + return _boundary_support.run_fast_inference_from_public_boundary( + request, + settings, + loaded_model=loaded_model, + enforce_timeout=enforce_timeout, + allow_retries=allow_retries, logger=logger, - setup_phase_name=PHASE_EMOTION_SETUP, - inference_phase_name=PHASE_EMOTION_INFERENCE, - log_phase_started=log_phase_started, - log_phase_completed=log_phase_completed, - log_phase_failed=log_phase_failed, - run_process_setup_compute_handshake=_run_process_setup_compute_handshake_impl, - worker_target=_worker_entry, - recv_worker_message=_recv_worker_message, - is_setup_complete_message=_is_setup_complete_message, - terminate_worker_process=_terminate_worker_process, - timeout_error_factory=FastInferenceTimeoutError, - execution_error_factory=FastInferenceExecutionError, - worker_label="Fast inference", - process_join_grace_seconds=_TERMINATE_GRACE_SECONDS, - parse_worker_completion_message=_parse_worker_completion_message, - ) - - -def _recv_worker_message( - connection: Connection, - *, - stage: str, -) -> WorkerMessage: - """Receives one worker message and validates tuple envelope shape.""" - return _recv_worker_message_binding( - connection=connection, - stage=stage, - impl=_recv_worker_message_impl, - worker_label="Fast inference", - error_factory=FastInferenceExecutionError, - ) - - -def _is_setup_complete_message(message: WorkerMessage) -> bool: - """Returns whether one worker message marks setup completion.""" - return _is_setup_complete_message_binding( - message=message, - impl=_is_setup_complete_message_impl, - worker_label="Fast inference", - error_factory=FastInferenceExecutionError, - ) - - -def _parse_worker_completion_message(worker_message: WorkerMessage) -> InferenceResult: - """Parses one worker completion message and returns inference result.""" - return _parse_worker_completion_message_binding( - worker_message=worker_message, - impl=_parse_worker_completion_message_impl, - worker_label="Fast inference", - error_factory=FastInferenceExecutionError, - raise_worker_error=_raise_worker_error, - result_type=InferenceResult, - ) - - -def _worker_entry(payload: FastProcessPayload, connection: Connection) -> None: - """Executes one fast inference operation inside child process.""" - _run_worker_entry_binding( - payload=payload, - connection=connection, - prepare_process_operation=_prepare_process_operation, - run_process_operation=_run_process_operation, - ) - - -def _prepare_process_operation(payload: FastProcessPayload) -> _PreparedFastOperation: - """Performs untimed setup for one process-isolated fast operation.""" - loaded_model = _load_fast_model(payload.settings, loaded_model=None) - return _PreparedFastOperation(loaded_model=loaded_model, request=payload.request) - - -def _run_process_operation(prepared: _PreparedFastOperation) -> InferenceResult: - """Runs one fast compute phase inside isolated worker process.""" - return predict_emotions_detailed( - prepared.request.file_path, - loaded_model=prepared.loaded_model, - ) - - -def _build_process_settings_snapshot(settings: AppConfig) -> AppConfig: - """Builds a process-safe settings snapshot for spawn-based workers.""" - return replace(settings, emotions=dict(settings.emotions)) - - -def _ensure_fast_compatible_model(loaded_model: LoadedModel) -> None: - """Validates that loaded artifact metadata is compatible with fast runtime.""" - metadata = loaded_model.artifact_metadata - if not isinstance(metadata, dict): - raise FastModelUnavailableError( - "Fast profile requires a v2 model artifact metadata envelope. " - "Train a fast-profile model before inference." - ) - if metadata.get("backend_id") != "handcrafted": - raise FastModelUnavailableError( - "No fast-profile model artifact is available. " - f"Found backend_id={metadata.get('backend_id')!r}; expected 'handcrafted'." - ) - if metadata.get("profile") != "fast": - raise FastModelUnavailableError( - "No fast-profile model artifact is available. " - f"Found profile={metadata.get('profile')!r}; expected 'fast'." - ) - - -def _terminate_worker_process(process: BaseProcess) -> None: - """Terminates a timed-out worker process with kill fallback.""" - _terminate_worker_process_binding( - process=process, - impl=_terminate_worker_process_impl, - terminate_grace_seconds=_TERMINATE_GRACE_SECONDS, - kill_grace_seconds=_KILL_GRACE_SECONDS, - ) - - -def _raise_worker_error(error_type: str, message: str) -> None: - """Rehydrates child-process errors into runtime-domain exceptions.""" - _raise_worker_error_binding( - error_type=error_type, - message=message, - impl=_raise_worker_error_impl, - known_error_factories=_WORKER_ERROR_FACTORIES, - unknown_error_factory=FastInferenceExecutionError, - worker_label="Fast inference", + model_unavailable_error_type=FastModelUnavailableError, + model_load_error_type=FastModelLoadError, + timeout_error_type=FastInferenceTimeoutError, + execution_error_type=FastInferenceExecutionError, + transient_error_type=FastTransientBackendError, ) -def _retry_delay_seconds(base_delay: float, attempt: int) -> float: - """Returns bounded retry delay for one retry attempt.""" - if base_delay <= 0.0: - return 0.0 - return base_delay * float(attempt) +__all__ = [ + "FastInferenceExecutionError", + "FastInferenceTimeoutError", + "FastModelLoadError", + "FastModelUnavailableError", + "FastTransientBackendError", + "run_fast_inference", +] diff --git a/ser/runtime/medium_inference.py b/ser/runtime/medium_inference.py index d3166b9..13464fa 100644 --- a/ser/runtime/medium_inference.py +++ b/ser/runtime/medium_inference.py @@ -1,118 +1,17 @@ -"""Medium-profile inference runner with encode-once/pool-many semantics.""" +"""Medium-profile public inference boundary.""" from __future__ import annotations -import multiprocessing as mp -from collections.abc import Callable -from dataclasses import dataclass -from multiprocessing.connection import Connection -from multiprocessing.process import BaseProcess -from typing import Literal, cast - -import numpy as np -from numpy.typing import NDArray - from ser._internal.runtime import medium_public_boundary as _boundary_support -from ser._internal.runtime.single_flight import SingleFlightRegistry -from ser._internal.runtime.worker_lifecycle import ( - is_setup_complete_message as _is_setup_complete_message_impl, -) -from ser._internal.runtime.worker_lifecycle import ( - parse_worker_completion_message as _parse_worker_completion_message_impl, -) -from ser._internal.runtime.worker_lifecycle import raise_worker_error as _raise_worker_error_impl -from ser._internal.runtime.worker_lifecycle import recv_worker_message as _recv_worker_message_impl -from ser._internal.runtime.worker_lifecycle import ( - run_process_setup_compute_handshake as _run_process_setup_compute_handshake_impl, -) -from ser._internal.runtime.worker_lifecycle import run_with_timeout as _run_with_timeout_impl -from ser._internal.runtime.worker_lifecycle import ( - terminate_worker_process as _terminate_worker_process_impl, -) -from ser.config import AppConfig, MediumRuntimeConfig -from ser.models.emotion_model import LoadedModel, load_model -from ser.models.profile_runtime import resolve_medium_model_id +from ser.config import AppConfig +from ser.models.emotion_model import LoadedModel from ser.repr import XLSRBackend -from ser.repr.runtime_policy import FeatureRuntimePolicy -from ser.runtime import medium_worker_operation as medium_worker_operation_helpers from ser.runtime.contracts import InferenceRequest -from ser.runtime.medium_execution_context import MediumExecutionContext as _MediumExecutionContext -from ser.runtime.medium_retry_operation import ( - run_medium_inference_with_retry_policy as _run_medium_retry_policy_impl, -) -from ser.runtime.medium_runtime_support import ( - build_cpu_medium_backend_for_settings as _build_cpu_medium_backend_for_settings_impl, -) -from ser.runtime.medium_runtime_support import ( - build_runtime_settings_snapshot as _build_runtime_settings_snapshot_impl, -) -from ser.runtime.medium_worker_lifecycle import ( - is_setup_complete_message as _is_setup_complete_message_orchestration, -) -from ser.runtime.medium_worker_lifecycle import ( - parse_worker_completion_message as _parse_worker_completion_message_orchestration, -) -from ser.runtime.medium_worker_lifecycle import ( - raise_worker_error as _raise_worker_error_orchestration, -) -from ser.runtime.medium_worker_lifecycle import ( - recv_worker_message as _recv_worker_message_orchestration, -) -from ser.runtime.medium_worker_lifecycle import ( - run_with_process_timeout as _run_with_process_timeout_orchestration, -) -from ser.runtime.medium_worker_lifecycle import run_worker_entry as _run_worker_entry_orchestration -from ser.runtime.medium_worker_lifecycle import ( - terminate_worker_process as _terminate_worker_process_orchestration, -) -from ser.runtime.phase_contract import PHASE_EMOTION_INFERENCE, PHASE_EMOTION_SETUP -from ser.runtime.phase_timing import ( - log_phase_completed, - log_phase_failed, - log_phase_started, -) from ser.runtime.schema import InferenceResult -from ser.utils.audio_utils import read_audio_file from ser.utils.logger import get_logger logger = get_logger(__name__) -type WorkerPhaseMessage = tuple[Literal["phase"], Literal["setup_complete"]] -type WorkerSuccessMessage = tuple[Literal["ok"], InferenceResult] -type WorkerErrorMessage = tuple[Literal["err"], str, str] -type WorkerMessage = WorkerPhaseMessage | WorkerSuccessMessage | WorkerErrorMessage - -_TERMINATE_GRACE_SECONDS = 0.5 -_KILL_GRACE_SECONDS = 0.5 -_SINGLE_FLIGHT_REGISTRY = SingleFlightRegistry() - - -@dataclass(frozen=True) -class MediumProcessPayload: - """Serializable payload for one process-isolated medium inference attempt.""" - - request: InferenceRequest - settings: AppConfig - expected_backend_model_id: str - - -type _PreparedMediumOperation = medium_worker_operation_helpers.PreparedMediumOperation[ - LoadedModel, - XLSRBackend, -] -type _MediumRetryOperationState = ( - medium_worker_operation_helpers.MediumRetryOperationState[ - MediumProcessPayload, - LoadedModel, - XLSRBackend, - ] -) -type _PreparedMediumExecutionContext = _MediumExecutionContext[ - MediumProcessPayload, - LoadedModel, - XLSRBackend, -] - class MediumModelUnavailableError(FileNotFoundError): """Raised when a compatible medium-profile model artifact is unavailable.""" @@ -138,17 +37,6 @@ class MediumTransientBackendError(RuntimeError): """Raised for retryable medium backend encoding failures.""" -_WORKER_ERROR_FACTORIES: dict[str, Callable[[str], Exception]] = { - "ValueError": ValueError, - "MediumRuntimeDependencyError": MediumRuntimeDependencyError, - "MediumTransientBackendError": MediumTransientBackendError, - "MediumModelUnavailableError": MediumModelUnavailableError, - "MediumModelLoadError": MediumModelLoadError, - "MediumInferenceTimeoutError": MediumInferenceTimeoutError, - "RuntimeError": RuntimeError, -} - - def run_medium_inference( request: InferenceRequest, settings: AppConfig, @@ -158,374 +46,30 @@ def run_medium_inference( enforce_timeout: bool = True, allow_retries: bool = True, ) -> InferenceResult: - """Runs medium-profile inference with bounded retries and timeout budget. - - Args: - request: Runtime inference request payload. - settings: Active application settings. - loaded_model: Optional preloaded model artifact for repeated inference calls. - backend: Optional preinitialized XLSR backend for repeated inference calls. - enforce_timeout: Whether to apply timeout wrapper for each inference attempt. - allow_retries: Whether to apply configured retry budget for retryable errors. - - Returns: - Detailed inference result with frame and segment predictions. - - Raises: - MediumModelUnavailableError: If no compatible medium artifact exists. - MediumRuntimeDependencyError: If required medium dependencies are missing. - MediumModelLoadError: If model loading fails for non-compatibility reasons. - MediumInferenceTimeoutError: If all attempts timed out. - MediumInferenceExecutionError: If transient backend failures exhaust retries. - ValueError: If feature dimensions are incompatible with loaded artifact. - """ - execution_context = _prepare_execution_context( - request=request, - settings=settings, + """Runs medium-profile inference with bounded retries and timeout budget.""" + return _boundary_support.run_medium_inference_from_public_boundary( + request, + settings, loaded_model=loaded_model, backend=backend, enforce_timeout=enforce_timeout, - ) - expected_backend_model_id = execution_context.expected_backend_model_id - - with _SINGLE_FLIGHT_REGISTRY.lock( - profile="medium", - backend_model_id=expected_backend_model_id, - ): - return _execute_medium_inference_with_retry( - execution_context=execution_context, - settings=settings, - injected_backend=backend, - enforce_timeout=enforce_timeout, - allow_retries=allow_retries, - expected_backend_model_id=expected_backend_model_id, - ) - - -def _run_with_timeout( - *, - operation: Callable[[], InferenceResult], - timeout_seconds: float, -) -> InferenceResult: - """Runs one in-process medium inference attempt under timeout budget.""" - return _run_with_timeout_impl( - operation=operation, - timeout_seconds=timeout_seconds, - timeout_error_factory=MediumInferenceTimeoutError, - timeout_label="Medium inference", - ) - - -def _run_medium_inference_once( - *, - loaded_model: LoadedModel, - backend: XLSRBackend, - audio: NDArray[np.float32], - sample_rate: int, - runtime_config: MediumRuntimeConfig, -) -> InferenceResult: - """Runs one medium inference attempt without retry control.""" - return _boundary_support.run_medium_inference_once( - loaded_model=loaded_model, - backend=backend, - audio=audio, - sample_rate=sample_rate, - runtime_config=runtime_config, - logger=logger, - is_dependency_error=_is_dependency_error, - dependency_error_factory=MediumRuntimeDependencyError, - transient_error_factory=MediumTransientBackendError, - ) - - -def _run_with_process_timeout( - payload: MediumProcessPayload, - *, - timeout_seconds: float, -) -> InferenceResult: - """Runs one process-isolated attempt with timeout applied only to compute.""" - return _run_with_process_timeout_orchestration( - payload=payload, - profile="medium", - timeout_seconds=timeout_seconds, - get_context=mp.get_context, - logger=logger, - setup_phase_name=PHASE_EMOTION_SETUP, - inference_phase_name=PHASE_EMOTION_INFERENCE, - log_phase_started=log_phase_started, - log_phase_completed=log_phase_completed, - log_phase_failed=log_phase_failed, - run_process_setup_compute_handshake=_run_process_setup_compute_handshake_impl, - worker_target=_worker_entry, - recv_worker_message=_recv_worker_message, - is_setup_complete_message=_is_setup_complete_message, - terminate_worker_process=_terminate_worker_process, - timeout_error_factory=MediumInferenceTimeoutError, - execution_error_factory=MediumInferenceExecutionError, - worker_label="Medium inference", - process_join_grace_seconds=_TERMINATE_GRACE_SECONDS, - parse_worker_completion_message=_parse_worker_completion_message, - ) - - -def _recv_worker_message( - connection: Connection, - *, - stage: str, -) -> tuple[object, ...]: - """Receives one worker message and validates tuple envelope shape.""" - return _recv_worker_message_orchestration( - connection=connection, - stage=stage, - impl=_recv_worker_message_impl, - worker_label="Medium inference", - error_factory=MediumInferenceExecutionError, - ) - - -def _is_setup_complete_message(message: tuple[object, ...]) -> bool: - """Returns whether one worker message marks setup completion.""" - return _is_setup_complete_message_orchestration( - message=message, - impl=_is_setup_complete_message_impl, - worker_label="Medium inference", - error_factory=MediumInferenceExecutionError, - ) - - -def _parse_worker_completion_message(worker_message: tuple[object, ...]) -> InferenceResult: - """Parses one worker completion message and returns inference result.""" - return _parse_worker_completion_message_orchestration( - worker_message=worker_message, - impl=_parse_worker_completion_message_impl, - worker_label="Medium inference", - error_factory=MediumInferenceExecutionError, - raise_worker_error=_raise_worker_error, - result_type=InferenceResult, - ) - - -def _worker_entry( - payload: MediumProcessPayload, - connection: Connection, -) -> None: - """Executes one medium inference operation inside a child process.""" - _run_worker_entry_orchestration( - payload=payload, - connection=connection, - prepare_process_operation=_prepare_process_operation, - run_process_operation=_run_process_operation, - ) - - -def _prepare_in_process_operation( - *, - request: InferenceRequest, - settings: AppConfig, - loaded_model: LoadedModel | None, - backend: XLSRBackend | None, - expected_backend_model_id: str, - runtime_device: str, - runtime_dtype: str, -) -> _PreparedMediumOperation: - """Performs untimed setup for one in-process medium operation.""" - return cast( - _PreparedMediumOperation, - _boundary_support.prepare_in_process_operation( - request=request, - settings=settings, - loaded_model=loaded_model, - backend=backend, - expected_backend_model_id=expected_backend_model_id, - runtime_device=runtime_device, - runtime_dtype=runtime_dtype, - load_model_fn=load_model, - read_audio_file_fn=read_audio_file, - backend_factory=XLSRBackend, - logger=logger, - model_unavailable_error_factory=MediumModelUnavailableError, - model_load_error_factory=MediumModelLoadError, - ), - ) - - -def _prepare_process_operation( - payload: MediumProcessPayload, -) -> _PreparedMediumOperation: - """Performs untimed setup for one process-isolated medium operation.""" - return cast( - _PreparedMediumOperation, - _boundary_support.prepare_process_operation( - payload, - load_model_fn=load_model, - read_audio_file_fn=read_audio_file, - backend_factory=XLSRBackend, - resolve_runtime_policy=lambda settings: _resolve_medium_feature_runtime_policy( - settings=settings - ), - logger=logger, - model_unavailable_error_factory=MediumModelUnavailableError, - model_load_error_factory=MediumModelLoadError, - prepare_medium_backend_runtime=lambda active_backend: _prepare_medium_backend_runtime( - backend=active_backend - ), - ), - ) - - -def _run_process_operation(prepared: _PreparedMediumOperation) -> InferenceResult: - """Runs one medium compute phase inside isolated worker process.""" - return _boundary_support.run_process_operation( - prepared, - run_medium_inference_once=lambda **kwargs: _run_medium_inference_once(**kwargs), - ) - - -def _prepare_execution_context( - *, - request: InferenceRequest, - settings: AppConfig, - loaded_model: LoadedModel | None, - backend: XLSRBackend | None, - enforce_timeout: bool, -) -> _PreparedMediumExecutionContext: - """Resolves pre-lock runtime context for medium inference execution.""" - return cast( - _PreparedMediumExecutionContext, - _boundary_support.prepare_execution_context( - request=request, - settings=settings, - loaded_model=loaded_model, - backend=backend, - enforce_timeout=enforce_timeout, - resolve_medium_model_id=resolve_medium_model_id, - resolve_runtime_policy=lambda active_settings: _resolve_medium_feature_runtime_policy( - settings=active_settings - ), - prepare_retry_state=medium_worker_operation_helpers.prepare_retry_state, - prepare_in_process_operation=_prepare_in_process_operation, - build_process_payload=lambda backend_model_id, policy_device, policy_dtype: ( - MediumProcessPayload( - request=request, - settings=_build_runtime_settings_snapshot_impl( - settings, - runtime_device=policy_device, - runtime_dtype=policy_dtype, - ), - expected_backend_model_id=backend_model_id, - ) - ), - logger=logger, - ), - ) - - -def _execute_medium_inference_with_retry( - *, - execution_context: _PreparedMediumExecutionContext, - settings: AppConfig, - injected_backend: XLSRBackend | None, - enforce_timeout: bool, - allow_retries: bool, - expected_backend_model_id: str, -) -> InferenceResult: - """Executes medium inference inside the single-flight lock.""" - return _boundary_support.execute_medium_inference_with_retry( - execution_context=execution_context, - settings=settings, - injected_backend=injected_backend, - enforce_timeout=enforce_timeout, allow_retries=allow_retries, - expected_backend_model_id=expected_backend_model_id, logger=logger, - run_with_process_timeout=lambda payload, timeout_seconds: _run_with_process_timeout( - payload, - timeout_seconds=timeout_seconds, - ), - run_process_operation=_run_process_operation, - run_with_timeout=lambda operation, timeout_seconds: _run_with_timeout( - operation=operation, - timeout_seconds=timeout_seconds, - ), - prepare_medium_backend_runtime=lambda active_backend: _prepare_medium_backend_runtime( - backend=active_backend - ), - cpu_backend_builder=lambda: _build_cpu_medium_backend_for_settings_impl( - settings=settings, - expected_backend_model_id=expected_backend_model_id, - backend_factory=XLSRBackend, - ), - timeout_error_type=MediumInferenceTimeoutError, - transient_error_type=MediumTransientBackendError, + model_unavailable_error_type=MediumModelUnavailableError, runtime_dependency_error_type=MediumRuntimeDependencyError, + model_load_error_type=MediumModelLoadError, + timeout_error_type=MediumInferenceTimeoutError, execution_error_type=MediumInferenceExecutionError, - run_retry_policy_impl=_run_medium_retry_policy_impl, - retry_delay_seconds=_retry_delay_seconds, - should_retry_on_cpu_after_transient_failure=_should_retry_on_cpu_after_transient_failure, - summarize_transient_failure=_summarize_transient_failure, - ) - - -def _resolve_medium_feature_runtime_policy( - *, - settings: AppConfig, -) -> FeatureRuntimePolicy: - """Resolves backend-aware feature runtime selectors for medium profile.""" - return _boundary_support.resolve_medium_feature_runtime_policy( - settings=settings, - ) - - -def _terminate_worker_process(process: BaseProcess) -> None: - """Terminates a timed-out worker process with kill fallback.""" - _terminate_worker_process_orchestration( - process=process, - impl=_terminate_worker_process_impl, - terminate_grace_seconds=_TERMINATE_GRACE_SECONDS, - kill_grace_seconds=_KILL_GRACE_SECONDS, - ) - - -def _raise_worker_error(error_type: str, message: str) -> None: - """Rehydrates child-process errors into runtime-domain exceptions.""" - _raise_worker_error_orchestration( - error_type=error_type, - message=message, - impl=_raise_worker_error_impl, - known_error_factories=_WORKER_ERROR_FACTORIES, - unknown_error_factory=MediumInferenceExecutionError, - worker_label="Medium inference", - ) - - -def _retry_delay_seconds(*, base_delay: float, attempt: int) -> float: - """Returns bounded retry delay with small jitter.""" - return _boundary_support.retry_delay_seconds( - base_delay=base_delay, - attempt=attempt, + transient_error_type=MediumTransientBackendError, ) -def _should_retry_on_cpu_after_transient_failure(err: Exception) -> bool: - """Returns whether one transient failure should trigger CPU fallback retry.""" - return _boundary_support.should_retry_on_cpu_after_transient_failure(err) - - -def _summarize_transient_failure(err: Exception) -> str: - """Builds one compact summary line for medium transient fallback logs.""" - return _boundary_support.summarize_transient_failure(err) - - -def _is_dependency_error(err: RuntimeError) -> bool: - """Returns whether runtime error indicates missing optional modules.""" - return _boundary_support.is_dependency_error(err) - - -def _prepare_medium_backend_runtime(*, backend: XLSRBackend) -> None: - """Warms backend runtime components so setup work is outside timeout budgets.""" - _boundary_support.prepare_medium_backend_runtime( - backend=backend, - is_dependency_error=_is_dependency_error, - dependency_error_factory=MediumRuntimeDependencyError, - transient_error_factory=MediumTransientBackendError, - ) +__all__ = [ + "MediumInferenceExecutionError", + "MediumInferenceTimeoutError", + "MediumModelLoadError", + "MediumModelUnavailableError", + "MediumRuntimeDependencyError", + "MediumTransientBackendError", + "run_medium_inference", +] diff --git a/ser/transcript/backends/stable_whisper_mps_compat.py b/ser/transcript/backends/stable_whisper_mps_compat.py index 94870a9..0a1b6fd 100644 --- a/ser/transcript/backends/stable_whisper_mps_compat.py +++ b/ser/transcript/backends/stable_whisper_mps_compat.py @@ -5,8 +5,12 @@ import importlib import inspect import logging +import sys import threading +import warnings from contextlib import contextmanager +from dataclasses import dataclass +from functools import lru_cache from typing import Any, cast import numpy as np @@ -21,10 +25,21 @@ type _DtwCallable = Any type _ComputeQksCallable = Any type _DisableSdpaCallable = Any +type _LogMelSpectrogramCallable = Any _logger = logging.getLogger(__name__) +@dataclass(frozen=True, slots=True) +class _MpsLogMelCompatibilityDecision: + """Cached decision for whether MPS log-mel must be computed on CPU.""" + + enable_cpu_offload: bool + reason_code: str + python_version: str + torch_version: str + + def enable_stable_whisper_mps_compatibility(model: object) -> object: """Moves one stable-whisper model to MPS with sparse-buffer safeguards.""" moved_model = move_model_to_mps_with_alignment_placeholder(model) @@ -127,14 +142,149 @@ def get_stable_whisper_runtime_device( return default_device_type +def _mps_backend_available() -> bool: + """Returns whether the current torch runtime can execute MPS probes.""" + backends = getattr(torch, "backends", None) + mps_backend = getattr(backends, "mps", None) + is_built = getattr(mps_backend, "is_built", None) + is_available = getattr(mps_backend, "is_available", None) + return bool(callable(is_built) and callable(is_available) and is_built() and is_available()) + + +@lru_cache(maxsize=1) +def _resolve_mps_log_mel_compatibility_decision() -> _MpsLogMelCompatibilityDecision: + """Returns whether this runtime needs CPU log-mel offload on MPS.""" + decision_kwargs = { + "python_version": sys.version.split()[0], + "torch_version": str(getattr(torch, "__version__", "unknown")), + } + if not _mps_backend_available(): + return _MpsLogMelCompatibilityDecision( + enable_cpu_offload=False, + reason_code="mps_unavailable", + **decision_kwargs, + ) + try: + _probe_mps_log_mel_frontend() + except Exception as err: + if _is_mps_log_mel_compatibility_error(err): + return _MpsLogMelCompatibilityDecision( + enable_cpu_offload=True, + reason_code="mps_log_mel_frontend_cpu_offload_required", + **decision_kwargs, + ) + _logger.debug( + "Ignored non-compatibility failure while probing MPS log-mel frontend support.", + exc_info=True, + ) + return _MpsLogMelCompatibilityDecision( + enable_cpu_offload=False, + reason_code="mps_log_mel_probe_unclassified_failure", + **decision_kwargs, + ) + return _MpsLogMelCompatibilityDecision( + enable_cpu_offload=False, + reason_code="mps_log_mel_frontend_supported", + **decision_kwargs, + ) + + +def _probe_mps_log_mel_frontend() -> None: + """Exercises the exact whisper log-mel frontend on MPS for compatibility gating.""" + whisper_audio_module = importlib.import_module("whisper.audio") + log_mel_spectrogram = getattr(whisper_audio_module, "log_mel_spectrogram", None) + if not callable(log_mel_spectrogram): + raise RuntimeError("whisper.audio does not expose log_mel_spectrogram().") + probe_audio = torch.zeros(2048, dtype=torch.float32, device="mps") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + log_mel_spectrogram(probe_audio, 80, 0) + + +def _is_mps_log_mel_compatibility_error(err: Exception) -> bool: + """Returns whether one log-mel probe error requires CPU frontend offload.""" + message = " ".join(str(err).split()).lower() + if isinstance(err, NotImplementedError): + if "aten::_fft_r2c" in message and "mps" in message: + return True + if "not currently implemented" in message and "mps" in message and "aten::" in message: + return True + if any(marker in message for marker in ("complexfloat", "complexhalf", "complex64")): + return "mps" in message + return False + + +def _resolve_mps_log_mel_target_device( + audio: object, + device: object | None, +) -> torch.device | None: + """Returns the MPS device that should receive the frontend output, when applicable.""" + if device is not None: + if not isinstance(device, (str, int, torch.device)): + return None + try: + explicit_device = torch.device(device) + except (TypeError, RuntimeError, ValueError): + explicit_device = None + if explicit_device is None or explicit_device.type != "mps": + return None + return explicit_device + if torch.is_tensor(audio) and cast(torch.Tensor, audio).device.type == "mps": + return cast(torch.Tensor, audio).device + return None + + +def _build_mps_safe_log_mel_spectrogram( + original_log_mel_spectrogram: _LogMelSpectrogramCallable, +) -> _LogMelSpectrogramCallable: + """Builds one log-mel adapter that computes the frontend on CPU for MPS callers.""" + + def _log_mel_cpu_safe( + audio: object, + n_mels: int = 80, + padding: int = 0, + device: object | None = None, + ) -> torch.Tensor: + target_device = _resolve_mps_log_mel_target_device(audio, device) + if target_device is None: + return cast( + torch.Tensor, + original_log_mel_spectrogram( + audio, + n_mels=n_mels, + padding=padding, + device=device, + ), + ) + cpu_audio = cast(torch.Tensor, audio).float().cpu() if torch.is_tensor(audio) else audio + cpu_log_mel = cast( + torch.Tensor, + original_log_mel_spectrogram( + cpu_audio, + n_mels=n_mels, + padding=padding, + device="cpu", + ), + ) + return cpu_log_mel.to(device=target_device) + + return _log_mel_cpu_safe + + @contextmanager def stable_whisper_mps_timing_compatibility_context() -> Any: - """Patches stable-whisper timing aliases for MPS-compatible word alignment.""" + """Patches stable-whisper MPS gaps for frontend and timing compatibility.""" with _TIMING_PATCH_LOCK: timing_module = importlib.import_module("stable_whisper.timing") compatibility_module = importlib.import_module("stable_whisper.whisper_compatibility") + original_whisper_module = importlib.import_module( + "stable_whisper.whisper_word_level.original_whisper" + ) + whisper_audio_module = importlib.import_module("whisper.audio") timing_module_any = cast(Any, timing_module) compatibility_module_any = cast(Any, compatibility_module) + original_whisper_module_any = cast(Any, original_whisper_module) + whisper_audio_module_any = cast(Any, whisper_audio_module) whisper_timing_module = importlib.import_module("whisper.timing") dtw_cpu = getattr(whisper_timing_module, "dtw_cpu", None) if not callable(dtw_cpu): @@ -144,6 +294,15 @@ def stable_whisper_mps_timing_compatibility_context() -> Any: original_compat_dtw = cast(_DtwCallable, compatibility_module_any.dtw) original_timing_dtw = cast(_DtwCallable, timing_module_any.dtw) original_timing_compute_qks = getattr(timing_module_any, "_compute_qks", None) + original_compat_log_mel = getattr(compatibility_module_any, "log_mel_spectrogram", None) + original_original_whisper_log_mel = getattr( + original_whisper_module_any, + "log_mel_spectrogram", + None, + ) + original_whisper_audio_log_mel = getattr( + whisper_audio_module_any, "log_mel_spectrogram", None + ) safe_dtw = _build_cpu_safe_dtw(dtw_cpu) safe_std_mean = _build_mps_safe_std_mean(original_std_mean) @@ -151,6 +310,21 @@ def stable_whisper_mps_timing_compatibility_context() -> Any: compatibility_module_any.dtw = safe_dtw timing_module_any.dtw = safe_dtw torch.std_mean = cast(_StdMeanCallable, safe_std_mean) + frontend_decision = _resolve_mps_log_mel_compatibility_decision() + if frontend_decision.enable_cpu_offload and callable(original_whisper_audio_log_mel): + safe_log_mel = _build_mps_safe_log_mel_spectrogram(original_whisper_audio_log_mel) + if callable(original_compat_log_mel): + compatibility_module_any.log_mel_spectrogram = safe_log_mel + if callable(original_original_whisper_log_mel): + original_whisper_module_any.log_mel_spectrogram = safe_log_mel + whisper_audio_module_any.log_mel_spectrogram = safe_log_mel + _logger.debug( + "Enabled stable-whisper CPU log-mel fallback patch for MPS compatibility " + "context (reason=%s, python=%s, torch=%s).", + frontend_decision.reason_code, + frontend_decision.python_version, + frontend_decision.torch_version, + ) if callable(original_timing_compute_qks): disable_sdpa = getattr(compatibility_module_any, "disable_sdpa", None) if callable(disable_sdpa): @@ -175,6 +349,12 @@ def stable_whisper_mps_timing_compatibility_context() -> Any: torch.std_mean = original_std_mean compatibility_module_any.dtw = original_compat_dtw timing_module_any.dtw = original_timing_dtw + if callable(original_compat_log_mel): + compatibility_module_any.log_mel_spectrogram = original_compat_log_mel + if callable(original_original_whisper_log_mel): + original_whisper_module_any.log_mel_spectrogram = original_original_whisper_log_mel + if callable(original_whisper_audio_log_mel): + whisper_audio_module_any.log_mel_spectrogram = original_whisper_audio_log_mel if original_timing_compute_qks is not None: timing_module_any._compute_qks = original_timing_compute_qks diff --git a/ser/transcript/transcript_extractor.py b/ser/transcript/transcript_extractor.py index e15f0ab..2b8658d 100644 --- a/ser/transcript/transcript_extractor.py +++ b/ser/transcript/transcript_extractor.py @@ -3,62 +3,13 @@ from __future__ import annotations import logging -import multiprocessing as mp from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Literal, Never, Protocol, cast +from typing import TYPE_CHECKING, Literal, Protocol, cast from ser._internal.transcription import public_boundary_support as _boundary_support -from ser._internal.transcription.process_isolation import ( - WorkerMessage, -) -from ser._internal.transcription.process_isolation import ( - raise_worker_error as _raise_worker_error_impl, -) -from ser._internal.transcription.process_isolation import ( - recv_worker_message as _recv_worker_message_impl, -) -from ser._internal.transcription.process_isolation import ( - should_use_process_isolated_path as _should_use_process_isolated_path_impl, -) -from ser._internal.transcription.process_isolation import ( - terminate_worker_process as _terminate_worker_process_impl, -) -from ser._internal.transcription.process_isolation import ( - transcription_worker_entry as _transcription_worker_entry_impl, -) -from ser._internal.transcription.process_worker import ( - TranscriptionProcessPayload as _TranscriptionProcessPayload, -) -from ser._internal.transcription.process_worker import ( - TranscriptionWorkerModelsConfig as _TranscriptionWorkerModelsConfig, -) -from ser._internal.transcription.process_worker import ( - TranscriptionWorkerSettings as _TranscriptionWorkerSettings, -) from ser._internal.transcription.process_worker import ( release_transcription_runtime_memory as _release_transcription_runtime_memory_impl, ) -from ser._internal.transcription.public_boundary_process import ( - raise_worker_error_from_public_boundary as _raise_worker_error_boundary_impl, -) -from ser._internal.transcription.public_boundary_process import ( - recv_worker_message_from_public_boundary as _recv_worker_message_boundary_impl, -) -from ser._internal.transcription.public_boundary_process import ( - resolve_transcription_adapter_from_public_boundary as _resolve_transcription_adapter_boundary_impl, -) -from ser._internal.transcription.public_boundary_process import ( - spawn_context_for_public_boundary as _spawn_context_boundary_impl, -) -from ser._internal.transcription.public_boundary_process import ( - terminate_worker_process_from_public_boundary as _terminate_worker_process_boundary_impl, -) -from ser._internal.transcription.public_boundary_process import ( - transcription_worker_entry_from_public_boundary as _transcription_worker_entry_boundary_impl, -) -from ser._internal.transcription.runtime_profile import ( - runtime_request_from_profile as _runtime_request_from_profile_impl, -) from ser.config import AppConfig, reload_settings from ser.domain import TranscriptWord from ser.profiles import ( @@ -71,29 +22,13 @@ log_phase_failed, log_phase_started, ) -from ser.transcript.backends import ( - BackendRuntimeRequest, - CompatibilityReport, - resolve_transcription_backend_adapter, -) -from ser.transcript.runtime_policy import ( - DEFAULT_MPS_LOW_MEMORY_THRESHOLD_GB, - resolve_transcription_runtime_policy, -) from ser.utils.logger import get_logger if TYPE_CHECKING: from stable_whisper.result import WhisperResult logger: logging.Logger = get_logger(__name__) -_TERMINATE_GRACE_SECONDS = 5.0 -_KILL_GRACE_SECONDS = 2.0 - -_PROCESS_WORKER_NAMESPACE = ( - _TranscriptionProcessPayload, - _TranscriptionWorkerModelsConfig, - _TranscriptionWorkerSettings, -) +type _CompatibilityIssueKind = Literal["noise", "operational"] class TranscriptionError(RuntimeError): @@ -126,12 +61,6 @@ class TranscriptionProfile: ) -type _WorkerMessage = WorkerMessage -type _CompatibilityIssueKind = Literal["noise", "operational"] - -_EMITTED_COMPATIBILITY_ISSUE_KEYS: set[tuple[str, str, str]] = set() - - def _resolve_catalog_transcription_defaults( profile: Literal["fast", "medium", "accurate", "accurate-research"], ) -> ProfileTranscriptionDefaults: @@ -144,65 +73,23 @@ def _resolve_boundary_settings(settings: AppConfig | None) -> AppConfig: return settings if settings is not None else reload_settings() -def _resolve_transcription_profile_for_settings( +def resolve_transcription_profile( profile: TranscriptionProfile | None = None, *, - settings: AppConfig, + settings: AppConfig | None = None, ) -> TranscriptionProfile: - """Resolves one transcription profile against an explicit settings snapshot.""" + """Resolves profile overrides or falls back to configured defaults.""" return cast( TranscriptionProfile, _boundary_support.resolve_transcription_profile_for_settings( profile, - settings=settings, + settings=_resolve_boundary_settings(settings), profile_factory=TranscriptionProfile, error_factory=TranscriptionError, ), ) -def resolve_transcription_profile( - profile: TranscriptionProfile | None = None, - *, - settings: AppConfig | None = None, -) -> TranscriptionProfile: - """Resolves profile overrides or falls back to configured defaults.""" - return _resolve_transcription_profile_for_settings( - profile, - settings=_resolve_boundary_settings(settings), - ) - - -def _runtime_request_from_profile( - active_profile: TranscriptionProfile, - settings: AppConfig, -) -> BackendRuntimeRequest: - """Builds one backend runtime request from transcription profile settings.""" - return _runtime_request_from_profile_impl( - active_profile=active_profile, - settings=settings, - runtime_policy_resolver=resolve_transcription_runtime_policy, - default_mps_low_memory_threshold_gb=DEFAULT_MPS_LOW_MEMORY_THRESHOLD_GB, - ) - - -def _check_adapter_compatibility( - *, - active_profile: TranscriptionProfile, - settings: AppConfig, - runtime_request: BackendRuntimeRequest | None = None, -) -> CompatibilityReport: - """Validates backend compatibility and logs non-blocking compatibility issues.""" - return _boundary_support.check_adapter_compatibility( - active_profile=active_profile, - settings=settings, - runtime_request=runtime_request, - emitted_issue_keys=_EMITTED_COMPATIBILITY_ISSUE_KEYS, - logger=logger, - error_factory=TranscriptionError, - ) - - def mark_compatibility_issues_as_emitted( *, backend_id: TranscriptionBackendId, @@ -214,53 +101,6 @@ def mark_compatibility_issues_as_emitted( backend_id=backend_id, issue_kind=issue_kind, issue_codes=issue_codes, - emitted_issue_keys=_EMITTED_COMPATIBILITY_ISSUE_KEYS, - ) - - -def _transcription_setup_required( - *, - active_profile: TranscriptionProfile, - settings: AppConfig, -) -> bool: - """Returns whether a setup/download phase is needed before model load.""" - return _boundary_support.transcription_setup_required( - active_profile=active_profile, - settings=settings, - emitted_issue_keys=_EMITTED_COMPATIBILITY_ISSUE_KEYS, - logger=logger, - error_factory=TranscriptionError, - ) - - -def _prepare_transcription_assets( - *, - active_profile: TranscriptionProfile, - settings: AppConfig, -) -> None: - """Ensures required stable-whisper model assets are present locally.""" - _boundary_support.prepare_transcription_assets( - active_profile=active_profile, - settings=settings, - emitted_issue_keys=_EMITTED_COMPATIBILITY_ISSUE_KEYS, - logger=logger, - error_factory=TranscriptionError, - ) - - -def _load_whisper_model_for_settings( - profile: TranscriptionProfile | None = None, - *, - settings: AppConfig, -) -> object: - """Loads one transcription model for an explicit settings snapshot.""" - return _boundary_support.load_whisper_model_for_settings( - profile=profile, - settings=settings, - profile_factory=TranscriptionProfile, - logger=logger, - emitted_issue_keys=_EMITTED_COMPATIBILITY_ISSUE_KEYS, - error_factory=TranscriptionError, ) @@ -269,112 +109,13 @@ def load_whisper_model( *, settings: AppConfig | None = None, ) -> object: - """Loads the configured transcription model for resolved runtime settings. - - Returns: - The loaded Whisper model instance. - """ - return _load_whisper_model_for_settings( + """Loads the configured transcription model for resolved runtime settings.""" + return _boundary_support.load_whisper_model_for_settings( profile=profile, settings=_resolve_boundary_settings(settings), - ) - - -def _should_use_process_isolated_path(profile: TranscriptionProfile) -> bool: - """Returns whether one transcription profile should execute in a worker process.""" - return _should_use_process_isolated_path_impl(profile) - - -def _runtime_request_for_isolated_faster_whisper( - profile: TranscriptionProfile, - settings: AppConfig, -) -> BackendRuntimeRequest: - """Builds one faster-whisper runtime request without importing torch in worker.""" - return _boundary_support._runtime_request_for_isolated_faster_whisper( - profile=profile, - settings=settings, - error_factory=TranscriptionError, - logger=logger, - ) - - -def _run_faster_whisper_process_isolated( - *, - file_path: str, - language: str, - profile: TranscriptionProfile, - settings: AppConfig, -) -> list[TranscriptWord]: - """Runs faster-whisper setup/load/transcribe inside one spawned worker process.""" - return _boundary_support.run_faster_whisper_process_isolated( - file_path=file_path, - language=language, - profile=profile, - settings=settings, - transcript_word_factory=TranscriptWord, - spawn_context_resolver=_spawn_context, - worker_entry=_transcription_worker_entry, - recv_worker_message_fn=_recv_worker_message, - raise_worker_error_fn=_raise_worker_error, - terminate_worker_process_fn=_terminate_worker_process, + profile_factory=TranscriptionProfile, logger=logger, error_factory=TranscriptionError, - terminate_grace_seconds=_TERMINATE_GRACE_SECONDS, - ) - - -def _recv_worker_message(connection: object, *, stage: str) -> _WorkerMessage: - """Receives one worker message and validates tuple envelope shape.""" - return _recv_worker_message_boundary_impl( - connection, - recv_worker_message_impl=_recv_worker_message_impl, - stage=stage, - error_factory=TranscriptionError, - ) - - -def _raise_worker_error(message: object) -> Never: - """Raises one transcription-domain error from a worker payload.""" - _raise_worker_error_boundary_impl( - cast(_WorkerMessage, message), - raise_worker_error_impl=_raise_worker_error_impl, - error_factory=TranscriptionError, - ) - - -def _terminate_worker_process(process: object) -> None: - """Terminates a worker process with kill fallback.""" - _terminate_worker_process_boundary_impl( - process, - terminate_worker_process_impl=_terminate_worker_process_impl, - terminate_grace_seconds=_TERMINATE_GRACE_SECONDS, - kill_grace_seconds=_KILL_GRACE_SECONDS, - ) - - -def _spawn_context() -> object: - """Returns the spawn context used for faster-whisper process isolation.""" - return _spawn_context_boundary_impl(get_context=mp.get_context) - - -def _resolve_transcription_adapter(backend_id: TranscriptionBackendId) -> object: - """Resolves one transcription adapter for worker execution.""" - return _resolve_transcription_adapter_boundary_impl( - backend_id, - adapter_resolver=resolve_transcription_backend_adapter, - ) - - -def _transcription_worker_entry( - payload: object, - connection: object, -) -> None: - """Executes faster-whisper transcription inside one isolated worker process.""" - _transcription_worker_entry_boundary_impl( - payload, - connection, - transcription_worker_entry_impl=_transcription_worker_entry_impl, - adapter_resolver=_resolve_transcription_adapter, ) @@ -385,123 +126,24 @@ def extract_transcript( *, settings: AppConfig | None = None, ) -> list[TranscriptWord]: - """Extracts a transcript with per-word timing for an input audio file. - - Args: - file_path: Path to the audio file. - language: Language code used by Whisper during transcription. - profile: Optional runtime overrides for model and preprocessing toggles. - - Returns: - A list of transcript word entries with timing metadata. - """ + """Extracts a transcript with per-word timing for an input audio file.""" active_settings = _resolve_boundary_settings(settings) - active_language: str = language or active_settings.default_language - return _extract_transcript( + active_language = language or active_settings.default_language + return _boundary_support.extract_transcript( file_path, active_language, profile, settings=active_settings, - ) - - -def _extract_transcript( - file_path: str, - language: str, - profile: TranscriptionProfile | None = None, - *, - settings: AppConfig, -) -> list[TranscriptWord]: - """Internal transcript workflow with backend-specific execution strategy.""" - return _boundary_support.extract_transcript( - file_path, - language, - profile, - settings=settings, profile_factory=TranscriptionProfile, - transcript_word_factory=TranscriptWord, logger=logger, - emitted_issue_keys=_EMITTED_COMPATIBILITY_ISSUE_KEYS, error_factory=TranscriptionError, - terminate_grace_seconds=_TERMINATE_GRACE_SECONDS, - spawn_context_resolver=_spawn_context, - worker_entry=_transcription_worker_entry, - recv_worker_message_fn=_recv_worker_message, - raise_worker_error_fn=_raise_worker_error, - terminate_worker_process_fn=_terminate_worker_process, - release_memory_fn=_release_transcription_runtime_memory, - phase_started_fn=log_phase_started, - phase_completed_fn=log_phase_completed, - phase_failed_fn=log_phase_failed, - ) - - -def _extract_transcript_in_process( - *, - file_path: str, - language: str, - profile: TranscriptionProfile, - settings: AppConfig, -) -> list[TranscriptWord]: - """Runs one in-process transcript workflow with phase-aware logging.""" - return _boundary_support.extract_transcript_in_process( - file_path=file_path, - language=language, - profile=profile, - settings=settings, - profile_factory=TranscriptionProfile, - emitted_issue_keys=_EMITTED_COMPATIBILITY_ISSUE_KEYS, - error_factory=TranscriptionError, - release_memory_fn=_release_transcription_runtime_memory, + release_memory_fn=lambda *, model: _release_transcription_runtime_memory_impl( + model=model, + logger=logger, + ), phase_started_fn=log_phase_started, phase_completed_fn=log_phase_completed, phase_failed_fn=log_phase_failed, - logger=logger, - ) - - -def _release_transcription_runtime_memory(*, model: object | None) -> None: - """Releases best-effort Torch runtime memory after one in-process transcript run.""" - _release_transcription_runtime_memory_impl(model=model, logger=logger) - - -def __transcribe_file( - model: object, - language: str, - file_path: str, - *, - settings: AppConfig, -) -> list[TranscriptWord]: - """Runs a Whisper transcription call and normalizes return types.""" - return _transcribe_file_with_profile( - model, - language, - file_path, - profile=None, - settings=settings, - ) - - -def _transcribe_file_with_profile( - model: object, - language: str, - file_path: str, - profile: TranscriptionProfile | None, - *, - settings: AppConfig, -) -> list[TranscriptWord]: - """Runs a Whisper transcription call using an explicit runtime profile.""" - return _boundary_support.transcribe_with_profile( - model, - language, - file_path, - profile, - settings=settings, - profile_factory=TranscriptionProfile, - emitted_issue_keys=_EMITTED_COMPATIBILITY_ISSUE_KEYS, - error_factory=TranscriptionError, - passthrough_error_cls=TranscriptionError, - logger=logger, ) @@ -514,27 +156,37 @@ def transcribe_with_model( settings: AppConfig | None = None, ) -> list[TranscriptWord]: """Transcribes one file with a pre-loaded model for profiling workloads.""" - return _transcribe_file_with_profile( + return _boundary_support.transcribe_with_profile( model, language, file_path, - profile=profile, + profile, settings=_resolve_boundary_settings(settings), + profile_factory=TranscriptionProfile, + logger=logger, + error_factory=TranscriptionError, + passthrough_error_cls=TranscriptionError, ) def format_transcript(result: WhisperResult) -> list[TranscriptWord]: - """Formats a Whisper result object into a word-level timestamp list. - - Args: - result: Whisper transcription result. - - Returns: - A list of transcript word entries with timing metadata. - """ + """Formats a Whisper result object into a word-level timestamp list.""" return _boundary_support.format_transcript( result, transcript_word_factory=TranscriptWord, logger=logger, error_factory=TranscriptionError, ) + + +__all__ = [ + "TranscriptionError", + "TranscriptionProfile", + "WhisperWord", + "extract_transcript", + "format_transcript", + "load_whisper_model", + "mark_compatibility_issues_as_emitted", + "resolve_transcription_profile", + "transcribe_with_model", +] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..2a71537 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test package marker for importable spawned-process support modules.""" diff --git a/tests/conftest.py b/tests/conftest.py index a671d3b..9fd5dc9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,8 @@ import pytest +pytest_plugins = ("tests.fixtures.settings",) + @pytest.fixture(scope="session") def repo_root(pytestconfig: pytest.Config) -> Path: diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 0000000..2f5b783 --- /dev/null +++ b/tests/fixtures/__init__.py @@ -0,0 +1 @@ +"""Reusable pytest fixture modules.""" diff --git a/tests/fixtures/settings.py b/tests/fixtures/settings.py new file mode 100644 index 0000000..2378e0f --- /dev/null +++ b/tests/fixtures/settings.py @@ -0,0 +1,17 @@ +"""Shared fixtures for ambient settings state.""" + +from __future__ import annotations + +from collections.abc import Generator + +import pytest + +import ser.config as config + + +@pytest.fixture +def reset_ambient_settings() -> Generator[None]: + """Resets ambient settings before and after one test.""" + config.reload_settings() + yield + config.reload_settings() diff --git a/tests/suites/integration/test_accurate_inference.py b/tests/suites/integration/test_accurate_inference.py new file mode 100644 index 0000000..58332d1 --- /dev/null +++ b/tests/suites/integration/test_accurate_inference.py @@ -0,0 +1,364 @@ +"""Public accurate inference behavior tests.""" + +from __future__ import annotations + +import pickle +import threading +import time + +import numpy as np +import pytest +from sklearn.neural_network import MLPClassifier + +import ser.config as config +from ser._internal.runtime import accurate_public_boundary as accurate_boundary +from ser.models import emotion_model +from ser.repr import EncodedSequence +from ser.runtime.accurate_inference import ( + AccurateInferenceExecutionError, + AccurateInferenceTimeoutError, + AccurateModelUnavailableError, + AccurateRuntimeDependencyError, + AccurateTransientBackendError, + run_accurate_inference, +) +from ser.runtime.contracts import InferenceRequest +from ser.runtime.schema import OUTPUT_SCHEMA_VERSION + +pytestmark = [pytest.mark.integration, pytest.mark.usefixtures("reset_ambient_settings")] + + +class _PredictModel(MLPClassifier): + """Deterministic model stub for accurate runtime tests.""" + + def __init__(self) -> None: + super().__init__(hidden_layer_sizes=(1,), max_iter=1, random_state=0) + self.classes_ = np.asarray(["happy", "sad"], dtype=object) + + def predict(self, X: np.ndarray) -> np.ndarray: # noqa: N803 + return np.asarray(["happy"] * int(X.shape[0]), dtype=object) + + def predict_proba(self, X: np.ndarray) -> np.ndarray: # noqa: N803 + return np.asarray([[0.9, 0.1]] * int(X.shape[0]), dtype=np.float64) + + +class _FakeBackend: + """Deterministic accurate backend stub.""" + + def __init__(self) -> None: + self.encode_calls = 0 + + def encode_sequence(self, _audio: np.ndarray, _sample_rate: int) -> EncodedSequence: + self.encode_calls += 1 + return EncodedSequence( + embeddings=np.asarray( + [ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + ], + dtype=np.float32, + ), + frame_start_seconds=np.asarray([0.0, 1.0, 2.0], dtype=np.float64), + frame_end_seconds=np.asarray([1.0, 2.0, 3.0], dtype=np.float64), + backend_id="hf_whisper", + ) + + +def _accurate_metadata( + feature_vector_size: int = 4, + *, + backend_model_id: str | None = emotion_model.ACCURATE_MODEL_ID, + backend_id: str = "hf_whisper", + profile: str = "accurate", +) -> dict[str, object]: + """Builds minimal accurate-profile artifact metadata for runtime tests.""" + metadata: dict[str, object] = { + "artifact_version": emotion_model.MODEL_ARTIFACT_VERSION, + "artifact_schema_version": "v2", + "created_at_utc": "2026-02-21T00:00:00+00:00", + "feature_vector_size": feature_vector_size, + "training_samples": 8, + "labels": ["happy", "sad"], + "backend_id": backend_id, + "profile": profile, + "feature_dim": feature_vector_size, + "frame_size_seconds": 1.0, + "frame_stride_seconds": 1.0, + "pooling_strategy": "mean_std", + } + if backend_model_id is not None: + metadata["backend_model_id"] = backend_model_id + return metadata + + +def _patch_runtime_prerequisites( + monkeypatch: pytest.MonkeyPatch, + *, + backend_model_id: str, + backend: _FakeBackend | None = None, + metadata: dict[str, object] | None = None, +) -> _FakeBackend: + """Patches model/audio/backend prerequisites for accurate runtime tests.""" + resolved_backend = backend or _FakeBackend() + active_metadata = metadata or _accurate_metadata(backend_model_id=backend_model_id) + monkeypatch.setattr( + accurate_boundary, + "load_model", + lambda **_kwargs: emotion_model.LoadedModel( + model=_PredictModel(), + expected_feature_size=4, + artifact_metadata=active_metadata, + ), + ) + monkeypatch.setattr( + accurate_boundary, + "read_audio_file", + lambda _file_path, *, audio_read_config=None: ( + np.linspace(0.0, 1.0, 16, dtype=np.float32), + 4, + ), + ) + monkeypatch.setattr( + accurate_boundary, + "WhisperBackend", + lambda **_kwargs: resolved_backend, + ) + return resolved_backend + + +def test_accurate_timeout_retries_up_to_configured_budget( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Timeouts should retry up to `max_retries + 1` attempts and then fail.""" + monkeypatch.setenv("SER_ACCURATE_MAX_TIMEOUT_RETRIES", "2") + monkeypatch.setenv("SER_ACCURATE_RETRY_BACKOFF_SECONDS", "0.5") + settings = config.reload_settings() + _patch_runtime_prerequisites( + monkeypatch, + backend_model_id=settings.models.accurate_model_id, + ) + + calls = {"attempts": 0, "sleeps": 0} + + def fake_timeout_runner(*_args: object, **_kwargs: object) -> object: + calls["attempts"] += 1 + raise AccurateInferenceTimeoutError("timeout") + + monkeypatch.setattr(accurate_boundary, "_run_with_timeout_impl", fake_timeout_runner) + monkeypatch.setattr( + accurate_boundary, + "retry_delay_seconds", + lambda **_kwargs: 0.1, + ) + monkeypatch.setattr( + "ser.runtime.policy.time.sleep", + lambda _delay: calls.__setitem__("sleeps", calls["sleeps"] + 1), + ) + + with pytest.raises(AccurateInferenceTimeoutError, match="timeout"): + run_accurate_inference( + InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), + settings, + ) + + assert calls["attempts"] == 3 + assert calls["sleeps"] == 2 + + +def test_accurate_process_isolation_uses_spawn_safe_worker_target( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Process-isolated accurate calls should pass a top-level picklable worker target.""" + monkeypatch.setenv("SER_ENABLE_PROFILE_PIPELINE", "true") + monkeypatch.setenv("SER_ACCURATE_PROCESS_ISOLATION", "true") + settings = config.reload_settings() + captured: dict[str, object] = {} + expected = accurate_boundary.InferenceResult( + schema_version=OUTPUT_SCHEMA_VERSION, + segments=[], + frames=[], + ) + + def fail_if_called(**_kwargs: object) -> object: + raise AssertionError("In-process load_model path should not run in process mode.") + + def fake_process_runner(*_args: object, **kwargs: object) -> accurate_boundary.InferenceResult: + captured.update(kwargs) + return expected + + monkeypatch.setattr(accurate_boundary, "load_model", fail_if_called) + monkeypatch.setattr(accurate_boundary, "_run_with_process_timeout_impl", fake_process_runner) + + result = run_accurate_inference( + InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), + settings, + ) + + worker_target = captured["worker_target"] + assert result == expected + assert callable(worker_target) + qualname = getattr(worker_target, "__qualname__", "") + assert isinstance(qualname, str) + assert "" not in qualname + assert pickle.dumps(worker_target) + + +def test_accurate_transient_backend_failure_respects_retry_upper_bound( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Transient backend failures should stop after bounded retry attempts.""" + monkeypatch.setenv("SER_ACCURATE_MAX_TRANSIENT_RETRIES", "2") + monkeypatch.setenv("SER_ACCURATE_RETRY_BACKOFF_SECONDS", "0") + settings = config.reload_settings() + _patch_runtime_prerequisites( + monkeypatch, + backend_model_id=settings.models.accurate_model_id, + ) + + calls = {"attempts": 0} + + def fake_attempt(**_kwargs: object) -> object: + calls["attempts"] += 1 + raise AccurateTransientBackendError("transient backend failure") + + monkeypatch.setattr( + accurate_boundary, + "_run_with_timeout_impl", + lambda **kwargs: kwargs["operation"](), + ) + monkeypatch.setattr(accurate_boundary, "run_accurate_inference_once", fake_attempt) + + with pytest.raises(AccurateInferenceExecutionError, match="retry budget"): + run_accurate_inference( + InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), + settings, + ) + + assert calls["attempts"] == 3 + + +def test_accurate_dependency_error_is_not_retried( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Dependency failures should bypass retry policy.""" + monkeypatch.setenv("SER_ACCURATE_MAX_TRANSIENT_RETRIES", "4") + settings = config.reload_settings() + _patch_runtime_prerequisites( + monkeypatch, + backend_model_id=settings.models.accurate_model_id, + ) + + calls = {"attempts": 0} + + def fake_attempt(**_kwargs: object) -> object: + calls["attempts"] += 1 + raise AccurateRuntimeDependencyError("missing runtime dependency") + + monkeypatch.setattr( + accurate_boundary, + "_run_with_timeout_impl", + lambda **kwargs: kwargs["operation"](), + ) + monkeypatch.setattr(accurate_boundary, "run_accurate_inference_once", fake_attempt) + + with pytest.raises(AccurateRuntimeDependencyError, match="missing runtime dependency"): + run_accurate_inference( + InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), + settings, + ) + + assert calls["attempts"] == 1 + + +def test_accurate_inference_returns_expected_schema( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Accurate runtime should return deterministic inference schema payload.""" + settings = config.reload_settings() + backend = _patch_runtime_prerequisites( + monkeypatch, + backend_model_id=settings.models.accurate_model_id, + ) + + result = run_accurate_inference( + InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), + settings, + ) + + assert backend.encode_calls == 1 + assert result.schema_version == OUTPUT_SCHEMA_VERSION + assert len(result.frames) == 3 + assert [frame.emotion for frame in result.frames] == ["happy", "happy", "happy"] + + +def test_accurate_inference_rejects_non_accurate_artifact_metadata( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Mismatched artifact metadata should be rejected before inference.""" + settings = config.reload_settings() + _patch_runtime_prerequisites( + monkeypatch, + backend_model_id=settings.models.accurate_model_id, + metadata=_accurate_metadata( + backend_model_id=settings.models.accurate_model_id, + backend_id="hf_xlsr", + ), + ) + + with pytest.raises(AccurateModelUnavailableError, match="backend_id"): + run_accurate_inference( + InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), + settings, + ) + + +def test_accurate_single_flight_serializes_same_profile_model_calls( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Single-flight registry should serialize same-profile accurate calls.""" + settings = config.reload_settings() + _patch_runtime_prerequisites( + monkeypatch, + backend_model_id=settings.models.accurate_model_id, + ) + monkeypatch.setattr( + accurate_boundary, + "_run_with_timeout_impl", + lambda **kwargs: kwargs["operation"](), + ) + + state = {"active": 0, "max_active": 0} + state_lock = threading.Lock() + + def fake_attempt(**_kwargs: object) -> object: + with state_lock: + state["active"] += 1 + state["max_active"] = max(state["max_active"], state["active"]) + try: + time.sleep(0.05) + return accurate_boundary.InferenceResult( + schema_version=OUTPUT_SCHEMA_VERSION, + segments=[], + frames=[], + ) + finally: + with state_lock: + state["active"] -= 1 + + monkeypatch.setattr(accurate_boundary, "run_accurate_inference_once", fake_attempt) + + def invoke() -> None: + result = run_accurate_inference( + InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), + settings, + ) + assert result.schema_version == OUTPUT_SCHEMA_VERSION + + threads = [threading.Thread(target=invoke) for _ in range(2)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert state["max_active"] == 1 diff --git a/tests/test_backend_hooks.py b/tests/suites/integration/test_backend_hooks.py similarity index 97% rename from tests/test_backend_hooks.py rename to tests/suites/integration/test_backend_hooks.py index e685b8a..e3fb9cf 100644 --- a/tests/test_backend_hooks.py +++ b/tests/suites/integration/test_backend_hooks.py @@ -2,8 +2,6 @@ from __future__ import annotations -from collections.abc import Generator - import pytest import ser.config as config @@ -11,13 +9,7 @@ from ser.runtime.contracts import InferenceRequest from ser.runtime.schema import OUTPUT_SCHEMA_VERSION, InferenceResult - -@pytest.fixture(autouse=True) -def _reset_settings() -> Generator[None]: - """Keeps global settings stable across tests.""" - config.reload_settings() - yield - config.reload_settings() +pytestmark = [pytest.mark.integration, pytest.mark.usefixtures("reset_ambient_settings")] def test_build_backend_hooks_registers_fast_by_default() -> None: diff --git a/tests/suites/integration/test_fast_inference.py b/tests/suites/integration/test_fast_inference.py new file mode 100644 index 0000000..dd483e6 --- /dev/null +++ b/tests/suites/integration/test_fast_inference.py @@ -0,0 +1,220 @@ +"""Public fast inference behavior tests.""" + +from __future__ import annotations + +import pickle +import threading +import time +from typing import cast + +import pytest + +import ser.config as config +import ser.models.training_support as training_support +from ser._internal.runtime import fast_public_boundary as fast_boundary +from ser.models import emotion_model +from ser.runtime.contracts import InferenceRequest +from ser.runtime.fast_inference import ( + FastInferenceTimeoutError, + FastModelUnavailableError, + run_fast_inference, +) +from ser.runtime.schema import OUTPUT_SCHEMA_VERSION, InferenceResult + +pytestmark = [pytest.mark.integration, pytest.mark.usefixtures("reset_ambient_settings")] + + +def _fast_metadata( + *, + backend_id: str = "handcrafted", + profile: str = "fast", +) -> dict[str, object]: + """Builds minimal fast-profile artifact metadata for runtime tests.""" + return { + "artifact_version": emotion_model.MODEL_ARTIFACT_VERSION, + "artifact_schema_version": "v2", + "created_at_utc": "2026-02-24T00:00:00+00:00", + "feature_vector_size": 193, + "training_samples": 8, + "labels": ["happy", "sad"], + "backend_id": backend_id, + "profile": profile, + "feature_dim": 193, + "frame_size_seconds": 3.0, + "frame_stride_seconds": 1.0, + "pooling_strategy": "mean", + } + + +def _patch_fast_prerequisites( + monkeypatch: pytest.MonkeyPatch, + *, + metadata: dict[str, object] | None = None, +) -> None: + """Patches model/detail prerequisites for fast runtime tests.""" + monkeypatch.setattr( + fast_boundary, + "load_model", + lambda **_kwargs: emotion_model.LoadedModel( + model=cast(training_support.EmotionClassifier, object()), + expected_feature_size=193, + artifact_metadata=metadata or _fast_metadata(), + ), + ) + + +def test_fast_inference_returns_expected_schema( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Fast runtime should return deterministic inference schema payload.""" + settings = config.reload_settings() + _patch_fast_prerequisites(monkeypatch) + expected = InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) + monkeypatch.setattr( + fast_boundary, + "predict_emotions_detailed", + lambda _file_path, loaded_model=None: expected, + ) + + result = run_fast_inference( + InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), + settings, + ) + + assert result == expected + + +def test_fast_timeout_retries_up_to_configured_budget( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Timeouts should retry up to `max_timeout_retries + 1` attempts and then fail.""" + monkeypatch.setenv("SER_FAST_MAX_TIMEOUT_RETRIES", "2") + monkeypatch.setenv("SER_FAST_RETRY_BACKOFF_SECONDS", "0.5") + monkeypatch.setenv("SER_FAST_TIMEOUT_SECONDS", "1.0") + settings = config.reload_settings() + _patch_fast_prerequisites(monkeypatch) + + calls = {"attempts": 0, "sleeps": 0} + + def fake_timeout_runner(*_args: object, **_kwargs: object) -> object: + calls["attempts"] += 1 + raise FastInferenceTimeoutError("timeout") + + monkeypatch.setattr(fast_boundary, "_run_with_timeout_impl", fake_timeout_runner) + monkeypatch.setattr( + "ser.runtime.policy.time.sleep", + lambda _delay: calls.__setitem__("sleeps", calls["sleeps"] + 1), + ) + + with pytest.raises(FastInferenceTimeoutError, match="timeout"): + run_fast_inference( + InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), + settings, + ) + + assert calls["attempts"] == 3 + assert calls["sleeps"] == 2 + + +def test_fast_profile_pipeline_uses_process_timeout_runner( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Profile-pipeline fast calls should route attempts through process timeout path.""" + monkeypatch.setenv("SER_ENABLE_PROFILE_PIPELINE", "true") + monkeypatch.setenv("SER_FAST_PROCESS_ISOLATION", "true") + monkeypatch.setenv("SER_FAST_MAX_TIMEOUT_RETRIES", "1") + monkeypatch.setenv("SER_FAST_RETRY_BACKOFF_SECONDS", "0.1") + settings = config.reload_settings() + + calls = {"process": 0, "sleep": 0} + captured: dict[str, object] = {} + expected = InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) + + def fail_if_called(**_kwargs: object) -> object: + raise AssertionError("In-process load_model path should not run in process mode.") + + def fake_process_runner(*_args: object, **_kwargs: object) -> InferenceResult: + captured.update(_kwargs) + calls["process"] += 1 + if calls["process"] == 1: + raise FastInferenceTimeoutError("timeout") + return expected + + monkeypatch.setattr(fast_boundary, "load_model", fail_if_called) + monkeypatch.setattr(fast_boundary, "_run_with_process_timeout_impl", fake_process_runner) + monkeypatch.setattr( + "ser.runtime.policy.time.sleep", + lambda _delay: calls.__setitem__("sleep", calls["sleep"] + 1), + ) + + result = run_fast_inference( + InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), + settings, + ) + + assert result == expected + assert calls["process"] == 2 + assert calls["sleep"] == 1 + worker_target = captured["worker_target"] + assert callable(worker_target) + qualname = getattr(worker_target, "__qualname__", "") + assert isinstance(qualname, str) + assert "" not in qualname + assert pickle.dumps(worker_target) + + +def test_fast_inference_rejects_mismatched_artifact_metadata( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Mismatched fast artifact metadata should be rejected before inference.""" + settings = config.reload_settings() + _patch_fast_prerequisites( + monkeypatch, + metadata=_fast_metadata(backend_id="hf_whisper"), + ) + + with pytest.raises(FastModelUnavailableError, match="backend_id"): + run_fast_inference( + InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), + settings, + ) + + +def test_fast_single_flight_serializes_calls( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Single-flight registry should serialize same-profile fast calls.""" + settings = config.reload_settings() + _patch_fast_prerequisites(monkeypatch) + + state = {"active": 0, "max_active": 0} + state_lock = threading.Lock() + + def fake_predict(_file_path: str, *, loaded_model: object | None = None) -> InferenceResult: + del loaded_model + with state_lock: + state["active"] += 1 + state["max_active"] = max(state["max_active"], state["active"]) + try: + time.sleep(0.05) + return InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) + finally: + with state_lock: + state["active"] -= 1 + + monkeypatch.setattr(fast_boundary, "predict_emotions_detailed", fake_predict) + + def invoke() -> None: + result = run_fast_inference( + InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), + settings, + ) + assert result.schema_version == OUTPUT_SCHEMA_VERSION + + threads = [threading.Thread(target=invoke) for _ in range(2)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert state["max_active"] == 1 diff --git a/tests/suites/integration/test_medium_inference.py b/tests/suites/integration/test_medium_inference.py new file mode 100644 index 0000000..a1c3d66 --- /dev/null +++ b/tests/suites/integration/test_medium_inference.py @@ -0,0 +1,364 @@ +"""Public medium inference behavior tests.""" + +from __future__ import annotations + +import pickle +import threading +import time + +import numpy as np +import pytest +from numpy.typing import NDArray +from sklearn.neural_network import MLPClassifier + +import ser.config as config +from ser._internal.runtime import medium_public_boundary as medium_boundary +from ser.models import emotion_model +from ser.repr import EncodedSequence +from ser.runtime.contracts import InferenceRequest +from ser.runtime.medium_inference import ( + MediumInferenceExecutionError, + MediumInferenceTimeoutError, + MediumModelUnavailableError, + MediumRuntimeDependencyError, + MediumTransientBackendError, + run_medium_inference, +) +from ser.runtime.schema import OUTPUT_SCHEMA_VERSION + +pytestmark = [pytest.mark.integration, pytest.mark.usefixtures("reset_ambient_settings")] + + +class _PredictModel(MLPClassifier): + """Deterministic classifier stub for medium inference contract tests.""" + + def __init__(self) -> None: + super().__init__(hidden_layer_sizes=(1,), max_iter=1, random_state=0) + self.classes_ = np.asarray(["happy", "sad"], dtype=object) + self.last_features: NDArray[np.float64] | None = None + + def predict(self, X: np.ndarray) -> np.ndarray: # noqa: N803 + self.last_features = np.asarray(X, dtype=np.float64) + return np.asarray(["happy"] * int(X.shape[0]), dtype=object) + + def predict_proba(self, X: np.ndarray) -> np.ndarray: # noqa: N803 + self.last_features = np.asarray(X, dtype=np.float64) + return np.asarray([[0.9, 0.1]] * int(X.shape[0]), dtype=np.float64) + + +class _FakeBackend: + """Deterministic backend stub that tracks encode invocation count.""" + + def __init__(self) -> None: + self.encode_calls = 0 + + def encode_sequence( + self, + _audio: NDArray[np.float32], + _sample_rate: int, + ) -> EncodedSequence: + self.encode_calls += 1 + return EncodedSequence( + embeddings=np.asarray( + [ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + ], + dtype=np.float32, + ), + frame_start_seconds=np.asarray([0.0, 1.0, 2.0], dtype=np.float64), + frame_end_seconds=np.asarray([1.0, 2.0, 3.0], dtype=np.float64), + backend_id="hf_xlsr", + ) + + +def _medium_metadata( + feature_vector_size: int = 4, + *, + backend_model_id: str | None = emotion_model.MEDIUM_MODEL_ID, + backend_id: str = "hf_xlsr", + profile: str = "medium", +) -> dict[str, object]: + """Builds minimal medium-profile artifact metadata for loader tests.""" + metadata: dict[str, object] = { + "artifact_version": emotion_model.MODEL_ARTIFACT_VERSION, + "artifact_schema_version": "v2", + "created_at_utc": "2026-02-19T00:00:00+00:00", + "feature_vector_size": feature_vector_size, + "training_samples": 8, + "labels": ["happy", "sad"], + "backend_id": backend_id, + "profile": profile, + "feature_dim": feature_vector_size, + "frame_size_seconds": 1.0, + "frame_stride_seconds": 1.0, + "pooling_strategy": "mean_std", + } + if backend_model_id is not None: + metadata["backend_model_id"] = backend_model_id + return metadata + + +def _patch_runtime_prerequisites( + monkeypatch: pytest.MonkeyPatch, + *, + backend_model_id: str, + backend: _FakeBackend | None = None, + metadata: dict[str, object] | None = None, +) -> _FakeBackend: + """Patches model/audio/backend prerequisites for medium runtime tests.""" + resolved_backend = backend or _FakeBackend() + active_metadata = metadata or _medium_metadata(backend_model_id=backend_model_id) + monkeypatch.setattr( + medium_boundary, + "load_model", + lambda **_kwargs: emotion_model.LoadedModel( + model=_PredictModel(), + expected_feature_size=4, + artifact_metadata=active_metadata, + ), + ) + monkeypatch.setattr( + medium_boundary, + "read_audio_file", + lambda _file_path, *, audio_read_config=None: ( + np.linspace(0.0, 1.0, 16, dtype=np.float32), + 4, + ), + ) + monkeypatch.setattr(medium_boundary, "XLSRBackend", lambda **_kwargs: resolved_backend) + return resolved_backend + + +def test_run_medium_inference_uses_encode_once_and_returns_schema_result( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Medium inference should encode once and return deterministic segments.""" + settings = config.reload_settings() + backend = _patch_runtime_prerequisites( + monkeypatch, + backend_model_id=settings.models.medium_model_id, + ) + + result = run_medium_inference( + InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), + settings, + ) + + assert backend.encode_calls == 1 + assert result.schema_version == OUTPUT_SCHEMA_VERSION + assert len(result.frames) == 3 + assert [frame.emotion for frame in result.frames] == ["happy", "happy", "happy"] + + +def test_run_medium_inference_fails_fast_for_non_medium_artifact( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Non-medium artifacts should be rejected before expensive encoding work.""" + settings = config.reload_settings() + _patch_runtime_prerequisites( + monkeypatch, + backend_model_id=settings.models.medium_model_id, + metadata=_medium_metadata( + backend_model_id=settings.models.medium_model_id, + backend_id="hf_whisper", + ), + ) + + with pytest.raises(MediumModelUnavailableError, match="backend_id"): + run_medium_inference( + InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), + settings, + ) + + +def test_medium_timeout_retries_up_to_configured_budget( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Timeouts should retry up to `max_retries + 1` attempts and then fail.""" + monkeypatch.setenv("SER_MEDIUM_MAX_TIMEOUT_RETRIES", "2") + monkeypatch.setenv("SER_MEDIUM_RETRY_BACKOFF_SECONDS", "0.5") + settings = config.reload_settings() + _patch_runtime_prerequisites( + monkeypatch, + backend_model_id=settings.models.medium_model_id, + ) + + calls = {"attempts": 0, "sleeps": 0} + + def fake_timeout_runner(*_args: object, **_kwargs: object) -> object: + calls["attempts"] += 1 + raise MediumInferenceTimeoutError("timeout") + + monkeypatch.setattr(medium_boundary, "_run_with_timeout_impl", fake_timeout_runner) + monkeypatch.setattr(medium_boundary, "retry_delay_seconds", lambda **_kwargs: 0.1) + monkeypatch.setattr( + "ser.runtime.policy.time.sleep", + lambda _delay: calls.__setitem__("sleeps", calls["sleeps"] + 1), + ) + + with pytest.raises(MediumInferenceTimeoutError, match="timeout"): + run_medium_inference( + InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), + settings, + ) + + assert calls["attempts"] == 3 + assert calls["sleeps"] == 2 + + +def test_medium_process_isolation_uses_spawn_safe_worker_target( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Process-isolated medium calls should pass a top-level picklable worker target.""" + monkeypatch.setenv("SER_ENABLE_PROFILE_PIPELINE", "true") + monkeypatch.setenv("SER_MEDIUM_PROCESS_ISOLATION", "true") + settings = config.reload_settings() + captured: dict[str, object] = {} + expected = medium_boundary.InferenceResult( + schema_version=OUTPUT_SCHEMA_VERSION, + segments=[], + frames=[], + ) + + def fail_if_called(**_kwargs: object) -> object: + raise AssertionError("In-process load_model path should not run in process mode.") + + def fake_process_runner(*_args: object, **kwargs: object) -> medium_boundary.InferenceResult: + captured.update(kwargs) + return expected + + monkeypatch.setattr(medium_boundary, "load_model", fail_if_called) + monkeypatch.setattr(medium_boundary, "_run_with_process_timeout_impl", fake_process_runner) + + result = run_medium_inference( + InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), + settings, + ) + + worker_target = captured["worker_target"] + assert result == expected + assert callable(worker_target) + qualname = getattr(worker_target, "__qualname__", "") + assert isinstance(qualname, str) + assert "" not in qualname + assert pickle.dumps(worker_target) + + +def test_medium_dependency_error_is_not_retried( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Dependency failures should bypass retry policy.""" + monkeypatch.setenv("SER_MEDIUM_MAX_TRANSIENT_RETRIES", "4") + settings = config.reload_settings() + _patch_runtime_prerequisites( + monkeypatch, + backend_model_id=settings.models.medium_model_id, + ) + + calls = {"attempts": 0} + + def fake_attempt(**_kwargs: object) -> object: + calls["attempts"] += 1 + raise MediumRuntimeDependencyError("missing runtime dependency") + + monkeypatch.setattr( + medium_boundary, + "_run_with_timeout_impl", + lambda **kwargs: kwargs["operation"](), + ) + monkeypatch.setattr(medium_boundary, "run_medium_inference_once", fake_attempt) + + with pytest.raises(MediumRuntimeDependencyError, match="missing runtime dependency"): + run_medium_inference( + InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), + settings, + ) + + assert calls["attempts"] == 1 + + +def test_medium_transient_failure_respects_retry_upper_bound( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Transient backend failures should stop after bounded retry attempts.""" + monkeypatch.setenv("SER_MEDIUM_MAX_TRANSIENT_RETRIES", "2") + monkeypatch.setenv("SER_MEDIUM_RETRY_BACKOFF_SECONDS", "0") + settings = config.reload_settings() + _patch_runtime_prerequisites( + monkeypatch, + backend_model_id=settings.models.medium_model_id, + ) + + calls = {"attempts": 0} + + def fake_attempt(**_kwargs: object) -> object: + calls["attempts"] += 1 + raise MediumTransientBackendError("transient backend failure") + + monkeypatch.setattr( + medium_boundary, + "_run_with_timeout_impl", + lambda **kwargs: kwargs["operation"](), + ) + monkeypatch.setattr(medium_boundary, "run_medium_inference_once", fake_attempt) + + with pytest.raises(MediumInferenceExecutionError, match="retry budget"): + run_medium_inference( + InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), + settings, + ) + + assert calls["attempts"] == 3 + + +def test_medium_single_flight_serializes_same_profile_model_calls( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Single-flight registry should serialize same-profile medium calls.""" + settings = config.reload_settings() + _patch_runtime_prerequisites( + monkeypatch, + backend_model_id=settings.models.medium_model_id, + ) + monkeypatch.setattr( + medium_boundary, + "_run_with_timeout_impl", + lambda **kwargs: kwargs["operation"](), + ) + + state = {"active": 0, "max_active": 0} + state_lock = threading.Lock() + + def fake_attempt(**_kwargs: object) -> object: + with state_lock: + state["active"] += 1 + state["max_active"] = max(state["max_active"], state["active"]) + try: + time.sleep(0.05) + return medium_boundary.InferenceResult( + schema_version=OUTPUT_SCHEMA_VERSION, + segments=[], + frames=[], + ) + finally: + with state_lock: + state["active"] -= 1 + + monkeypatch.setattr(medium_boundary, "run_medium_inference_once", fake_attempt) + + def invoke() -> None: + result = run_medium_inference( + InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), + settings, + ) + assert result.schema_version == OUTPUT_SCHEMA_VERSION + + threads = [threading.Thread(target=invoke) for _ in range(2)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert state["max_active"] == 1 diff --git a/tests/suites/integration/test_process_isolation.py b/tests/suites/integration/test_process_isolation.py new file mode 100644 index 0000000..d90dcaf --- /dev/null +++ b/tests/suites/integration/test_process_isolation.py @@ -0,0 +1,314 @@ +"""Integration coverage for spawned runtime and transcription process paths.""" + +from __future__ import annotations + +import logging +import multiprocessing as mp +import pickle +from collections.abc import Callable +from multiprocessing.connection import Connection +from multiprocessing.process import BaseProcess +from pathlib import Path +from typing import cast + +import pytest +from tests.utils.helpers import process_spawn_support + +from ser._internal.runtime import process_timeout as runtime_process_timeout +from ser._internal.runtime import worker_lifecycle as runtime_worker_lifecycle +from ser._internal.transcription import process_isolation as transcription_process_isolation +from ser._internal.transcription import public_boundary_support as transcription_boundary_support +from ser.config import AppConfig, reload_settings +from ser.domain import TranscriptWord +from ser.transcript.backends import BackendRuntimeRequest +from ser.transcript.transcript_extractor import TranscriptionProfile + +pytestmark = [ + pytest.mark.integration, + pytest.mark.process_isolation, + pytest.mark.usefixtures("reset_ambient_settings"), +] + + +class _RuntimeTimeoutError(TimeoutError): + """Timeout-domain error used by spawned runtime process integration tests.""" + + +def _log_phase_started( + _logger: object, + *, + phase_name: str, + profile: str, +) -> float: + del phase_name, profile + return 0.0 + + +def _log_phase_completed( + _logger: object, + *, + phase_name: str, + started_at: float | None = None, + profile: str | None = None, +) -> None: + del phase_name, started_at, profile + + +def _log_phase_failed( + _logger: object, + *, + phase_name: str, + started_at: float | None = None, + profile: str | None = None, +) -> None: + del phase_name, started_at, profile + + +def _run_runtime_process_timeout( + payload: process_spawn_support.RuntimeWorkerPayload, + *, + timeout_seconds: float, +) -> str: + logger = logging.getLogger("ser.tests.runtime_process_spawn") + + def _recv_worker_message(connection: Connection, *, stage: str) -> tuple[object, ...]: + return runtime_worker_lifecycle.recv_worker_message( + connection=connection, + stage=stage, + worker_label="Synthetic runtime", + error_factory=RuntimeError, + ) + + def _is_setup_complete_message(message: tuple[object, ...]) -> bool: + return runtime_worker_lifecycle.is_setup_complete_message( + message=message, + worker_label="Synthetic runtime", + error_factory=RuntimeError, + ) + + def _terminate_worker_process(process: object) -> None: + runtime_worker_lifecycle.terminate_worker_process( + process=cast(BaseProcess, process), + terminate_grace_seconds=0.1, + kill_grace_seconds=0.1, + ) + + def _raise_worker_error(error_type: str, message: str) -> None: + runtime_worker_lifecycle.raise_worker_error( + error_type=error_type, + message=message, + known_error_factories={"ValueError": ValueError}, + unknown_error_factory=RuntimeError, + worker_label="Synthetic runtime", + ) + + def _parse_worker_completion_message(message: tuple[object, ...]) -> str: + return runtime_worker_lifecycle.parse_worker_completion_message( + worker_message=message, + worker_label="Synthetic runtime", + error_factory=RuntimeError, + raise_worker_error=_raise_worker_error, + result_type=str, + ) + + return runtime_process_timeout.run_with_process_timeout( + payload=payload, + resolve_profile=lambda _payload: "synthetic", + timeout_seconds=timeout_seconds, + get_context=mp.get_context, + logger=logger, + setup_phase_name="setup", + inference_phase_name="compute", + log_phase_started=_log_phase_started, + log_phase_completed=_log_phase_completed, + log_phase_failed=_log_phase_failed, + run_process_setup_compute_handshake=runtime_worker_lifecycle.run_process_setup_compute_handshake, + worker_target=process_spawn_support.runtime_worker_entry, + recv_worker_message=_recv_worker_message, + is_setup_complete_message=_is_setup_complete_message, + terminate_worker_process=_terminate_worker_process, + timeout_error_factory=_RuntimeTimeoutError, + execution_error_factory=RuntimeError, + worker_label="Synthetic runtime", + process_join_grace_seconds=0.1, + parse_worker_completion_message=_parse_worker_completion_message, + ) + + +def _transcription_runtime_request() -> BackendRuntimeRequest: + return BackendRuntimeRequest( + model_name="tiny", + use_demucs=False, + use_vad=True, + device_spec="cpu", + device_type="cpu", + precision_candidates=("float32",), + memory_tier="not_applicable", + ) + + +def test_runtime_process_timeout_executes_real_spawn_worker() -> None: + """Runtime helper should complete the real spawned-worker happy path.""" + result = _run_runtime_process_timeout( + process_spawn_support.RuntimeWorkerPayload(result="done"), + timeout_seconds=1.0, + ) + + assert result == "done" + + +def test_runtime_process_timeout_maps_worker_errors_from_real_spawn_worker() -> None: + """Runtime helper should rehydrate worker errors from a real spawned child.""" + with pytest.raises(ValueError, match="bad input"): + _run_runtime_process_timeout( + process_spawn_support.RuntimeWorkerPayload( + error_type="ValueError", + error_message="bad input", + ), + timeout_seconds=1.0, + ) + + +def test_runtime_process_timeout_enforces_timeout_for_real_spawn_worker() -> None: + """Runtime helper should terminate a real spawned worker on timeout.""" + with pytest.raises(_RuntimeTimeoutError, match="Synthetic runtime exceeded timeout"): + _run_runtime_process_timeout( + process_spawn_support.RuntimeWorkerPayload(compute_delay_seconds=0.2), + timeout_seconds=0.01, + ) + + +def test_transcription_process_isolation_executes_real_spawn_worker( + tmp_path: Path, +) -> None: + """Transcription helper should complete one real spawned-worker success path.""" + settings = process_spawn_support.FakeSettings( + models=process_spawn_support.FakeModelsConfig(whisper_download_root=tmp_path), + ) + profile = process_spawn_support.FakeTranscriptionProfile() + + result = transcription_process_isolation.run_faster_whisper_process_isolated( + file_path="sample.wav", + language="en", + profile=profile, + settings_resolver=lambda: cast(AppConfig, settings), + runtime_request_resolver=lambda _profile, _settings: _transcription_runtime_request(), + payload_factory=cast( + Callable[..., object], + process_spawn_support.build_transcription_payload, + ), + get_spawn_context=lambda: mp.get_context("spawn"), + worker_entry=cast( + Callable[[object, object], None], + process_spawn_support.transcription_success_worker, + ), + recv_worker_message_fn=lambda connection, *, stage: ( + transcription_process_isolation.recv_worker_message( + connection, + stage=stage, + error_factory=RuntimeError, + ) + ), + raise_worker_error_fn=lambda message: transcription_process_isolation.raise_worker_error( + message, + error_factory=RuntimeError, + ), + terminate_worker_process_fn=lambda process: ( + transcription_process_isolation.terminate_worker_process( + process, + terminate_grace_seconds=0.1, + kill_grace_seconds=0.1, + ) + ), + transcript_word_factory=TranscriptWord, + logger=logging.getLogger("ser.tests.transcription_process_spawn"), + error_factory=RuntimeError, + terminate_grace_seconds=0.1, + ) + + assert result == [TranscriptWord("hello", 0.0, 0.5), TranscriptWord("world", 0.5, 1.0)] + + +def test_transcription_process_isolation_maps_real_spawn_worker_errors( + tmp_path: Path, +) -> None: + """Transcription helper should surface errors from a real spawned worker.""" + settings = process_spawn_support.FakeSettings( + models=process_spawn_support.FakeModelsConfig(whisper_download_root=tmp_path), + ) + profile = process_spawn_support.FakeTranscriptionProfile() + + with pytest.raises( + RuntimeError, + match="Transcription worker model_load failed with RuntimeError: boom", + ): + transcription_process_isolation.run_faster_whisper_process_isolated( + file_path="sample.wav", + language="en", + profile=profile, + settings_resolver=lambda: cast(AppConfig, settings), + runtime_request_resolver=lambda _profile, _settings: _transcription_runtime_request(), + payload_factory=cast( + Callable[..., object], + process_spawn_support.build_transcription_payload, + ), + get_spawn_context=lambda: mp.get_context("spawn"), + worker_entry=cast( + Callable[[object, object], None], + process_spawn_support.transcription_error_worker, + ), + recv_worker_message_fn=lambda connection, *, stage: ( + transcription_process_isolation.recv_worker_message( + connection, + stage=stage, + error_factory=RuntimeError, + ) + ), + raise_worker_error_fn=lambda message: transcription_process_isolation.raise_worker_error( + message, + error_factory=RuntimeError, + ), + terminate_worker_process_fn=lambda process: ( + transcription_process_isolation.terminate_worker_process( + process, + terminate_grace_seconds=0.1, + kill_grace_seconds=0.1, + ) + ), + transcript_word_factory=TranscriptWord, + logger=logging.getLogger("ser.tests.transcription_process_spawn"), + error_factory=RuntimeError, + terminate_grace_seconds=0.1, + ) + + +def test_transcription_public_boundary_worker_target_is_spawn_safe( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Public faster-whisper boundary should pass a top-level picklable worker target.""" + captured: dict[str, object] = {} + + monkeypatch.setattr( + transcription_boundary_support, + "_run_faster_whisper_process_isolated_boundary_impl", + lambda **kwargs: (captured.update(kwargs), [])[1], + ) + + transcription_boundary_support.run_faster_whisper_process_isolated( + file_path="sample.wav", + language="en", + profile=TranscriptionProfile( + backend_id="faster_whisper", + model_name="small", + use_demucs=False, + use_vad=True, + ), + settings=reload_settings(), + logger=logging.getLogger("ser.tests.transcription_boundary"), + error_factory=RuntimeError, + ) + + worker_entry = cast(Callable[..., None], captured["worker_entry"]) + assert callable(worker_entry) + assert "" not in worker_entry.__qualname__ + assert pickle.dumps(worker_entry) diff --git a/tests/test_runtime_pipeline.py b/tests/suites/integration/test_runtime_pipeline.py similarity index 99% rename from tests/test_runtime_pipeline.py rename to tests/suites/integration/test_runtime_pipeline.py index 88a4f48..2a031e2 100644 --- a/tests/test_runtime_pipeline.py +++ b/tests/suites/integration/test_runtime_pipeline.py @@ -1,7 +1,7 @@ """Tests for runtime pipeline orchestration seam.""" import sys -from collections.abc import Callable, Generator +from collections.abc import Callable from dataclasses import replace from types import ModuleType from typing import Any, cast @@ -34,13 +34,7 @@ type PrintTimelineCallable = Callable[[list[TimelineEntry]], None] type SaveTimelineCallable = Callable[[list[TimelineEntry], str], str] - -@pytest.fixture(autouse=True) -def _reset_settings() -> Generator[None, None, None]: - """Keeps global settings stable across tests.""" - config.reload_settings() - yield - config.reload_settings() +pytestmark = [pytest.mark.integration, pytest.mark.usefixtures("reset_ambient_settings")] def _build_test_pipeline( diff --git a/tests/test_runtime_registry.py b/tests/suites/integration/test_runtime_registry.py similarity index 97% rename from tests/test_runtime_registry.py rename to tests/suites/integration/test_runtime_registry.py index b1b1386..75e89d9 100644 --- a/tests/test_runtime_registry.py +++ b/tests/suites/integration/test_runtime_registry.py @@ -1,7 +1,5 @@ """Tests for runtime profile capability registry.""" -from collections.abc import Generator - import pytest import ser.config as config @@ -12,13 +10,7 @@ resolve_runtime_capability, ) - -@pytest.fixture(autouse=True) -def _reset_settings() -> Generator[None, None, None]: - """Keeps global settings stable across tests.""" - config.reload_settings() - yield - config.reload_settings() +pytestmark = [pytest.mark.integration, pytest.mark.usefixtures("reset_ambient_settings")] def test_resolve_runtime_capability_defaults_to_fast() -> None: diff --git a/tests/suites/integration/test_transcript_extractor.py b/tests/suites/integration/test_transcript_extractor.py new file mode 100644 index 0000000..c24c9d7 --- /dev/null +++ b/tests/suites/integration/test_transcript_extractor.py @@ -0,0 +1,271 @@ +"""Public transcript extractor behavior tests.""" + +from __future__ import annotations + +import logging +from types import SimpleNamespace +from typing import Any, cast + +import pytest + +from ser.domain import TranscriptWord +from ser.transcript import transcript_extractor as te + +pytestmark = pytest.mark.integration + + +class FakeResult: + """Whisper-like result object with configurable word payload.""" + + def __init__(self, words: list[SimpleNamespace]) -> None: + self._words = words + + def all_words(self) -> list[SimpleNamespace]: + return self._words + + +def test_resolve_transcription_profile_delegates_to_boundary_owner( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Profile resolution should delegate to the internal owner with explicit settings.""" + settings = cast(te.AppConfig, SimpleNamespace(default_language="en")) + captured: dict[str, object] = {} + expected = te.TranscriptionProfile(backend_id="faster_whisper", model_name="small") + + def _fake_boundary_impl( + profile: te.TranscriptionProfile | None, + *, + settings: te.AppConfig, + profile_factory: object, + error_factory: object, + ) -> te.TranscriptionProfile: + captured["profile"] = profile + captured["settings"] = settings + captured["profile_factory"] = profile_factory + captured["error_factory"] = error_factory + return expected + + monkeypatch.setattr( + te._boundary_support, + "resolve_transcription_profile_for_settings", + _fake_boundary_impl, + ) + + resolved = te.resolve_transcription_profile(None, settings=settings) + + assert resolved == expected + assert captured["settings"] is settings + assert captured["profile_factory"] is te.TranscriptionProfile + assert captured["error_factory"] is te.TranscriptionError + + +def test_extract_transcript_uses_default_language_from_settings( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Public extractor should pass the resolved default language to the owner.""" + settings = cast(te.AppConfig, SimpleNamespace(default_language="pt")) + captured: dict[str, object] = {} + expected = [TranscriptWord("ola", 0.0, 0.5)] + + def _fake_boundary_impl( + file_path: str, + language: str, + profile: te.TranscriptionProfile | None, + *, + settings: te.AppConfig, + profile_factory: object, + logger: logging.Logger, + error_factory: object, + release_memory_fn: object, + phase_started_fn: object, + phase_completed_fn: object, + phase_failed_fn: object, + ) -> list[TranscriptWord]: + captured["file_path"] = file_path + captured["language"] = language + captured["profile"] = profile + captured["settings"] = settings + captured["profile_factory"] = profile_factory + captured["logger"] = logger + captured["error_factory"] = error_factory + captured["release_memory_fn"] = release_memory_fn + captured["phase_started_fn"] = phase_started_fn + captured["phase_completed_fn"] = phase_completed_fn + captured["phase_failed_fn"] = phase_failed_fn + return expected + + monkeypatch.setattr(te._boundary_support, "extract_transcript", _fake_boundary_impl) + + resolved = te.extract_transcript("sample.wav", profile=None, settings=settings) + + assert resolved == expected + assert captured["file_path"] == "sample.wav" + assert captured["language"] == "pt" + assert captured["settings"] is settings + assert captured["profile_factory"] is te.TranscriptionProfile + assert captured["logger"] is te.logger + assert captured["error_factory"] is te.TranscriptionError + + +def test_extract_transcript_propagates_transcription_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Operational failures should propagate as TranscriptionError.""" + settings = cast(te.AppConfig, SimpleNamespace(default_language="en")) + monkeypatch.setattr( + te._boundary_support, + "extract_transcript", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + te.TranscriptionError("Failed to transcribe audio.") + ), + ) + + with pytest.raises(te.TranscriptionError, match="Failed to transcribe audio"): + te.extract_transcript("sample.wav", settings=settings) + + +def test_load_whisper_model_uses_explicit_settings_without_ambient_lookup( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Explicit settings should be passed directly to the owner.""" + settings = cast(te.AppConfig, SimpleNamespace(default_language="en")) + expected_model = object() + captured: dict[str, object] = {} + + def _fake_load_model( + profile: te.TranscriptionProfile | None = None, + *, + settings: te.AppConfig, + profile_factory: object, + logger: logging.Logger, + error_factory: object, + ) -> object: + captured["profile"] = profile + captured["settings"] = settings + captured["profile_factory"] = profile_factory + captured["logger"] = logger + captured["error_factory"] = error_factory + return expected_model + + monkeypatch.setattr(te._boundary_support, "load_whisper_model_for_settings", _fake_load_model) + + resolved = te.load_whisper_model(settings=settings) + + assert resolved is expected_model + assert captured["settings"] is settings + assert captured["profile_factory"] is te.TranscriptionProfile + assert captured["logger"] is te.logger + assert captured["error_factory"] is te.TranscriptionError + + +def test_transcribe_with_model_delegates_to_boundary_owner( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Pre-loaded transcription calls should delegate to the boundary owner.""" + settings = cast(te.AppConfig, SimpleNamespace(default_language="en")) + profile = te.TranscriptionProfile(backend_id="stable_whisper", model_name="large-v2") + expected = [TranscriptWord("hello", 0.1, 0.3)] + captured: dict[str, object] = {} + + def _fake_boundary_impl( + model: object, + language: str, + file_path: str, + active_profile: te.TranscriptionProfile | None, + *, + settings: te.AppConfig, + profile_factory: object, + logger: logging.Logger, + error_factory: object, + passthrough_error_cls: object, + ) -> list[TranscriptWord]: + captured["model"] = model + captured["language"] = language + captured["file_path"] = file_path + captured["profile"] = active_profile + captured["settings"] = settings + captured["profile_factory"] = profile_factory + captured["logger"] = logger + captured["error_factory"] = error_factory + captured["passthrough_error_cls"] = passthrough_error_cls + return expected + + monkeypatch.setattr(te._boundary_support, "transcribe_with_profile", _fake_boundary_impl) + + model = object() + resolved = te.transcribe_with_model( + model, + "sample.wav", + "en", + profile, + settings=settings, + ) + + assert resolved == expected + assert captured["model"] is model + assert captured["profile"] == profile + assert captured["settings"] is settings + assert captured["profile_factory"] is te.TranscriptionProfile + assert captured["logger"] is te.logger + assert captured["error_factory"] is te.TranscriptionError + assert captured["passthrough_error_cls"] is te.TranscriptionError + + +def test_format_transcript_formats_word_timestamps() -> None: + """Word-level timestamps should be preserved in formatted output.""" + result = FakeResult([SimpleNamespace(word="hello", start=0.1, end=0.3)]) + + assert te.format_transcript(cast(Any, result)) == [TranscriptWord("hello", 0.1, 0.3)] + + +def test_format_transcript_raises_for_invalid_result() -> None: + """Invalid Whisper results should raise TranscriptionError.""" + with pytest.raises(te.TranscriptionError, match="Invalid Whisper result object"): + te.format_transcript(cast(Any, object())) + + +def test_mark_compatibility_issues_as_emitted_suppresses_duplicate_operational_logs( + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +) -> None: + """Pre-emitted compatibility issues should not be logged again.""" + + class _Adapter: + def check_compatibility(self, **_kwargs: object) -> Any: + issue = SimpleNamespace( + code="pytest-operational", + message="sample remediation guidance", + impact="informational", + ) + return SimpleNamespace( + noise_issues=[], + operational_issues=[issue], + functional_issues=[], + has_blocking_issues=False, + ) + + monkeypatch.setattr( + te._boundary_support, + "resolve_transcription_backend_adapter", + lambda _backend_id: _Adapter(), + ) + + te.mark_compatibility_issues_as_emitted( + backend_id="stable_whisper", + issue_kind="operational", + issue_codes=("pytest-operational",), + ) + + with caplog.at_level(logging.INFO): + _ = te._boundary_support.check_adapter_compatibility( + active_profile=te.TranscriptionProfile( + backend_id="stable_whisper", + model_name="large-v2", + ), + settings=cast(te.AppConfig, SimpleNamespace()), + runtime_request=cast(Any, SimpleNamespace()), + logger=te.logger, + error_factory=te.TranscriptionError, + ) + + assert "pytest-operational" not in caplog.text diff --git a/tests/suites/smoke/test_cli_runtime_workflows.py b/tests/suites/smoke/test_cli_runtime_workflows.py new file mode 100644 index 0000000..2912a4f --- /dev/null +++ b/tests/suites/smoke/test_cli_runtime_workflows.py @@ -0,0 +1,110 @@ +"""Cheap smoke coverage for user-visible CLI training and inference workflows.""" + +from __future__ import annotations + +import pytest + +import ser.__main__ as cli +import ser.config as config_module +import ser.models.emotion_model as emotion_model +import ser.runtime.fast_inference as fast_inference +import ser.utils.timeline_utils as timeline_utils +from ser.runtime.schema import OUTPUT_SCHEMA_VERSION, InferenceResult, SegmentPrediction + +pytestmark = [pytest.mark.smoke, pytest.mark.usefixtures("reset_ambient_settings")] + + +def _patch_cli_runtime_prerequisites(monkeypatch: pytest.MonkeyPatch) -> None: + """Keeps CLI runtime smokes deterministic without bypassing pipeline building.""" + monkeypatch.setattr(cli, "load_dotenv", lambda: None) + monkeypatch.setattr(cli, "configure_logging", lambda _level=None: None) + monkeypatch.setattr( + cli, + "run_restricted_backend_cli_gate", + lambda **_kwargs: ((), None), + ) + monkeypatch.setattr( + cli, + "run_startup_preflight_cli_gate", + lambda **_kwargs: ((), None), + ) + + +def test_cli_fast_inference_smoke_uses_real_pipeline_builder( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Fast CLI inference should succeed without patching away the pipeline builder.""" + _patch_cli_runtime_prerequisites(monkeypatch) + monkeypatch.setattr( + cli.sys, + "argv", + ["ser", "--file", "sample.wav", "--profile", "fast", "--no-transcript"], + ) + calls = {"legacy": 0, "detailed": 0, "timeline": 0} + + def _fake_run_fast_inference( + _request: object, + _settings: object, + **_kwargs: object, + ) -> InferenceResult: + calls["detailed"] += 1 + return InferenceResult( + schema_version=OUTPUT_SCHEMA_VERSION, + segments=[ + SegmentPrediction( + emotion="calm", + start_seconds=0.0, + end_seconds=1.0, + confidence=1.0, + ) + ], + frames=[], + ) + + def _fake_build_timeline(_transcript: object, _emotions: object) -> list[object]: + calls["timeline"] += 1 + return [] + + monkeypatch.setattr( + fast_inference, + "run_fast_inference", + _fake_run_fast_inference, + ) + monkeypatch.setattr( + "ser.transcript.extract_transcript", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + AssertionError("Transcription should not run for --no-transcript.") + ), + ) + monkeypatch.setattr( + timeline_utils, + "build_timeline", + _fake_build_timeline, + ) + monkeypatch.setattr(timeline_utils, "print_timeline", lambda _timeline: None) + + cli.main() + + assert calls["timeline"] == 1 + assert calls["legacy"] == 0 + assert calls["detailed"] == 1 + + +def test_cli_fast_training_smoke_uses_real_pipeline_builder( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Fast CLI training should succeed without patching away the pipeline builder.""" + _patch_cli_runtime_prerequisites(monkeypatch) + monkeypatch.setattr(cli.sys, "argv", ["ser", "--train", "--profile", "fast"]) + captured: dict[str, object] = {} + + def _fake_train_model(*, settings: config_module.AppConfig | None = None) -> None: + captured["settings"] = settings + + monkeypatch.setattr(emotion_model, "train_model", _fake_train_model) + + with pytest.raises(SystemExit) as exc_info: + cli.main() + + assert exc_info.value.code == 0 + assert captured["settings"] is not None diff --git a/tests/suites/unit/__init__.py b/tests/suites/unit/__init__.py new file mode 100644 index 0000000..4ca66fe --- /dev/null +++ b/tests/suites/unit/__init__.py @@ -0,0 +1 @@ +"""Unit test suite package.""" diff --git a/tests/test_accurate_execution.py b/tests/suites/unit/test_accurate_execution.py similarity index 95% rename from tests/test_accurate_execution.py rename to tests/suites/unit/test_accurate_execution.py index beb7406..086eb64 100644 --- a/tests/test_accurate_execution.py +++ b/tests/suites/unit/test_accurate_execution.py @@ -3,7 +3,6 @@ from __future__ import annotations import logging -from collections.abc import Generator from dataclasses import dataclass import numpy as np @@ -17,6 +16,8 @@ ) from ser.runtime.accurate_prediction import predict_labels as _predict_labels_impl +pytestmark = [pytest.mark.unit, pytest.mark.usefixtures("reset_ambient_settings")] + class _PredictModel: """Deterministic classifier stub for accurate execution tests.""" @@ -40,15 +41,6 @@ class _LoadedModelStub: expected_feature_size: int | None -@pytest.fixture(autouse=True) -def _reset_settings() -> Generator[None]: - """Keep runtime settings stable across tests.""" - - config.reload_settings() - yield - config.reload_settings() - - def _encoded_sequence() -> EncodedSequence: """Build a deterministic encoded timeline.""" diff --git a/tests/test_accurate_execution_flow.py b/tests/suites/unit/test_accurate_execution_flow.py similarity index 99% rename from tests/test_accurate_execution_flow.py rename to tests/suites/unit/test_accurate_execution_flow.py index 43bf3b7..19f2ccf 100644 --- a/tests/test_accurate_execution_flow.py +++ b/tests/suites/unit/test_accurate_execution_flow.py @@ -7,6 +7,7 @@ from dataclasses import dataclass import numpy as np +import pytest from numpy.typing import NDArray import ser.config as config @@ -21,6 +22,8 @@ ) from ser.runtime.schema import OUTPUT_SCHEMA_VERSION, InferenceResult +pytestmark = pytest.mark.unit + class _BackendStub(FeatureBackend): """Minimal feature-backend stub for orchestration tests.""" diff --git a/tests/test_accurate_operation_setup.py b/tests/suites/unit/test_accurate_operation_setup.py similarity index 99% rename from tests/test_accurate_operation_setup.py rename to tests/suites/unit/test_accurate_operation_setup.py index 5215b84..cc0ef9d 100644 --- a/tests/test_accurate_operation_setup.py +++ b/tests/suites/unit/test_accurate_operation_setup.py @@ -13,6 +13,8 @@ from ser.runtime import accurate_operation_setup from ser.runtime.accurate_worker_operation import PreparedAccurateOperation +pytestmark = pytest.mark.unit + @dataclass(frozen=True, slots=True) class _RequestStub: diff --git a/tests/test_accurate_process_timeout.py b/tests/suites/unit/test_accurate_process_timeout.py similarity index 99% rename from tests/test_accurate_process_timeout.py rename to tests/suites/unit/test_accurate_process_timeout.py index 70d913c..42e66ec 100644 --- a/tests/test_accurate_process_timeout.py +++ b/tests/suites/unit/test_accurate_process_timeout.py @@ -9,6 +9,8 @@ from ser.runtime import accurate_process_timeout +pytestmark = pytest.mark.unit + @dataclass(frozen=True) class _Payload: diff --git a/tests/test_accurate_retry_operation.py b/tests/suites/unit/test_accurate_retry_operation.py similarity index 99% rename from tests/test_accurate_retry_operation.py rename to tests/suites/unit/test_accurate_retry_operation.py index 29422cd..e583a72 100644 --- a/tests/test_accurate_retry_operation.py +++ b/tests/suites/unit/test_accurate_retry_operation.py @@ -8,6 +8,8 @@ from ser.runtime.accurate_retry_operation import run_accurate_retry_operation +pytestmark = pytest.mark.unit + def test_run_accurate_retry_operation_requires_process_payload() -> None: """Process-isolated operation should fail fast when payload is absent.""" diff --git a/tests/test_accurate_runtime_support.py b/tests/suites/unit/test_accurate_runtime_support.py similarity index 99% rename from tests/test_accurate_runtime_support.py rename to tests/suites/unit/test_accurate_runtime_support.py index 4532bf3..8d12ec8 100644 --- a/tests/test_accurate_runtime_support.py +++ b/tests/suites/unit/test_accurate_runtime_support.py @@ -17,6 +17,8 @@ prepare_accurate_backend_runtime, ) +pytestmark = pytest.mark.unit + class _BackendStub: """Structural feature-backend stub for support helper tests.""" diff --git a/tests/test_accurate_worker_lifecycle.py b/tests/suites/unit/test_accurate_worker_lifecycle.py similarity index 99% rename from tests/test_accurate_worker_lifecycle.py rename to tests/suites/unit/test_accurate_worker_lifecycle.py index 8f940e9..e272081 100644 --- a/tests/test_accurate_worker_lifecycle.py +++ b/tests/suites/unit/test_accurate_worker_lifecycle.py @@ -13,6 +13,8 @@ from ser.runtime import accurate_worker_lifecycle +pytestmark = pytest.mark.unit + @dataclass(frozen=True, slots=True) class _PayloadStub: diff --git a/tests/test_accurate_worker_operation.py b/tests/suites/unit/test_accurate_worker_operation.py similarity index 99% rename from tests/test_accurate_worker_operation.py rename to tests/suites/unit/test_accurate_worker_operation.py index 27e0b54..9848f79 100644 --- a/tests/test_accurate_worker_operation.py +++ b/tests/suites/unit/test_accurate_worker_operation.py @@ -24,6 +24,8 @@ ) from ser.runtime.contracts import InferenceRequest +pytestmark = pytest.mark.unit + @dataclass(frozen=True) class _Request: diff --git a/tests/test_medium_execution.py b/tests/suites/unit/test_medium_execution.py similarity index 95% rename from tests/test_medium_execution.py rename to tests/suites/unit/test_medium_execution.py index e31b2c5..ffe9eb6 100644 --- a/tests/test_medium_execution.py +++ b/tests/suites/unit/test_medium_execution.py @@ -3,7 +3,6 @@ from __future__ import annotations import logging -from collections.abc import Generator from dataclasses import dataclass import numpy as np @@ -17,6 +16,8 @@ ) from ser.runtime.medium_prediction import predict_labels as _predict_labels_impl +pytestmark = [pytest.mark.unit, pytest.mark.usefixtures("reset_ambient_settings")] + class _PredictModel: """Deterministic classifier stub for medium execution tests.""" @@ -40,15 +41,6 @@ class _LoadedModelStub: expected_feature_size: int | None -@pytest.fixture(autouse=True) -def _reset_settings() -> Generator[None]: - """Keep runtime settings stable across tests.""" - - config.reload_settings() - yield - config.reload_settings() - - def _encoded_sequence() -> EncodedSequence: """Build a deterministic encoded timeline.""" diff --git a/tests/test_medium_execution_context.py b/tests/suites/unit/test_medium_execution_context.py similarity index 99% rename from tests/test_medium_execution_context.py rename to tests/suites/unit/test_medium_execution_context.py index d26da97..41f97ff 100644 --- a/tests/test_medium_execution_context.py +++ b/tests/suites/unit/test_medium_execution_context.py @@ -7,12 +7,16 @@ from types import SimpleNamespace from typing import cast +import pytest + from ser.config import AppConfig from ser.repr.runtime_policy import FeatureRuntimePolicy from ser.runtime.contracts import InferenceRequest from ser.runtime.medium_execution_context import prepare_execution_context from ser.runtime.medium_worker_operation import MediumRetryOperationState +pytestmark = pytest.mark.unit + @dataclass(frozen=True, slots=True) class _PayloadStub: diff --git a/tests/test_medium_execution_flow.py b/tests/suites/unit/test_medium_execution_flow.py similarity index 99% rename from tests/test_medium_execution_flow.py rename to tests/suites/unit/test_medium_execution_flow.py index bad84f4..8bdba43 100644 --- a/tests/test_medium_execution_flow.py +++ b/tests/suites/unit/test_medium_execution_flow.py @@ -6,6 +6,8 @@ from dataclasses import dataclass from typing import cast +import pytest + from ser import config from ser.config import AppConfig from ser.repr.runtime_policy import FeatureRuntimePolicy @@ -14,6 +16,8 @@ from ser.runtime.medium_execution_flow import execute_medium_inference_with_retry from ser.runtime.medium_worker_operation import MediumRetryOperationState +pytestmark = pytest.mark.unit + @dataclass(frozen=True, slots=True) class _PayloadStub: diff --git a/tests/test_medium_process_operation.py b/tests/suites/unit/test_medium_process_operation.py similarity index 99% rename from tests/test_medium_process_operation.py rename to tests/suites/unit/test_medium_process_operation.py index 55bddcb..9316de1 100644 --- a/tests/test_medium_process_operation.py +++ b/tests/suites/unit/test_medium_process_operation.py @@ -15,6 +15,8 @@ ) from ser.runtime.medium_worker_operation import PreparedMediumOperation +pytestmark = pytest.mark.unit + @dataclass(frozen=True, slots=True) class _LoadedModelStub: diff --git a/tests/test_medium_process_timeout.py b/tests/suites/unit/test_medium_process_timeout.py similarity index 99% rename from tests/test_medium_process_timeout.py rename to tests/suites/unit/test_medium_process_timeout.py index 1fd7d49..56d7e47 100644 --- a/tests/test_medium_process_timeout.py +++ b/tests/suites/unit/test_medium_process_timeout.py @@ -9,6 +9,8 @@ from ser.runtime import medium_process_timeout +pytestmark = pytest.mark.unit + @dataclass(frozen=True) class _Payload: diff --git a/tests/test_medium_retry_operation.py b/tests/suites/unit/test_medium_retry_operation.py similarity index 99% rename from tests/test_medium_retry_operation.py rename to tests/suites/unit/test_medium_retry_operation.py index f755d00..131275d 100644 --- a/tests/test_medium_retry_operation.py +++ b/tests/suites/unit/test_medium_retry_operation.py @@ -8,6 +8,8 @@ from ser.runtime.medium_retry_operation import run_medium_inference_with_retry_policy +pytestmark = pytest.mark.unit + def test_run_medium_inference_with_retry_policy_delegates() -> None: """Helper should delegate to the shared retry-policy runner.""" diff --git a/tests/test_medium_runtime_support.py b/tests/suites/unit/test_medium_runtime_support.py similarity index 99% rename from tests/test_medium_runtime_support.py rename to tests/suites/unit/test_medium_runtime_support.py index fd81a35..adb33e9 100644 --- a/tests/test_medium_runtime_support.py +++ b/tests/suites/unit/test_medium_runtime_support.py @@ -25,6 +25,8 @@ validate_medium_loaded_model_runtime_contract, ) +pytestmark = pytest.mark.unit + @dataclass(frozen=True, slots=True) class _LoadedModelStub: diff --git a/tests/test_medium_worker_lifecycle.py b/tests/suites/unit/test_medium_worker_lifecycle.py similarity index 99% rename from tests/test_medium_worker_lifecycle.py rename to tests/suites/unit/test_medium_worker_lifecycle.py index e196150..c54fc30 100644 --- a/tests/test_medium_worker_lifecycle.py +++ b/tests/suites/unit/test_medium_worker_lifecycle.py @@ -12,6 +12,8 @@ from ser.runtime import medium_worker_lifecycle +pytestmark = pytest.mark.unit + def test_run_with_process_timeout_delegates_to_timeout_helper( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/test_medium_worker_operation.py b/tests/suites/unit/test_medium_worker_operation.py similarity index 99% rename from tests/test_medium_worker_operation.py rename to tests/suites/unit/test_medium_worker_operation.py index 2cd0880..df5481a 100644 --- a/tests/test_medium_worker_operation.py +++ b/tests/suites/unit/test_medium_worker_operation.py @@ -23,6 +23,8 @@ run_process_operation, ) +pytestmark = pytest.mark.unit + @dataclass(frozen=True, slots=True) class _LoadedModelStub: diff --git a/tests/test_runtime_worker_error_timeout.py b/tests/suites/unit/test_runtime_worker_error_timeout.py similarity index 99% rename from tests/test_runtime_worker_error_timeout.py rename to tests/suites/unit/test_runtime_worker_error_timeout.py index 7ed28dd..dff1e7b 100644 --- a/tests/test_runtime_worker_error_timeout.py +++ b/tests/suites/unit/test_runtime_worker_error_timeout.py @@ -14,6 +14,8 @@ run_with_timeout, ) +pytestmark = pytest.mark.unit + class _WorkerTimeoutError(TimeoutError): """Synthetic timeout error used for helper behavior tests.""" diff --git a/tests/test_accurate_inference.py b/tests/test_accurate_inference.py deleted file mode 100644 index 85a69c2..0000000 --- a/tests/test_accurate_inference.py +++ /dev/null @@ -1,893 +0,0 @@ -"""Tests for accurate runtime timeout/retry and fallback behavior.""" - -from __future__ import annotations - -import threading -import time -from collections.abc import Generator -from pathlib import Path -from typing import cast - -import numpy as np -import pytest -from sklearn.neural_network import MLPClassifier - -import ser.config as config -import ser.runtime.accurate_inference as accurate_inference -from ser.models import emotion_model -from ser.runtime.accurate_inference import ( - AccurateInferenceExecutionError, - AccurateInferenceTimeoutError, - AccurateModelUnavailableError, - AccurateRuntimeDependencyError, - AccurateTransientBackendError, - run_accurate_inference, -) -from ser.runtime.contracts import InferenceRequest -from ser.runtime.schema import OUTPUT_SCHEMA_VERSION, InferenceResult - - -class _PredictModel(MLPClassifier): - """Deterministic model stub for accurate runtime tests.""" - - def __init__(self) -> None: - super().__init__(hidden_layer_sizes=(1,), max_iter=1, random_state=0) - self.classes_ = np.asarray(["happy", "sad"], dtype=object) - - def predict(self, X: np.ndarray) -> np.ndarray: # noqa: N803 - return np.asarray(["happy"] * int(X.shape[0]), dtype=object) - - def predict_proba(self, X: np.ndarray) -> np.ndarray: # noqa: N803 - return np.asarray([[0.9, 0.1]] * int(X.shape[0]), dtype=np.float64) - - -@pytest.fixture(autouse=True) -def _reset_settings() -> Generator[None]: - """Keeps global settings stable across tests.""" - config.reload_settings() - yield - config.reload_settings() - - -def _accurate_metadata( - feature_vector_size: int = 4, - *, - backend_model_id: str | None = emotion_model.ACCURATE_MODEL_ID, -) -> dict[str, object]: - """Builds minimal accurate-profile artifact metadata for runtime tests.""" - metadata: dict[str, object] = { - "artifact_version": emotion_model.MODEL_ARTIFACT_VERSION, - "artifact_schema_version": "v2", - "created_at_utc": "2026-02-21T00:00:00+00:00", - "feature_vector_size": feature_vector_size, - "training_samples": 8, - "labels": ["happy", "sad"], - "backend_id": "hf_whisper", - "profile": "accurate", - "feature_dim": feature_vector_size, - "frame_size_seconds": 1.0, - "frame_stride_seconds": 1.0, - "pooling_strategy": "mean_std", - } - if backend_model_id is not None: - metadata["backend_model_id"] = backend_model_id - return metadata - - -def _patch_runtime_prerequisites( - monkeypatch: pytest.MonkeyPatch, - *, - backend_model_id: str, -) -> None: - """Patches model/audio/backend prerequisites for retry/timeout tests.""" - monkeypatch.setattr( - "ser.runtime.accurate_inference.load_model", - lambda **_kwargs: emotion_model.LoadedModel( - model=_PredictModel(), - expected_feature_size=4, - artifact_metadata=_accurate_metadata(backend_model_id=backend_model_id), - ), - ) - monkeypatch.setattr( - "ser.runtime.accurate_inference.read_audio_file", - lambda _file_path, *, audio_read_config=None: ( - np.linspace(0.0, 1.0, 16, dtype=np.float32), - 4, - ), - ) - monkeypatch.setattr( - "ser.runtime.accurate_inference.WhisperBackend", - lambda **_kwargs: object(), - ) - - -def test_accurate_timeout_retries_up_to_configured_budget( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Timeouts should retry up to `max_retries + 1` attempts and then fail.""" - monkeypatch.setenv("SER_ACCURATE_MAX_TIMEOUT_RETRIES", "2") - monkeypatch.setenv("SER_ACCURATE_RETRY_BACKOFF_SECONDS", "0.5") - settings = config.reload_settings() - _patch_runtime_prerequisites( - monkeypatch, - backend_model_id=settings.models.accurate_model_id, - ) - - calls = {"attempts": 0, "sleeps": 0} - - def fake_timeout_runner(*_args: object, **_kwargs: object) -> object: - calls["attempts"] += 1 - raise AccurateInferenceTimeoutError("timeout") - - monkeypatch.setattr( - "ser.runtime.accurate_inference._run_with_timeout_impl", fake_timeout_runner - ) - monkeypatch.setattr( - "ser.runtime.accurate_inference._retry_delay_seconds", - lambda **_kwargs: 0.1, - ) - monkeypatch.setattr( - "ser.runtime.policy.time.sleep", - lambda _delay: calls.__setitem__("sleeps", calls["sleeps"] + 1), - ) - - with pytest.raises(AccurateInferenceTimeoutError, match="timeout"): - run_accurate_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert calls["attempts"] == 3 - assert calls["sleeps"] == 2 - - -def test_accurate_transient_backend_failure_respects_retry_upper_bound( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Transient backend failures should stop after bounded retry attempts.""" - monkeypatch.setenv("SER_ACCURATE_MAX_TRANSIENT_RETRIES", "2") - monkeypatch.setenv("SER_ACCURATE_RETRY_BACKOFF_SECONDS", "0") - settings = config.reload_settings() - _patch_runtime_prerequisites( - monkeypatch, - backend_model_id=settings.models.accurate_model_id, - ) - - calls = {"attempts": 0} - - def fake_attempt( - *, - loaded_model: emotion_model.LoadedModel, - backend: object, - audio: np.ndarray, - sample_rate: int, - runtime_config: object, - ) -> object: - del loaded_model, backend, audio, sample_rate, runtime_config - calls["attempts"] += 1 - raise AccurateTransientBackendError("transient backend failure") - - monkeypatch.setattr( - "ser.runtime.accurate_inference._run_with_timeout_impl", - lambda **kwargs: kwargs["operation"](), - ) - monkeypatch.setattr("ser.runtime.accurate_inference._run_accurate_inference_once", fake_attempt) - - with pytest.raises(AccurateInferenceExecutionError, match="retry budget"): - run_accurate_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert calls["attempts"] == 3 - - -def test_accurate_non_retryable_value_error_exits_without_retries( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Feature-contract value errors should not trigger retries.""" - monkeypatch.setenv("SER_ACCURATE_MAX_TIMEOUT_RETRIES", "3") - monkeypatch.setenv("SER_ACCURATE_MAX_TRANSIENT_RETRIES", "3") - settings = config.reload_settings() - _patch_runtime_prerequisites( - monkeypatch, - backend_model_id=settings.models.accurate_model_id, - ) - - calls = {"attempts": 0, "sleeps": 0} - - def fake_attempt( - *, - loaded_model: emotion_model.LoadedModel, - backend: object, - audio: np.ndarray, - sample_rate: int, - runtime_config: object, - ) -> object: - del loaded_model, backend, audio, sample_rate, runtime_config - calls["attempts"] += 1 - raise ValueError("Feature vector size mismatch") - - monkeypatch.setattr( - "ser.runtime.accurate_inference._run_with_timeout_impl", - lambda **kwargs: kwargs["operation"](), - ) - monkeypatch.setattr("ser.runtime.accurate_inference._run_accurate_inference_once", fake_attempt) - monkeypatch.setattr( - "ser.runtime.policy.time.sleep", - lambda _delay: calls.__setitem__("sleeps", calls["sleeps"] + 1), - ) - - with pytest.raises(ValueError, match="Feature vector size mismatch"): - run_accurate_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert calls["attempts"] == 1 - assert calls["sleeps"] == 0 - - -def test_accurate_dependency_error_is_not_retried( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Dependency failures should fail immediately without retry loop.""" - monkeypatch.setenv("SER_ACCURATE_MAX_TIMEOUT_RETRIES", "3") - monkeypatch.setenv("SER_ACCURATE_MAX_TRANSIENT_RETRIES", "3") - settings = config.reload_settings() - _patch_runtime_prerequisites( - monkeypatch, - backend_model_id=settings.models.accurate_model_id, - ) - - calls = {"attempts": 0} - - def fake_attempt( - *, - loaded_model: emotion_model.LoadedModel, - backend: object, - audio: np.ndarray, - sample_rate: int, - runtime_config: object, - ) -> object: - del loaded_model, backend, audio, sample_rate, runtime_config - calls["attempts"] += 1 - raise AccurateRuntimeDependencyError("transformers missing") - - monkeypatch.setattr( - "ser.runtime.accurate_inference._run_with_timeout_impl", - lambda **kwargs: kwargs["operation"](), - ) - monkeypatch.setattr("ser.runtime.accurate_inference._run_accurate_inference_once", fake_attempt) - - with pytest.raises(AccurateRuntimeDependencyError, match="transformers"): - run_accurate_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert calls["attempts"] == 1 - - -def test_accurate_inference_rejects_non_accurate_artifact_metadata( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Accurate runtime should reject artifacts with non-accurate metadata.""" - monkeypatch.setattr( - "ser.runtime.accurate_inference.load_model", - lambda **_kwargs: emotion_model.LoadedModel( - model=_PredictModel(), - expected_feature_size=4, - artifact_metadata={ - **_accurate_metadata(), - "backend_id": "hf_xlsr", - "profile": "medium", - }, - ), - ) - monkeypatch.setattr( - "ser.runtime.accurate_inference.read_audio_file", - lambda _file_path, *, audio_read_config=None: ( - np.linspace(0.0, 1.0, 16, dtype=np.float32), - 4, - ), - ) - - with pytest.raises(AccurateModelUnavailableError, match="hf_whisper"): - run_accurate_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - config.reload_settings(), - ) - - -def test_accurate_inference_returns_expected_schema( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Accurate runtime should return deterministic inference schema payload.""" - settings = config.reload_settings() - _patch_runtime_prerequisites( - monkeypatch, - backend_model_id=settings.models.accurate_model_id, - ) - expected = InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) - monkeypatch.setattr( - "ser.runtime.accurate_inference._run_with_timeout_impl", - lambda **kwargs: kwargs["operation"](), - ) - monkeypatch.setattr( - "ser.runtime.accurate_inference._run_accurate_inference_once", - lambda **_kwargs: expected, - ) - - result = run_accurate_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert result == expected - - -def test_accurate_backend_setup_runs_before_timeout_wrapper( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Backend setup should execute before timeout-wrapped compute operation.""" - settings = config.reload_settings() - backend_model_id = settings.models.accurate_model_id - setup_calls = {"count": 0} - expected = InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) - - class _BackendStub: - def prepare_runtime(self) -> None: - setup_calls["count"] += 1 - - backend = _BackendStub() - monkeypatch.setattr( - "ser.runtime.accurate_inference.load_model", - lambda **_kwargs: emotion_model.LoadedModel( - model=_PredictModel(), - expected_feature_size=4, - artifact_metadata=_accurate_metadata(backend_model_id=backend_model_id), - ), - ) - monkeypatch.setattr( - "ser.runtime.accurate_inference.read_audio_file", - lambda _file_path, *, audio_read_config=None: ( - np.linspace(0.0, 1.0, 16, dtype=np.float32), - 4, - ), - ) - monkeypatch.setattr( - "ser.runtime.accurate_inference._build_backend_for_profile", - lambda **_kwargs: backend, - ) - - def fake_timeout_runner(*_args: object, **_kwargs: object) -> InferenceResult: - assert setup_calls["count"] == 1 - return expected - - monkeypatch.setattr( - "ser.runtime.accurate_inference._run_with_timeout_impl", - fake_timeout_runner, - ) - - result = run_accurate_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert result == expected - assert setup_calls["count"] == 1 - - -def test_accurate_inference_uses_configured_accurate_model_id( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Backend initialization should honor the configured accurate model id.""" - monkeypatch.setenv("SER_ACCURATE_MODEL_ID", "unit-test/whisper-tiny") - settings = config.reload_settings() - monkeypatch.setattr( - "ser.runtime.accurate_inference.load_model", - lambda **_kwargs: emotion_model.LoadedModel( - model=_PredictModel(), - expected_feature_size=4, - artifact_metadata=_accurate_metadata( - backend_model_id=settings.models.accurate_model_id - ), - ), - ) - monkeypatch.setattr( - "ser.runtime.accurate_inference.read_audio_file", - lambda _file_path, *, audio_read_config=None: ( - np.linspace(0.0, 1.0, 16, dtype=np.float32), - 4, - ), - ) - captured: dict[str, object] = {} - - class _BackendStub: - def __init__( - self, - *, - model_id: str, - cache_dir: Path, - device: str = "auto", - dtype: str = "auto", - ) -> None: - captured["model_id"] = model_id - captured["cache_dir"] = cache_dir - captured["device"] = device - captured["dtype"] = dtype - - monkeypatch.setattr("ser.runtime.accurate_inference.WhisperBackend", _BackendStub) - monkeypatch.setattr( - "ser.runtime.accurate_inference._run_with_timeout_impl", - lambda **kwargs: kwargs["operation"](), - ) - - def _fake_run_once(**kwargs: object) -> InferenceResult: - captured["backend"] = kwargs["backend"] - return InferenceResult( - schema_version=OUTPUT_SCHEMA_VERSION, - segments=[], - frames=[], - ) - - monkeypatch.setattr( - "ser.runtime.accurate_inference._run_accurate_inference_once", - _fake_run_once, - ) - - run_accurate_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert captured["model_id"] == "unit-test/whisper-tiny" - assert captured["cache_dir"] == settings.models.huggingface_cache_root - assert captured["device"] == settings.torch_runtime.device - assert captured["dtype"] == settings.torch_runtime.dtype - assert isinstance(captured["backend"], _BackendStub) - - -def test_accurate_inference_rejects_mismatched_backend_model_id( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Runtime should reject artifacts with backend_model_id mismatch.""" - monkeypatch.setenv("SER_ACCURATE_MODEL_ID", "unit-test/whisper-large") - settings = config.reload_settings() - monkeypatch.setattr( - "ser.runtime.accurate_inference.load_model", - lambda **_kwargs: emotion_model.LoadedModel( - model=_PredictModel(), - expected_feature_size=768, - artifact_metadata=_accurate_metadata( - feature_vector_size=768, - backend_model_id="unit-test/whisper-tiny", - ), - ), - ) - monkeypatch.setattr( - "ser.runtime.accurate_inference.read_audio_file", - lambda _file_path, *, audio_read_config=None: ( - np.linspace(0.0, 1.0, 16, dtype=np.float32), - 4, - ), - ) - with pytest.raises(AccurateModelUnavailableError, match="backend_model_id"): - run_accurate_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - -def test_accurate_inference_requires_backend_model_id_metadata( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Artifacts missing backend_model_id should be rejected by strict checks.""" - monkeypatch.setenv("SER_ACCURATE_MODEL_ID", "openai/whisper-tiny") - settings = config.reload_settings() - monkeypatch.setattr( - "ser.runtime.accurate_inference.load_model", - lambda **_kwargs: emotion_model.LoadedModel( - model=_PredictModel(), - expected_feature_size=768, - artifact_metadata=_accurate_metadata( - feature_vector_size=768, - backend_model_id=None, - ), - ), - ) - with pytest.raises(AccurateModelUnavailableError, match="backend_model_id"): - run_accurate_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - -def test_accurate_inference_warns_on_torch_runtime_metadata_mismatch( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Runtime should warn when artifact and runtime torch selectors differ.""" - monkeypatch.setenv("SER_TORCH_DEVICE", "cpu") - monkeypatch.setenv("SER_TORCH_DTYPE", "float32") - settings = config.reload_settings() - metadata = _accurate_metadata(backend_model_id=settings.models.accurate_model_id) - metadata["torch_device"] = "cuda:0" - metadata["torch_dtype"] = "float16" - warnings: list[str] = [] - monkeypatch.setattr( - "ser.runtime.accurate_inference.logger.warning", - lambda msg, *args: warnings.append(msg % args), - ) - monkeypatch.setattr( - "ser.runtime.accurate_inference.load_model", - lambda **_kwargs: emotion_model.LoadedModel( - model=_PredictModel(), - expected_feature_size=4, - artifact_metadata=metadata, - ), - ) - monkeypatch.setattr( - "ser.runtime.accurate_inference.read_audio_file", - lambda _file_path, *, audio_read_config=None: ( - np.linspace(0.0, 1.0, 16, dtype=np.float32), - 4, - ), - ) - monkeypatch.setattr( - "ser.runtime.accurate_inference.WhisperBackend", - lambda **_kwargs: object(), - ) - monkeypatch.setattr( - "ser.runtime.accurate_inference._run_with_timeout_impl", - lambda **kwargs: kwargs["operation"](), - ) - monkeypatch.setattr( - "ser.runtime.accurate_inference._run_accurate_inference_once", - lambda **_kwargs: InferenceResult( - schema_version=OUTPUT_SCHEMA_VERSION, - segments=[], - frames=[], - ), - ) - - run_accurate_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert warnings - assert "torch runtime selectors differ" in warnings[0] - - -def test_accurate_single_flight_serializes_same_profile_model_calls( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Concurrent accurate calls should execute one-at-a-time for one profile/model tuple.""" - settings = config.reload_settings() - _patch_runtime_prerequisites( - monkeypatch, - backend_model_id=settings.models.accurate_model_id, - ) - monkeypatch.setattr( - "ser.runtime.accurate_inference._run_with_timeout_impl", - lambda **kwargs: kwargs["operation"](), - ) - - counters = {"active": 0, "max_active": 0} - counter_lock = threading.Lock() - - def fake_attempt( - *, - loaded_model: emotion_model.LoadedModel, - backend: object, - audio: np.ndarray, - sample_rate: int, - runtime_config: object, - ) -> InferenceResult: - del loaded_model, backend, audio, sample_rate, runtime_config - with counter_lock: - counters["active"] += 1 - counters["max_active"] = max(counters["max_active"], counters["active"]) - time.sleep(0.05) - with counter_lock: - counters["active"] -= 1 - return InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) - - monkeypatch.setattr( - "ser.runtime.accurate_inference._run_accurate_inference_once", - fake_attempt, - ) - - errors: list[Exception] = [] - request = InferenceRequest(file_path="sample.wav", language="en", save_transcript=False) - - def invoke() -> None: - try: - run_accurate_inference(request, settings) - except Exception as err: # pragma: no cover - defensive capture for assertion clarity - errors.append(err) - - first = threading.Thread(target=invoke) - second = threading.Thread(target=invoke) - first.start() - second.start() - first.join() - second.join() - - assert errors == [] - assert counters["max_active"] == 1 - assert accurate_inference._SINGLE_FLIGHT_REGISTRY.active_key_count() == 0 - - -def test_accurate_profile_pipeline_uses_process_timeout_runner( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Profile-pipeline accurate calls should route attempts through process timeout path.""" - monkeypatch.setenv("SER_ENABLE_PROFILE_PIPELINE", "true") - monkeypatch.setenv("SER_ACCURATE_MAX_TIMEOUT_RETRIES", "1") - settings = config.reload_settings() - - calls = {"process": 0, "sleep": 0} - expected = InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) - - def fail_if_called(**_kwargs: object) -> object: - raise AssertionError("In-process load_model path should not run in process mode.") - - def fake_process_runner(*_args: object, **_kwargs: object) -> InferenceResult: - calls["process"] += 1 - if calls["process"] == 1: - raise AccurateInferenceTimeoutError("timeout") - return expected - - monkeypatch.setattr("ser.runtime.accurate_inference.load_model", fail_if_called) - monkeypatch.setattr( - "ser.runtime.accurate_inference._run_with_process_timeout", - fake_process_runner, - ) - monkeypatch.setattr( - "ser.runtime.accurate_inference._retry_delay_seconds", - lambda **_kwargs: 0.1, - ) - monkeypatch.setattr( - "ser.runtime.policy.time.sleep", - lambda _delay: calls.__setitem__("sleep", calls["sleep"] + 1), - ) - - result = run_accurate_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert result == expected - assert calls["process"] == 2 - assert calls["sleep"] == 1 - - -def test_accurate_profile_pipeline_retries_on_cpu_after_mps_oom( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Process-mode accurate retries should demote to CPU after one MPS OOM failure.""" - monkeypatch.setenv("SER_ENABLE_PROFILE_PIPELINE", "true") - monkeypatch.setenv("SER_ACCURATE_MAX_TRANSIENT_RETRIES", "1") - monkeypatch.setenv("SER_ACCURATE_RETRY_BACKOFF_SECONDS", "0") - settings = config.reload_settings() - - devices: list[str] = [] - infos: list[str] = [] - expected = InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) - oom_message = ( - "MPS backend out of memory (MPS allocated: 3.13 GB, other allocations: " - "220.17 MB, max allowed: 3.40 GB). Tried to allocate 85.83 MB on private pool." - ) - - def fail_if_called(**_kwargs: object) -> object: - raise AssertionError("In-process load_model path should not run in process mode.") - - def fake_process_runner( - payload: accurate_inference.AccurateProcessPayload, - *, - timeout_seconds: float, - ) -> InferenceResult: - del timeout_seconds - devices.append(payload.settings.torch_runtime.device) - if len(devices) == 1: - raise AccurateTransientBackendError(oom_message) - return expected - - monkeypatch.setattr("ser.runtime.accurate_inference.load_model", fail_if_called) - monkeypatch.setattr( - "ser.runtime.accurate_inference._run_with_process_timeout", - fake_process_runner, - ) - monkeypatch.setattr( - "ser.runtime.accurate_inference.logger.info", - lambda msg, *args: infos.append(msg % args), - ) - - result = run_accurate_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert result == expected - assert devices == [settings.torch_runtime.device, "cpu"] - assert any( - "retry on CPU after MPS OOM (required=85.8MB available=" in message for message in infos - ) - - -def test_accurate_profile_pipeline_allows_timeout_disable( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Profile-pipeline accurate calls should pass zero timeout to disable budgets.""" - monkeypatch.setenv("SER_ENABLE_PROFILE_PIPELINE", "true") - monkeypatch.setenv("SER_ACCURATE_TIMEOUT_SECONDS", "0") - settings = config.reload_settings() - expected = InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) - calls = {"process": 0} - - def fail_if_called(**_kwargs: object) -> object: - raise AssertionError("In-process load_model path should not run in process mode.") - - def fake_process_runner(*_args: object, **_kwargs: object) -> InferenceResult: - timeout_seconds = _kwargs.get("timeout_seconds") - assert timeout_seconds == 0.0 - calls["process"] += 1 - return expected - - monkeypatch.setattr("ser.runtime.accurate_inference.load_model", fail_if_called) - monkeypatch.setattr( - "ser.runtime.accurate_inference._run_with_process_timeout", - fake_process_runner, - ) - - result = run_accurate_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert result == expected - assert calls["process"] == 1 - - -def test_accurate_retryable_operation_delegates_to_worker_helper( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Retryable accurate wrapper should delegate attempt execution to worker helper seam.""" - settings = config.reload_settings() - expected = InferenceResult( - schema_version=OUTPUT_SCHEMA_VERSION, - segments=[], - frames=[], - ) - prepared = cast( - accurate_inference._PreparedAccurateOperation, - accurate_inference.PreparedAccurateOperation( - loaded_model=cast(accurate_inference.LoadedModel, object()), - backend=cast(accurate_inference.FeatureBackend, object()), - audio=np.ones(8, dtype=np.float32), - sample_rate=4, - runtime_config=settings.accurate_runtime, - ), - ) - retry_state = cast( - accurate_inference.AccurateRetryOperationState[ - accurate_inference.AccurateProcessPayload, - accurate_inference.FeatureBackend, - ], - accurate_inference.AccurateRetryOperationState( - process_payload=None, - active_backend=prepared.backend, - ), - ) - captured: dict[str, object] = {} - - monkeypatch.setattr( - "ser.runtime.accurate_inference._run_inference_operation_impl", - lambda **kwargs: captured.update(kwargs) or expected, - ) - - result = accurate_inference._run_accurate_retryable_operation( - enforce_timeout=True, - use_process_isolation=False, - retry_state=retry_state, - prepared_operation=prepared, - timeout_seconds=11.0, - expected_profile="accurate", - ) - - assert result == expected - assert captured["enforce_timeout"] is True - assert captured["use_process_isolation"] is False - assert captured["process_payload"] is None - assert captured["prepared_operation"] is prepared - assert captured["active_backend"] is prepared.backend - assert captured["timeout_seconds"] == 11.0 - assert captured["expected_profile"] == "accurate" - assert captured["inference_phase_name"] == accurate_inference.PHASE_EMOTION_INFERENCE - - -def test_accurate_process_timeout_applies_after_setup_phase( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Process timeout budget should start only after setup phase is acknowledged.""" - expected = InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) - poll_calls: list[float | None] = [] - settings = config.reload_settings() - - class _ParentConnection: - def __init__(self) -> None: - self._messages: list[tuple[object, ...]] = [ - ("phase", "setup_complete"), - ("ok", expected), - ] - - def recv(self) -> tuple[object, ...]: - if not self._messages: - raise EOFError - return self._messages.pop(0) - - def poll(self, timeout: float | None = None) -> bool: - poll_calls.append(timeout) - return True - - def close(self) -> None: - return None - - class _ChildConnection: - def close(self) -> None: - return None - - class _Process: - def __init__(self) -> None: - self.exitcode = 0 - self._alive = True - - def start(self) -> None: - return None - - def join(self, timeout: float | None = None) -> None: - del timeout - self._alive = False - - def is_alive(self) -> bool: - return self._alive - - def terminate(self) -> None: - self._alive = False - - def kill(self) -> None: - self._alive = False - - class _Context: - def __init__(self) -> None: - self._parent = _ParentConnection() - self._child = _ChildConnection() - self._process = _Process() - - def Pipe( - self, duplex: bool = False - ) -> tuple[_ParentConnection, _ChildConnection]: # noqa: N802 - del duplex - return self._parent, self._child - - def Process(self, **_kwargs: object) -> _Process: # noqa: N802 - return self._process - - monkeypatch.setattr( - "ser.runtime.accurate_inference.mp.get_context", - lambda _name: _Context(), - ) - payload = accurate_inference.AccurateProcessPayload( - request=InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings=settings, - expected_backend_id="hf_whisper", - expected_profile="accurate", - expected_backend_model_id=settings.models.accurate_model_id, - ) - - result = accurate_inference._run_with_process_timeout(payload, timeout_seconds=7.0) - - assert result == expected - assert poll_calls == [7.0] diff --git a/tests/test_fast_inference.py b/tests/test_fast_inference.py deleted file mode 100644 index 9f1f9a7..0000000 --- a/tests/test_fast_inference.py +++ /dev/null @@ -1,331 +0,0 @@ -"""Tests for fast runtime timeout/retry and policy-wrapper behavior.""" - -from __future__ import annotations - -import threading -import time -from collections.abc import Generator -from typing import cast - -import pytest - -import ser.config as config -import ser.models.training_support as training_support -import ser.runtime.fast_inference as fast_inference -from ser.models import emotion_model -from ser.runtime.contracts import InferenceRequest -from ser.runtime.fast_inference import ( - FastInferenceTimeoutError, - run_fast_inference, -) -from ser.runtime.schema import OUTPUT_SCHEMA_VERSION, InferenceResult - - -@pytest.fixture(autouse=True) -def _reset_settings() -> Generator[None]: - """Keeps global settings stable across tests.""" - config.reload_settings() - yield - config.reload_settings() - - -def _fast_metadata() -> dict[str, object]: - """Builds minimal fast-profile artifact metadata for runtime tests.""" - return { - "artifact_version": emotion_model.MODEL_ARTIFACT_VERSION, - "artifact_schema_version": "v2", - "created_at_utc": "2026-02-24T00:00:00+00:00", - "feature_vector_size": 193, - "training_samples": 8, - "labels": ["happy", "sad"], - "backend_id": "handcrafted", - "profile": "fast", - "feature_dim": 193, - "frame_size_seconds": 3.0, - "frame_stride_seconds": 1.0, - "pooling_strategy": "mean", - } - - -def _patch_fast_prerequisites( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Patches model/detail prerequisites for fast runtime tests.""" - monkeypatch.setattr( - "ser.runtime.fast_inference.load_model", - lambda **_kwargs: emotion_model.LoadedModel( - model=cast(training_support.EmotionClassifier, object()), - expected_feature_size=193, - artifact_metadata=_fast_metadata(), - ), - ) - - -def test_fast_inference_returns_expected_schema( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Fast runtime should return deterministic inference schema payload.""" - settings = config.reload_settings() - _patch_fast_prerequisites(monkeypatch) - expected = InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) - monkeypatch.setattr( - "ser.runtime.fast_inference.predict_emotions_detailed", - lambda _file_path, loaded_model=None: expected, - ) - - result = run_fast_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert result == expected - - -def test_fast_timeout_retries_up_to_configured_budget( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Timeouts should retry up to `max_timeout_retries + 1` attempts and then fail.""" - monkeypatch.setenv("SER_FAST_MAX_TIMEOUT_RETRIES", "2") - monkeypatch.setenv("SER_FAST_RETRY_BACKOFF_SECONDS", "0.5") - monkeypatch.setenv("SER_FAST_TIMEOUT_SECONDS", "1.0") - settings = config.reload_settings() - _patch_fast_prerequisites(monkeypatch) - - calls = {"attempts": 0, "sleeps": 0} - - def fake_timeout_runner(*_args: object, **_kwargs: object) -> object: - calls["attempts"] += 1 - raise FastInferenceTimeoutError("timeout") - - monkeypatch.setattr( - "ser.runtime.fast_inference._run_with_timeout_impl", - fake_timeout_runner, - ) - monkeypatch.setattr( - "ser.runtime.fast_inference._retry_delay_seconds", - lambda **_kwargs: 0.1, - ) - monkeypatch.setattr( - "ser.runtime.policy.time.sleep", - lambda _delay: calls.__setitem__("sleeps", calls["sleeps"] + 1), - ) - - with pytest.raises(FastInferenceTimeoutError, match="timeout"): - run_fast_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert calls["attempts"] == 3 - assert calls["sleeps"] == 2 - - -def test_fast_profile_pipeline_uses_process_timeout_runner( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Profile-pipeline fast calls should route attempts through process timeout path.""" - monkeypatch.setenv("SER_ENABLE_PROFILE_PIPELINE", "true") - monkeypatch.setenv("SER_FAST_PROCESS_ISOLATION", "true") - monkeypatch.setenv("SER_FAST_MAX_TIMEOUT_RETRIES", "1") - settings = config.reload_settings() - - calls = {"process": 0, "sleep": 0} - expected = InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) - - def fail_if_called(**_kwargs: object) -> object: - raise AssertionError("In-process load_model path should not run in process mode.") - - def fake_process_runner(*_args: object, **_kwargs: object) -> InferenceResult: - calls["process"] += 1 - if calls["process"] == 1: - raise FastInferenceTimeoutError("timeout") - return expected - - monkeypatch.setattr("ser.runtime.fast_inference.load_model", fail_if_called) - monkeypatch.setattr( - "ser.runtime.fast_inference._run_with_process_timeout", - fake_process_runner, - ) - monkeypatch.setattr( - "ser.runtime.fast_inference._retry_delay_seconds", - lambda **_kwargs: 0.1, - ) - monkeypatch.setattr( - "ser.runtime.policy.time.sleep", - lambda _delay: calls.__setitem__("sleep", calls["sleep"] + 1), - ) - - result = run_fast_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert result == expected - assert calls["process"] == 2 - assert calls["sleep"] == 1 - - -def test_fast_profile_pipeline_allows_timeout_disable( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Profile-pipeline fast calls should pass zero timeout to disable budgets.""" - monkeypatch.setenv("SER_ENABLE_PROFILE_PIPELINE", "true") - monkeypatch.setenv("SER_FAST_PROCESS_ISOLATION", "true") - monkeypatch.setenv("SER_FAST_TIMEOUT_SECONDS", "0") - settings = config.reload_settings() - expected = InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) - calls = {"process": 0} - - def fail_if_called(**_kwargs: object) -> object: - raise AssertionError("In-process load_model path should not run in process mode.") - - def fake_process_runner(*_args: object, **_kwargs: object) -> InferenceResult: - timeout_seconds = _kwargs.get("timeout_seconds") - assert timeout_seconds == 0.0 - calls["process"] += 1 - return expected - - monkeypatch.setattr("ser.runtime.fast_inference.load_model", fail_if_called) - monkeypatch.setattr( - "ser.runtime.fast_inference._run_with_process_timeout", - fake_process_runner, - ) - - result = run_fast_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert result == expected - assert calls["process"] == 1 - - -def test_fast_process_timeout_applies_after_setup_phase( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Process timeout budget should start only after setup phase is acknowledged.""" - expected = InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) - poll_calls: list[float | None] = [] - settings = config.reload_settings() - - class _ParentConnection: - def __init__(self) -> None: - self._messages: list[tuple[object, ...]] = [ - ("phase", "setup_complete"), - ("ok", expected), - ] - - def recv(self) -> tuple[object, ...]: - if not self._messages: - raise EOFError - return self._messages.pop(0) - - def poll(self, timeout: float | None = None) -> bool: - poll_calls.append(timeout) - return True - - def close(self) -> None: - return None - - class _ChildConnection: - def close(self) -> None: - return None - - class _Process: - def __init__(self) -> None: - self.exitcode = 0 - self._alive = True - - def start(self) -> None: - return None - - def join(self, timeout: float | None = None) -> None: - del timeout - self._alive = False - - def is_alive(self) -> bool: - return self._alive - - def terminate(self) -> None: - self._alive = False - - def kill(self) -> None: - self._alive = False - - class _Context: - def __init__(self) -> None: - self._parent = _ParentConnection() - self._child = _ChildConnection() - self._process = _Process() - - def Pipe( - self, duplex: bool = False - ) -> tuple[_ParentConnection, _ChildConnection]: # noqa: N802 - del duplex - return self._parent, self._child - - def Process(self, **_kwargs: object) -> _Process: # noqa: N802 - return self._process - - monkeypatch.setattr( - "ser.runtime.fast_inference.mp.get_context", - lambda _name: _Context(), - ) - payload = fast_inference.FastProcessPayload( - request=InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings=settings, - ) - - result = fast_inference._run_with_process_timeout(payload, timeout_seconds=5.0) - - assert result == expected - assert poll_calls == [5.0] - - -def test_fast_single_flight_serializes_calls( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Concurrent fast calls should execute one-at-a-time for one profile tuple.""" - settings = config.reload_settings() - _patch_fast_prerequisites(monkeypatch) - monkeypatch.setattr( - "ser.runtime.fast_inference._run_with_timeout_impl", - lambda **kwargs: kwargs["operation"](), - ) - - counters = {"active": 0, "max_active": 0} - counter_lock = threading.Lock() - - def fake_attempt(**_kwargs: object) -> InferenceResult: - with counter_lock: - counters["active"] += 1 - counters["max_active"] = max(counters["max_active"], counters["active"]) - time.sleep(0.05) - with counter_lock: - counters["active"] -= 1 - return InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) - - monkeypatch.setattr( - "ser.runtime.fast_inference._run_fast_inference_once", - fake_attempt, - ) - - errors: list[Exception] = [] - request = InferenceRequest(file_path="sample.wav", language="en", save_transcript=False) - - def invoke() -> None: - try: - run_fast_inference(request, settings) - except Exception as err: # pragma: no cover - defensive capture for assertion clarity - errors.append(err) - - first = threading.Thread(target=invoke) - second = threading.Thread(target=invoke) - first.start() - second.start() - first.join() - second.join() - - assert errors == [] - assert counters["max_active"] == 1 - assert fast_inference._SINGLE_FLIGHT_REGISTRY.active_key_count() == 0 diff --git a/tests/test_medium_inference.py b/tests/test_medium_inference.py deleted file mode 100644 index 36c23a8..0000000 --- a/tests/test_medium_inference.py +++ /dev/null @@ -1,740 +0,0 @@ -"""Tests for medium-profile encode-once inference execution.""" - -from __future__ import annotations - -import threading -import time -from collections.abc import Generator -from pathlib import Path - -import numpy as np -import pytest -from numpy.typing import NDArray -from sklearn.neural_network import MLPClassifier - -import ser.config as config -import ser.runtime.medium_inference as medium_inference -from ser.models import emotion_model -from ser.repr import EncodedSequence -from ser.repr.runtime_policy import resolve_feature_runtime_policy -from ser.runtime.contracts import InferenceRequest -from ser.runtime.medium_inference import ( - MediumModelUnavailableError, - run_medium_inference, -) -from ser.runtime.schema import OUTPUT_SCHEMA_VERSION, InferenceResult - - -class _PredictModel(MLPClassifier): - """Deterministic classifier stub for medium inference contract tests.""" - - def __init__( - self, - *, - predictions: list[str], - probabilities: list[list[float]], - classes: list[str], - ) -> None: - super().__init__(hidden_layer_sizes=(1,), max_iter=1, random_state=0) - self._predictions = np.asarray(predictions, dtype=object) - self._probabilities = np.asarray(probabilities, dtype=np.float64) - self.classes_ = np.asarray(classes, dtype=object) - self.last_features: NDArray[np.float64] | None = None - - def predict(self, X: np.ndarray) -> np.ndarray: - self.last_features = np.asarray(X, dtype=np.float64) - return self._predictions - - def predict_proba(self, X: np.ndarray) -> np.ndarray: - self.last_features = np.asarray(X, dtype=np.float64) - return self._probabilities - - -class _FakeBackend: - """Deterministic backend stub that tracks encode invocation count.""" - - def __init__(self) -> None: - self.encode_calls = 0 - - def encode_sequence( - self, - _audio: NDArray[np.float32], - _sample_rate: int, - ) -> EncodedSequence: - self.encode_calls += 1 - return EncodedSequence( - embeddings=np.asarray( - [ - [1.0, 2.0], - [3.0, 4.0], - [5.0, 6.0], - ], - dtype=np.float32, - ), - frame_start_seconds=np.asarray([0.0, 1.0, 2.0], dtype=np.float64), - frame_end_seconds=np.asarray([1.0, 2.0, 3.0], dtype=np.float64), - backend_id="hf_xlsr", - ) - - -@pytest.fixture(autouse=True) -def _reset_settings() -> Generator[None]: - """Keeps global settings stable across tests.""" - config.reload_settings() - yield - config.reload_settings() - - -def _medium_metadata( - feature_vector_size: int = 4, - *, - backend_model_id: str | None = emotion_model.MEDIUM_MODEL_ID, -) -> dict[str, object]: - """Builds minimal medium-profile artifact metadata for loader tests.""" - metadata: dict[str, object] = { - "artifact_version": emotion_model.MODEL_ARTIFACT_VERSION, - "artifact_schema_version": "v2", - "created_at_utc": "2026-02-19T00:00:00+00:00", - "feature_vector_size": feature_vector_size, - "training_samples": 8, - "labels": ["happy", "sad"], - "backend_id": "hf_xlsr", - "profile": "medium", - "feature_dim": feature_vector_size, - "frame_size_seconds": 1.0, - "frame_stride_seconds": 1.0, - "pooling_strategy": "mean_std", - } - if backend_model_id is not None: - metadata["backend_model_id"] = backend_model_id - return metadata - - -def test_run_medium_inference_uses_encode_once_and_returns_schema_result( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Medium inference should encode once and return deterministic segments.""" - backend = _FakeBackend() - model = _PredictModel( - predictions=["happy", "happy", "sad"], - probabilities=[[0.9, 0.1], [0.75, 0.25], [0.2, 0.8]], - classes=["happy", "sad"], - ) - monkeypatch.setattr( - "ser.runtime.medium_inference.read_audio_file", - lambda _file_path, *, audio_read_config=None: ( - np.linspace(0.0, 1.0, 16, dtype=np.float32), - 4, - ), - ) - monkeypatch.setattr("ser.runtime.medium_inference.XLSRBackend", lambda **_kwargs: backend) - monkeypatch.setattr( - "ser.runtime.medium_inference.load_model", - lambda **_kwargs: emotion_model.LoadedModel( - model=model, - expected_feature_size=4, - artifact_metadata=_medium_metadata(), - ), - ) - - result = run_medium_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - config.reload_settings(), - ) - - assert backend.encode_calls == 1 - assert result.schema_version == OUTPUT_SCHEMA_VERSION - assert len(result.frames) == 3 - assert result.frames[0].emotion == "happy" - assert result.frames[2].emotion == "sad" - assert [ - (segment.emotion, segment.start_seconds, segment.end_seconds) for segment in result.segments - ] == [("happy", 0.0, 2.0), ("sad", 2.0, 3.0)] - assert [segment.confidence for segment in result.segments] == pytest.approx([0.825, 0.8]) - assert result.segments[0].probabilities == pytest.approx({"happy": 0.825, "sad": 0.175}) - assert result.segments[1].probabilities == pytest.approx({"happy": 0.2, "sad": 0.8}) - assert model.last_features is not None - np.testing.assert_allclose( - model.last_features, - np.asarray( - [ - [1.0, 2.0, 0.0, 0.0], - [3.0, 4.0, 0.0, 0.0], - [5.0, 6.0, 0.0, 0.0], - ], - dtype=np.float64, - ), - ) - - -def test_run_medium_inference_fails_fast_for_non_medium_artifact( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Non-medium artifacts should be rejected before expensive encoding work.""" - backend = _FakeBackend() - model = _PredictModel( - predictions=["happy"], - probabilities=[[1.0, 0.0]], - classes=["happy", "sad"], - ) - monkeypatch.setattr("ser.runtime.medium_inference.XLSRBackend", lambda **_kwargs: backend) - monkeypatch.setattr( - "ser.runtime.medium_inference.load_model", - lambda **_kwargs: emotion_model.LoadedModel( - model=model, - expected_feature_size=4, - artifact_metadata={ - **_medium_metadata(), - "backend_id": "handcrafted", - "profile": "fast", - }, - ), - ) - - with pytest.raises(MediumModelUnavailableError, match="No medium-profile model"): - run_medium_inference( - InferenceRequest( - file_path="sample.wav", - language="en", - save_transcript=False, - ), - config.reload_settings(), - ) - assert backend.encode_calls == 0 - - -def test_run_medium_inference_rejects_feature_size_mismatch( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Feature width mismatches should fail with actionable error details.""" - backend = _FakeBackend() - model = _PredictModel( - predictions=["happy", "sad", "sad"], - probabilities=[[0.8, 0.2], [0.4, 0.6], [0.3, 0.7]], - classes=["happy", "sad"], - ) - monkeypatch.setattr( - "ser.runtime.medium_inference.read_audio_file", - lambda _file_path, *, audio_read_config=None: ( - np.ones(8, dtype=np.float32), - 4, - ), - ) - monkeypatch.setattr("ser.runtime.medium_inference.XLSRBackend", lambda **_kwargs: backend) - monkeypatch.setattr( - "ser.runtime.medium_inference.load_model", - lambda **_kwargs: emotion_model.LoadedModel( - model=model, - expected_feature_size=8, - artifact_metadata=_medium_metadata(feature_vector_size=8), - ), - ) - - with pytest.raises(ValueError, match="Feature vector size mismatch"): - run_medium_inference( - InferenceRequest( - file_path="sample.wav", - language="en", - save_transcript=False, - ), - config.reload_settings(), - ) - assert backend.encode_calls == 1 - - -def test_run_medium_inference_uses_configured_medium_model_id( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Medium runtime should initialize XLSR backend with configured model id.""" - monkeypatch.setenv("SER_MEDIUM_MODEL_ID", "unit-test/xlsr") - settings = config.reload_settings() - captured: dict[str, object] = {} - monkeypatch.setattr( - "ser.runtime.medium_inference.load_model", - lambda **_kwargs: emotion_model.LoadedModel( - model=_PredictModel( - predictions=["happy"], - probabilities=[[1.0, 0.0]], - classes=["happy", "sad"], - ), - expected_feature_size=4, - artifact_metadata=_medium_metadata(backend_model_id="unit-test/xlsr"), - ), - ) - monkeypatch.setattr( - "ser.runtime.medium_inference.read_audio_file", - lambda _file_path, *, audio_read_config=None: ( - np.ones(8, dtype=np.float32), - 4, - ), - ) - - class _BackendStub: - def __init__( - self, - *, - model_id: str, - cache_dir: Path, - device: str = "auto", - dtype: str = "auto", - ) -> None: - captured["model_id"] = model_id - captured["cache_dir"] = cache_dir - captured["device"] = device - captured["dtype"] = dtype - - monkeypatch.setattr("ser.runtime.medium_inference.XLSRBackend", _BackendStub) - monkeypatch.setattr( - "ser.runtime.medium_inference._run_with_timeout_impl", - lambda **kwargs: kwargs["operation"](), - ) - monkeypatch.setattr( - "ser.runtime.medium_inference._run_medium_inference_once", - lambda **_kwargs: InferenceResult( - schema_version=OUTPUT_SCHEMA_VERSION, - segments=[], - frames=[], - ), - ) - - run_medium_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - backend_override = settings.feature_runtime_policy.for_backend("hf_xlsr") - expected_runtime_policy = resolve_feature_runtime_policy( - backend_id="hf_xlsr", - requested_device=settings.torch_runtime.device, - requested_dtype=settings.torch_runtime.dtype, - backend_override_device=(backend_override.device if backend_override is not None else None), - backend_override_dtype=(backend_override.dtype if backend_override is not None else None), - ) - assert captured["model_id"] == "unit-test/xlsr" - assert captured["cache_dir"] == settings.models.huggingface_cache_root - assert captured["device"] == expected_runtime_policy.device - assert captured["dtype"] == expected_runtime_policy.dtype - - -def test_run_medium_inference_requires_backend_model_id_metadata( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Medium runtime should reject artifacts missing backend_model_id metadata.""" - monkeypatch.setenv("SER_MEDIUM_MODEL_ID", "unit-test/xlsr") - settings = config.reload_settings() - monkeypatch.setattr( - "ser.runtime.medium_inference.load_model", - lambda **_kwargs: emotion_model.LoadedModel( - model=_PredictModel( - predictions=["happy"], - probabilities=[[1.0, 0.0]], - classes=["happy", "sad"], - ), - expected_feature_size=4, - artifact_metadata=_medium_metadata( - backend_model_id=None, - ), - ), - ) - - with pytest.raises(MediumModelUnavailableError, match="backend_model_id"): - run_medium_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - -def test_run_medium_inference_warns_on_torch_runtime_metadata_mismatch( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Runtime should warn when artifact and runtime torch selectors differ.""" - monkeypatch.setenv("SER_TORCH_DEVICE", "cpu") - monkeypatch.setenv("SER_TORCH_DTYPE", "float32") - settings = config.reload_settings() - metadata = _medium_metadata(backend_model_id=settings.models.medium_model_id) - metadata["torch_device"] = "cuda:0" - metadata["torch_dtype"] = "float16" - warnings: list[str] = [] - monkeypatch.setattr( - "ser.runtime.medium_inference.logger.warning", - lambda msg, *args: warnings.append(msg % args), - ) - monkeypatch.setattr( - "ser.runtime.medium_inference.load_model", - lambda **_kwargs: emotion_model.LoadedModel( - model=_PredictModel( - predictions=["happy", "happy", "happy"], - probabilities=[[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]], - classes=["happy", "sad"], - ), - expected_feature_size=4, - artifact_metadata=metadata, - ), - ) - monkeypatch.setattr( - "ser.runtime.medium_inference.read_audio_file", - lambda _file_path, *, audio_read_config=None: ( - np.ones(8, dtype=np.float32), - 4, - ), - ) - monkeypatch.setattr( - "ser.runtime.medium_inference.XLSRBackend", - lambda **_kwargs: _FakeBackend(), - ) - monkeypatch.setattr( - "ser.runtime.medium_inference._run_with_timeout_impl", - lambda **kwargs: kwargs["operation"](), - ) - - run_medium_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert warnings - assert "torch runtime selectors differ" in warnings[0] - - -def test_run_medium_inference_delegates_in_process_setup_to_helper( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """In-process runtime should delegate setup orchestration to helper seam.""" - settings = config.reload_settings() - request = InferenceRequest( - file_path="sample.wav", - language="en", - save_transcript=False, - ) - captured: dict[str, object] = {} - prepared = medium_inference.medium_worker_operation_helpers.PreparedMediumOperation( - loaded_model=object(), - backend=object(), - audio=np.ones(8, dtype=np.float32), - sample_rate=4, - runtime_config=settings.medium_runtime, - ) - expected = InferenceResult( - schema_version=OUTPUT_SCHEMA_VERSION, - segments=[], - frames=[], - ) - - def fake_prepare(**kwargs: object) -> object: - captured.update(kwargs) - return prepared - - monkeypatch.setattr( - "ser.runtime.medium_inference._prepare_in_process_operation", - fake_prepare, - ) - monkeypatch.setattr( - "ser.runtime.medium_inference._prepare_medium_backend_runtime", - lambda *, backend: captured.update({"prepared_backend": backend}), - ) - monkeypatch.setattr( - "ser.runtime.medium_inference._run_process_operation", - lambda prepared: expected, - ) - monkeypatch.setattr( - "ser.runtime.medium_inference._run_with_timeout", - lambda operation, timeout_seconds: operation(), - ) - - result = run_medium_inference(request, settings) - - assert result == expected - assert captured["request"] == request - assert captured["settings"] == settings - assert captured["loaded_model"] is None - assert captured["backend"] is None - assert captured["expected_backend_model_id"] == settings.models.medium_model_id - assert captured["prepared_backend"] is prepared.backend - - -def test_run_medium_inference_delegates_operation_to_worker_helper( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Runtime loop should delegate attempt execution to worker helper seam.""" - settings = config.reload_settings() - request = InferenceRequest( - file_path="sample.wav", - language="en", - save_transcript=False, - ) - prepared = medium_inference.medium_worker_operation_helpers.PreparedMediumOperation( - loaded_model=object(), - backend=object(), - audio=np.ones(8, dtype=np.float32), - sample_rate=4, - runtime_config=settings.medium_runtime, - ) - expected = InferenceResult( - schema_version=OUTPUT_SCHEMA_VERSION, - segments=[], - frames=[], - ) - captured: dict[str, object] = {} - - monkeypatch.setattr( - "ser.runtime.medium_inference._prepare_in_process_operation", - lambda **_kwargs: prepared, - ) - monkeypatch.setattr( - "ser.runtime.medium_inference._prepare_medium_backend_runtime", - lambda *, backend: None, - ) - monkeypatch.setattr( - "ser.runtime.medium_inference._run_medium_retry_policy_impl", - lambda **kwargs: kwargs["operation"](), - ) - monkeypatch.setattr( - "ser.runtime.medium_inference.medium_worker_operation_helpers.run_inference_operation", - lambda **kwargs: captured.update(kwargs) or expected, - ) - - result = run_medium_inference(request, settings) - - assert result == expected - assert captured["enforce_timeout"] is True - assert captured["use_process_isolation"] is False - assert captured["process_payload"] is None - assert captured["prepared_operation"] is prepared - assert captured["timeout_seconds"] == settings.medium_runtime.timeout_seconds - assert captured["profile"] == "medium" - assert captured["inference_phase_name"] == medium_inference.PHASE_EMOTION_INFERENCE - - -def test_run_medium_inference_delegates_retry_policy_wrapper( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Runtime loop should delegate retry-policy execution to helper seam.""" - settings = config.reload_settings() - request = InferenceRequest( - file_path="sample.wav", - language="en", - save_transcript=False, - ) - prepared = medium_inference.medium_worker_operation_helpers.PreparedMediumOperation( - loaded_model=object(), - backend=object(), - audio=np.ones(8, dtype=np.float32), - sample_rate=4, - runtime_config=settings.medium_runtime, - ) - expected = InferenceResult( - schema_version=OUTPUT_SCHEMA_VERSION, - segments=[], - frames=[], - ) - captured: dict[str, object] = {} - - monkeypatch.setattr( - "ser.runtime.medium_inference._prepare_in_process_operation", - lambda **_kwargs: prepared, - ) - monkeypatch.setattr( - "ser.runtime.medium_inference._prepare_medium_backend_runtime", - lambda *, backend: None, - ) - monkeypatch.setattr( - "ser.runtime.medium_inference._run_medium_retry_policy_impl", - lambda **kwargs: captured.update(kwargs) or expected, - ) - - result = run_medium_inference(request, settings) - - assert result == expected - assert captured["runtime_config"] is settings.medium_runtime - assert captured["allow_retries"] is True - assert captured["profile_label"] == "Medium" - assert captured["timeout_error_type"] is medium_inference.MediumInferenceTimeoutError - assert captured["transient_error_type"] is medium_inference.MediumTransientBackendError - assert callable(captured["operation"]) - assert callable(captured["on_transient_failure"]) - assert callable(captured["run_with_retry_policy"]) - - -def test_run_medium_inference_delegates_execution_context_preparation( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Runtime loop should delegate pre-lock execution context preparation.""" - settings = config.reload_settings() - request = InferenceRequest( - file_path="sample.wav", - language="en", - save_transcript=False, - ) - retry_state = medium_inference.medium_worker_operation_helpers.MediumRetryOperationState[ - medium_inference.MediumProcessPayload, - emotion_model.LoadedModel, - object, - ]( - process_payload=None, - prepared_operation=None, - ) - captured: dict[str, object] = {} - - monkeypatch.setattr( - "ser.runtime.medium_inference._prepare_execution_context", - lambda **kwargs: captured.update(kwargs) - or medium_inference._MediumExecutionContext( - runtime_config=settings.medium_runtime, - expected_backend_model_id=settings.models.medium_model_id, - runtime_policy=resolve_feature_runtime_policy( - backend_id="hf_xlsr", - requested_device=settings.torch_runtime.device, - requested_dtype=settings.torch_runtime.dtype, - ), - use_process_isolation=False, - retry_state=retry_state, - setup_started_at=None, - ), - ) - monkeypatch.setattr( - "ser.runtime.medium_inference.medium_worker_operation_helpers.finalize_in_process_setup", - lambda **_kwargs: None, - ) - monkeypatch.setattr( - "ser.runtime.medium_inference._run_medium_retry_policy_impl", - lambda **_kwargs: InferenceResult( - schema_version=OUTPUT_SCHEMA_VERSION, - segments=[], - frames=[], - ), - ) - - _ = run_medium_inference(request, settings) - - assert captured["request"] == request - assert captured["settings"] == settings - assert captured["loaded_model"] is None - assert captured["backend"] is None - assert captured["enforce_timeout"] is True - - -def test_run_medium_inference_delegates_lock_body_execution( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Runtime loop should delegate single-flight execution body to helper seam.""" - settings = config.reload_settings() - request = InferenceRequest( - file_path="sample.wav", - language="en", - save_transcript=False, - ) - execution_context = medium_inference._MediumExecutionContext( - runtime_config=settings.medium_runtime, - expected_backend_model_id=settings.models.medium_model_id, - runtime_policy=resolve_feature_runtime_policy( - backend_id="hf_xlsr", - requested_device=settings.torch_runtime.device, - requested_dtype=settings.torch_runtime.dtype, - ), - use_process_isolation=False, - retry_state=medium_inference.medium_worker_operation_helpers.MediumRetryOperationState[ - medium_inference.MediumProcessPayload, - emotion_model.LoadedModel, - object, - ]( - process_payload=None, - prepared_operation=None, - ), - setup_started_at=None, - ) - captured: dict[str, object] = {} - expected = InferenceResult( - schema_version=OUTPUT_SCHEMA_VERSION, - segments=[], - frames=[], - ) - - monkeypatch.setattr( - "ser.runtime.medium_inference._prepare_execution_context", - lambda **_kwargs: execution_context, - ) - monkeypatch.setattr( - "ser.runtime.medium_inference._execute_medium_inference_with_retry", - lambda **kwargs: captured.update(kwargs) or expected, - ) - - result = run_medium_inference(request, settings) - - assert result == expected - assert captured["execution_context"] is execution_context - assert captured["settings"] is settings - assert captured["injected_backend"] is None - assert captured["enforce_timeout"] is True - assert captured["allow_retries"] is True - assert captured["expected_backend_model_id"] == settings.models.medium_model_id - - -def test_medium_single_flight_serializes_same_profile_model_calls( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Concurrent medium calls should execute one-at-a-time for one profile/model tuple.""" - settings = config.reload_settings() - monkeypatch.setattr( - "ser.runtime.medium_inference.load_model", - lambda **_kwargs: emotion_model.LoadedModel( - model=_PredictModel( - predictions=["happy"], - probabilities=[[1.0, 0.0]], - classes=["happy", "sad"], - ), - expected_feature_size=4, - artifact_metadata=_medium_metadata(backend_model_id=settings.models.medium_model_id), - ), - ) - monkeypatch.setattr( - "ser.runtime.medium_inference.read_audio_file", - lambda _file_path, *, audio_read_config=None: ( - np.ones(8, dtype=np.float32), - 4, - ), - ) - monkeypatch.setattr( - "ser.runtime.medium_inference.XLSRBackend", - lambda **_kwargs: object(), - ) - monkeypatch.setattr( - "ser.runtime.medium_inference._run_with_timeout_impl", - lambda **kwargs: kwargs["operation"](), - ) - - counters = {"active": 0, "max_active": 0} - counter_lock = threading.Lock() - - def fake_attempt(**_kwargs: object) -> InferenceResult: - with counter_lock: - counters["active"] += 1 - counters["max_active"] = max(counters["max_active"], counters["active"]) - time.sleep(0.05) - with counter_lock: - counters["active"] -= 1 - return InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) - - monkeypatch.setattr( - "ser.runtime.medium_inference._run_medium_inference_once", - fake_attempt, - ) - - errors: list[Exception] = [] - request = InferenceRequest(file_path="sample.wav", language="en", save_transcript=False) - - def invoke() -> None: - try: - run_medium_inference(request, settings) - except Exception as err: # pragma: no cover - defensive capture for assertion clarity - errors.append(err) - - first = threading.Thread(target=invoke) - second = threading.Thread(target=invoke) - first.start() - second.start() - first.join() - second.join() - - assert errors == [] - assert counters["max_active"] == 1 - assert medium_inference._SINGLE_FLIGHT_REGISTRY.active_key_count() == 0 diff --git a/tests/test_medium_timeout_and_fallback.py b/tests/test_medium_timeout_and_fallback.py deleted file mode 100644 index 9defda5..0000000 --- a/tests/test_medium_timeout_and_fallback.py +++ /dev/null @@ -1,574 +0,0 @@ -"""Tests for medium runtime timeout/retry and fallback behavior.""" - -from __future__ import annotations - -from collections.abc import Generator - -import numpy as np -import pytest -from sklearn.neural_network import MLPClassifier - -import ser.config as config -import ser.runtime.medium_inference as medium_inference -from ser.models import emotion_model -from ser.repr.runtime_policy import FeatureRuntimePolicy -from ser.runtime.contracts import InferenceRequest -from ser.runtime.medium_inference import ( - MediumInferenceExecutionError, - MediumInferenceTimeoutError, - MediumRuntimeDependencyError, - MediumTransientBackendError, - run_medium_inference, -) -from ser.runtime.schema import OUTPUT_SCHEMA_VERSION, InferenceResult - - -class _PredictModel(MLPClassifier): - """Deterministic model stub for medium runtime tests.""" - - def __init__(self) -> None: - super().__init__(hidden_layer_sizes=(1,), max_iter=1, random_state=0) - self.classes_ = np.asarray(["happy", "sad"], dtype=object) - - def predict(self, X: np.ndarray) -> np.ndarray: # noqa: N803 - return np.asarray(["happy"] * int(X.shape[0]), dtype=object) - - def predict_proba(self, X: np.ndarray) -> np.ndarray: # noqa: N803 - return np.asarray([[0.9, 0.1]] * int(X.shape[0]), dtype=np.float64) - - -@pytest.fixture(autouse=True) -def _reset_settings() -> Generator[None]: - """Keeps global settings stable across tests.""" - config.reload_settings() - yield - config.reload_settings() - - -def _medium_metadata(feature_vector_size: int = 4) -> dict[str, object]: - """Builds minimal medium-profile artifact metadata for runtime tests.""" - return { - "artifact_version": emotion_model.MODEL_ARTIFACT_VERSION, - "artifact_schema_version": "v2", - "created_at_utc": "2026-02-19T00:00:00+00:00", - "feature_vector_size": feature_vector_size, - "training_samples": 8, - "labels": ["happy", "sad"], - "backend_id": "hf_xlsr", - "profile": "medium", - "feature_dim": feature_vector_size, - "frame_size_seconds": 1.0, - "frame_stride_seconds": 1.0, - "pooling_strategy": "mean_std", - "backend_model_id": emotion_model.MEDIUM_MODEL_ID, - } - - -def _patch_runtime_prerequisites(monkeypatch: pytest.MonkeyPatch) -> None: - """Patches model/audio/backend prerequisites for retry/timeout tests.""" - monkeypatch.setattr( - "ser.runtime.medium_inference.load_model", - lambda **_kwargs: emotion_model.LoadedModel( - model=_PredictModel(), - expected_feature_size=4, - artifact_metadata=_medium_metadata(), - ), - ) - monkeypatch.setattr( - "ser.runtime.medium_inference.read_audio_file", - lambda _file_path, *, audio_read_config=None: ( - np.linspace(0.0, 1.0, 16, dtype=np.float32), - 4, - ), - ) - monkeypatch.setattr("ser.runtime.medium_inference.XLSRBackend", lambda **_kwargs: object()) - - -def test_medium_timeout_retries_up_to_configured_budget( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Timeouts should retry up to `max_retries + 1` attempts and then fail.""" - monkeypatch.setenv("SER_MEDIUM_MAX_TIMEOUT_RETRIES", "2") - monkeypatch.setenv("SER_MEDIUM_RETRY_BACKOFF_SECONDS", "0.5") - settings = config.reload_settings() - _patch_runtime_prerequisites(monkeypatch) - - calls = {"attempts": 0, "sleeps": 0} - - def fake_timeout_runner(*_args: object, **_kwargs: object) -> object: - calls["attempts"] += 1 - raise MediumInferenceTimeoutError("timeout") - - monkeypatch.setattr("ser.runtime.medium_inference._run_with_timeout", fake_timeout_runner) - monkeypatch.setattr( - "ser.runtime.medium_inference._retry_delay_seconds", - lambda **_kwargs: 0.1, - ) - monkeypatch.setattr( - "ser.runtime.policy.time.sleep", - lambda _delay: calls.__setitem__("sleeps", calls["sleeps"] + 1), - ) - - with pytest.raises(MediumInferenceTimeoutError, match="timeout"): - run_medium_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert calls["attempts"] == 3 - assert calls["sleeps"] == 2 - - -def test_medium_transient_backend_failure_respects_retry_upper_bound( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Transient backend failures should stop after bounded retry attempts.""" - monkeypatch.setenv("SER_MEDIUM_MAX_TRANSIENT_RETRIES", "2") - monkeypatch.setenv("SER_MEDIUM_RETRY_BACKOFF_SECONDS", "0") - settings = config.reload_settings() - _patch_runtime_prerequisites(monkeypatch) - - calls = {"attempts": 0} - - def fake_attempt( - *, - loaded_model: emotion_model.LoadedModel, - backend: object, - audio: np.ndarray, - sample_rate: int, - runtime_config: object, - ) -> object: - del loaded_model, backend, audio, sample_rate, runtime_config - calls["attempts"] += 1 - raise MediumTransientBackendError("transient backend failure") - - monkeypatch.setattr( - "ser.runtime.medium_inference._run_with_timeout", - lambda operation, timeout_seconds: operation(), - ) - monkeypatch.setattr("ser.runtime.medium_inference._run_medium_inference_once", fake_attempt) - - with pytest.raises(MediumInferenceExecutionError, match="retry budget"): - run_medium_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert calls["attempts"] == 3 - - -def test_medium_non_retryable_value_error_exits_without_retries( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Feature-contract value errors should not trigger retries.""" - monkeypatch.setenv("SER_MEDIUM_MAX_TIMEOUT_RETRIES", "3") - monkeypatch.setenv("SER_MEDIUM_MAX_TRANSIENT_RETRIES", "3") - settings = config.reload_settings() - _patch_runtime_prerequisites(monkeypatch) - - calls = {"attempts": 0, "sleeps": 0} - - def fake_attempt( - *, - loaded_model: emotion_model.LoadedModel, - backend: object, - audio: np.ndarray, - sample_rate: int, - runtime_config: object, - ) -> object: - del loaded_model, backend, audio, sample_rate, runtime_config - calls["attempts"] += 1 - raise ValueError("Feature vector size mismatch") - - monkeypatch.setattr( - "ser.runtime.medium_inference._run_with_timeout", - lambda operation, timeout_seconds: operation(), - ) - monkeypatch.setattr("ser.runtime.medium_inference._run_medium_inference_once", fake_attempt) - monkeypatch.setattr( - "ser.runtime.policy.time.sleep", - lambda _delay: calls.__setitem__("sleeps", calls["sleeps"] + 1), - ) - - with pytest.raises(ValueError, match="Feature vector size mismatch"): - run_medium_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert calls["attempts"] == 1 - assert calls["sleeps"] == 0 - - -def test_medium_dependency_error_is_not_retried( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Dependency failures should fail immediately without retry loop.""" - monkeypatch.setenv("SER_MEDIUM_MAX_TIMEOUT_RETRIES", "3") - monkeypatch.setenv("SER_MEDIUM_MAX_TRANSIENT_RETRIES", "3") - settings = config.reload_settings() - _patch_runtime_prerequisites(monkeypatch) - - calls = {"attempts": 0} - - def fake_attempt( - *, - loaded_model: emotion_model.LoadedModel, - backend: object, - audio: np.ndarray, - sample_rate: int, - runtime_config: object, - ) -> object: - del loaded_model, backend, audio, sample_rate, runtime_config - calls["attempts"] += 1 - raise MediumRuntimeDependencyError("transformers missing") - - monkeypatch.setattr( - "ser.runtime.medium_inference._run_with_timeout", - lambda operation, timeout_seconds: operation(), - ) - monkeypatch.setattr("ser.runtime.medium_inference._run_medium_inference_once", fake_attempt) - - with pytest.raises(MediumRuntimeDependencyError, match="transformers"): - run_medium_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert calls["attempts"] == 1 - - -def test_medium_backend_setup_runs_before_timeout_wrapper( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Backend setup should execute before timeout-wrapped compute operation.""" - settings = config.reload_settings() - setup_calls = {"count": 0} - expected = InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) - - class _BackendStub: - def prepare_runtime(self) -> None: - setup_calls["count"] += 1 - - backend = _BackendStub() - monkeypatch.setattr( - "ser.runtime.medium_inference.load_model", - lambda **_kwargs: emotion_model.LoadedModel( - model=_PredictModel(), - expected_feature_size=4, - artifact_metadata=_medium_metadata(), - ), - ) - monkeypatch.setattr( - "ser.runtime.medium_inference.read_audio_file", - lambda _file_path, *, audio_read_config=None: ( - np.linspace(0.0, 1.0, 16, dtype=np.float32), - 4, - ), - ) - monkeypatch.setattr( - "ser.runtime.medium_inference.XLSRBackend", - lambda **_kwargs: backend, - ) - - def fake_timeout_runner(*_args: object, **_kwargs: object) -> InferenceResult: - assert setup_calls["count"] == 1 - return expected - - monkeypatch.setattr( - "ser.runtime.medium_inference._run_with_timeout", - fake_timeout_runner, - ) - - result = run_medium_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert result == expected - assert setup_calls["count"] == 1 - - -def test_medium_profile_pipeline_uses_process_timeout_runner( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Profile-pipeline medium calls should route attempts through process timeout path.""" - monkeypatch.setenv("SER_ENABLE_PROFILE_PIPELINE", "true") - monkeypatch.setenv("SER_MEDIUM_MAX_TIMEOUT_RETRIES", "1") - settings = config.reload_settings() - - calls = {"process": 0, "sleep": 0} - expected = InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) - - def fail_if_called(**_kwargs: object) -> object: - raise AssertionError("In-process load_model path should not run in process mode.") - - def fake_process_runner(*_args: object, **_kwargs: object) -> InferenceResult: - calls["process"] += 1 - if calls["process"] == 1: - raise MediumInferenceTimeoutError("timeout") - return expected - - monkeypatch.setattr("ser.runtime.medium_inference.load_model", fail_if_called) - monkeypatch.setattr( - "ser.runtime.medium_inference._run_with_process_timeout", - fake_process_runner, - ) - monkeypatch.setattr( - "ser.runtime.medium_inference._retry_delay_seconds", - lambda **_kwargs: 0.1, - ) - monkeypatch.setattr( - "ser.runtime.policy.time.sleep", - lambda _delay: calls.__setitem__("sleep", calls["sleep"] + 1), - ) - - result = run_medium_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert result == expected - assert calls["process"] == 2 - assert calls["sleep"] == 1 - - -def test_medium_profile_pipeline_allows_timeout_disable( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Profile-pipeline medium calls should pass zero timeout to disable budgets.""" - monkeypatch.setenv("SER_ENABLE_PROFILE_PIPELINE", "true") - monkeypatch.setenv("SER_MEDIUM_TIMEOUT_SECONDS", "0") - settings = config.reload_settings() - calls = {"process": 0} - expected = InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) - - def fail_if_called(**_kwargs: object) -> object: - raise AssertionError("In-process load_model path should not run in process mode.") - - def fake_process_runner(*_args: object, **_kwargs: object) -> InferenceResult: - timeout_seconds = _kwargs.get("timeout_seconds") - assert timeout_seconds == 0.0 - calls["process"] += 1 - return expected - - monkeypatch.setattr("ser.runtime.medium_inference.load_model", fail_if_called) - monkeypatch.setattr( - "ser.runtime.medium_inference._run_with_process_timeout", - fake_process_runner, - ) - - result = run_medium_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert result == expected - assert calls["process"] == 1 - - -def test_medium_profile_pipeline_retries_on_cpu_after_mps_transient_failure( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Process-mode medium retries should demote to CPU after one MPS failure.""" - monkeypatch.setenv("SER_ENABLE_PROFILE_PIPELINE", "true") - monkeypatch.setenv("SER_MEDIUM_MAX_TRANSIENT_RETRIES", "1") - monkeypatch.setenv("SER_MEDIUM_RETRY_BACKOFF_SECONDS", "0") - settings = config.reload_settings() - - devices: list[tuple[str, str]] = [] - expected = InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) - oom_message = ( - "MPS backend out of memory (MPS allocated: 3.13 GB, other allocations: " - "220.17 MB, max allowed: 3.40 GB). Tried to allocate 85.83 MB on private pool." - ) - - monkeypatch.setattr( - "ser.runtime.medium_inference._resolve_medium_feature_runtime_policy", - lambda **_kwargs: FeatureRuntimePolicy( - device="mps", - dtype="float16", - reason="test_policy", - ), - ) - - def fail_if_called(**_kwargs: object) -> object: - raise AssertionError("In-process load_model path should not run in process mode.") - - def fake_process_runner( - payload: medium_inference.MediumProcessPayload, - *, - timeout_seconds: float, - ) -> InferenceResult: - del timeout_seconds - devices.append( - ( - payload.settings.torch_runtime.device, - payload.settings.torch_runtime.dtype, - ) - ) - if len(devices) == 1: - raise MediumTransientBackendError(oom_message) - return expected - - monkeypatch.setattr("ser.runtime.medium_inference.load_model", fail_if_called) - monkeypatch.setattr( - "ser.runtime.medium_inference._run_with_process_timeout", - fake_process_runner, - ) - - result = run_medium_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert result == expected - assert devices == [("mps", "float16"), ("cpu", "float32")] - - -def test_medium_in_process_rebuilds_backend_on_cpu_after_transient_failure( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """In-process medium retries should rebuild backend on CPU after MPS failures.""" - monkeypatch.setenv("SER_MEDIUM_MAX_TRANSIENT_RETRIES", "1") - monkeypatch.setenv("SER_MEDIUM_RETRY_BACKOFF_SECONDS", "0") - settings = config.reload_settings() - _patch_runtime_prerequisites(monkeypatch) - - backend_selectors: list[tuple[str, str]] = [] - calls = {"attempts": 0} - expected = InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) - - class _BackendStub: - def __init__(self, *, device: str, dtype: str, **_kwargs: object) -> None: - backend_selectors.append((device, dtype)) - - def prepare_runtime(self) -> None: - return None - - monkeypatch.setattr( - "ser.runtime.medium_inference._resolve_medium_feature_runtime_policy", - lambda **_kwargs: FeatureRuntimePolicy( - device="mps", - dtype="float16", - reason="test_policy", - ), - ) - monkeypatch.setattr("ser.runtime.medium_inference.XLSRBackend", _BackendStub) - monkeypatch.setattr( - "ser.runtime.medium_inference._run_with_timeout", - lambda operation, timeout_seconds: operation(), - ) - - def fake_attempt( - *, - loaded_model: emotion_model.LoadedModel, - backend: object, - audio: np.ndarray, - sample_rate: int, - runtime_config: object, - ) -> InferenceResult: - del loaded_model, backend, audio, sample_rate, runtime_config - calls["attempts"] += 1 - if calls["attempts"] == 1: - raise MediumTransientBackendError( - "Input type (c10::Half) and bias type (float) should be the same" - ) - return expected - - monkeypatch.setattr( - "ser.runtime.medium_inference._run_medium_inference_once", - fake_attempt, - ) - - result = run_medium_inference( - InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings, - ) - - assert result == expected - assert calls["attempts"] == 2 - assert backend_selectors == [("mps", "float16"), ("cpu", "float32")] - - -def test_medium_process_timeout_applies_after_setup_phase( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Process timeout budget should start only after setup phase is acknowledged.""" - expected = InferenceResult(schema_version=OUTPUT_SCHEMA_VERSION, segments=[], frames=[]) - poll_calls: list[float | None] = [] - settings = config.reload_settings() - - class _ParentConnection: - def __init__(self) -> None: - self._messages: list[tuple[object, ...]] = [ - ("phase", "setup_complete"), - ("ok", expected), - ] - - def recv(self) -> tuple[object, ...]: - if not self._messages: - raise EOFError - return self._messages.pop(0) - - def poll(self, timeout: float | None = None) -> bool: - poll_calls.append(timeout) - return True - - def close(self) -> None: - return None - - class _ChildConnection: - def close(self) -> None: - return None - - class _Process: - def __init__(self) -> None: - self.exitcode = 0 - self._alive = True - - def start(self) -> None: - return None - - def join(self, timeout: float | None = None) -> None: - del timeout - self._alive = False - - def is_alive(self) -> bool: - return self._alive - - def terminate(self) -> None: - self._alive = False - - def kill(self) -> None: - self._alive = False - - class _Context: - def __init__(self) -> None: - self._parent = _ParentConnection() - self._child = _ChildConnection() - self._process = _Process() - - def Pipe( - self, duplex: bool = False - ) -> tuple[_ParentConnection, _ChildConnection]: # noqa: N802 - del duplex - return self._parent, self._child - - def Process(self, **_kwargs: object) -> _Process: # noqa: N802 - return self._process - - monkeypatch.setattr( - "ser.runtime.medium_inference.mp.get_context", - lambda _name: _Context(), - ) - payload = medium_inference.MediumProcessPayload( - request=InferenceRequest(file_path="sample.wav", language="en", save_transcript=False), - settings=settings, - expected_backend_model_id=settings.models.medium_model_id, - ) - - result = medium_inference._run_with_process_timeout(payload, timeout_seconds=11.0) - - assert result == expected - assert poll_calls == [11.0] diff --git a/tests/test_runtime_worker_lifecycle_delegation.py b/tests/test_runtime_worker_lifecycle_delegation.py deleted file mode 100644 index b5ae15a..0000000 --- a/tests/test_runtime_worker_lifecycle_delegation.py +++ /dev/null @@ -1,209 +0,0 @@ -"""Delegation contracts for shared runtime worker lifecycle helpers.""" - -from __future__ import annotations - -from collections.abc import Callable -from multiprocessing.connection import Connection -from types import SimpleNamespace -from typing import cast - -import pytest - -import ser.runtime.accurate_inference as accurate_inference -import ser.runtime.fast_inference as fast_inference -import ser.runtime.medium_inference as medium_inference -from ser import config -from ser.runtime.contracts import InferenceRequest -from ser.runtime.schema import OUTPUT_SCHEMA_VERSION, InferenceResult - - -def test_fast_recv_worker_message_delegates_to_internal_service( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Fast wrapper should delegate with stable label/error wiring.""" - captured: dict[str, object] = {} - expected = ("phase", "setup_complete") - - def _fake_impl(**kwargs: object) -> tuple[str, str]: - captured.update(kwargs) - return expected - - monkeypatch.setattr(fast_inference, "_recv_worker_message_impl", _fake_impl) - connection = cast(Connection, SimpleNamespace()) - - resolved = fast_inference._recv_worker_message(connection, stage="setup") - - assert resolved == expected - assert captured["connection"] is connection - assert captured["stage"] == "setup" - assert captured["worker_label"] == "Fast inference" - assert captured["error_factory"] is fast_inference.FastInferenceExecutionError - - -def test_medium_recv_worker_message_delegates_to_internal_service( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Medium wrapper should delegate with stable label/error wiring.""" - captured: dict[str, object] = {} - expected = ("phase", "setup_complete") - - def _fake_impl(**kwargs: object) -> tuple[str, str]: - captured.update(kwargs) - return expected - - monkeypatch.setattr(medium_inference, "_recv_worker_message_impl", _fake_impl) - connection = cast(Connection, SimpleNamespace()) - - resolved = medium_inference._recv_worker_message(connection, stage="setup") - - assert resolved == expected - assert captured["connection"] is connection - assert captured["stage"] == "setup" - assert captured["worker_label"] == "Medium inference" - assert captured["error_factory"] is medium_inference.MediumInferenceExecutionError - - -def test_accurate_recv_worker_message_delegates_to_internal_service( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Accurate wrapper should delegate with stable label/error wiring.""" - captured: dict[str, object] = {} - expected = ("phase", "setup_complete") - - def _fake_impl(**kwargs: object) -> tuple[str, str]: - captured.update(kwargs) - return expected - - monkeypatch.setattr(accurate_inference, "_recv_worker_message_impl", _fake_impl) - connection = cast(Connection, SimpleNamespace()) - - resolved = accurate_inference._recv_worker_message(connection, stage="setup") - - assert resolved == expected - assert captured["connection"] is connection - assert captured["stage"] == "setup" - assert captured["worker_label"] == "Accurate inference" - assert captured["error_factory"] is accurate_inference.AccurateInferenceExecutionError - - -def test_medium_raise_worker_error_delegates_to_internal_service( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Medium worker-error wrapper should delegate with stable mapping wiring.""" - captured: dict[str, object] = {} - - def _fake_impl(**kwargs: object) -> None: - captured.update(kwargs) - - monkeypatch.setattr(medium_inference, "_raise_worker_error_impl", _fake_impl) - medium_inference._raise_worker_error("ValueError", "bad payload") - - assert captured["error_type"] == "ValueError" - assert captured["message"] == "bad payload" - assert captured["worker_label"] == "Medium inference" - assert captured["unknown_error_factory"] is medium_inference.MediumInferenceExecutionError - known_error_factories = cast(dict[str, object], captured["known_error_factories"]) - assert "MediumTransientBackendError" in known_error_factories - assert "MediumInferenceTimeoutError" in known_error_factories - - -def test_medium_process_timeout_delegates_setup_compute_handshake( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Medium process-timeout wrapper should delegate handshake orchestration.""" - captured: dict[str, object] = {} - sentinel_context = cast(object, SimpleNamespace()) - settings = config.reload_settings() - expected = InferenceResult( - schema_version=OUTPUT_SCHEMA_VERSION, - segments=[], - frames=[], - ) - - def _fake_impl(**kwargs: object) -> tuple[str, InferenceResult]: - captured.update(kwargs) - on_setup_complete = cast(Callable[[], None], kwargs["on_setup_complete"]) - on_setup_complete() - return ("ok", expected) - - monkeypatch.setattr( - medium_inference, - "_run_process_setup_compute_handshake_impl", - _fake_impl, - ) - monkeypatch.setattr( - medium_inference.mp, - "get_context", - lambda _name: sentinel_context, - ) - payload = medium_inference.MediumProcessPayload( - request=InferenceRequest( - file_path="sample.wav", - language="en", - save_transcript=False, - ), - settings=settings, - expected_backend_model_id=settings.models.medium_model_id, - ) - - resolved = medium_inference._run_with_process_timeout(payload, timeout_seconds=7.0) - - assert resolved == expected - assert captured["context"] is sentinel_context - assert captured["worker_label"] == "Medium inference" - assert captured["timeout_seconds"] == 7.0 - assert captured["timeout_error_factory"] is medium_inference.MediumInferenceTimeoutError - assert captured["execution_error_factory"] is medium_inference.MediumInferenceExecutionError - assert captured["worker_target"] is medium_inference._worker_entry - - -def test_accurate_process_timeout_delegates_setup_compute_handshake( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Accurate process-timeout wrapper should delegate handshake orchestration.""" - captured: dict[str, object] = {} - sentinel_context = cast(object, SimpleNamespace()) - settings = config.reload_settings() - expected = InferenceResult( - schema_version=OUTPUT_SCHEMA_VERSION, - segments=[], - frames=[], - ) - - def _fake_impl(**kwargs: object) -> tuple[str, InferenceResult]: - captured.update(kwargs) - on_setup_complete = cast(Callable[[], None], kwargs["on_setup_complete"]) - on_setup_complete() - return ("ok", expected) - - monkeypatch.setattr( - accurate_inference, - "_run_process_setup_compute_handshake_impl", - _fake_impl, - ) - monkeypatch.setattr( - accurate_inference.mp, - "get_context", - lambda _name: sentinel_context, - ) - payload = accurate_inference.AccurateProcessPayload( - request=InferenceRequest( - file_path="sample.wav", - language="en", - save_transcript=False, - ), - settings=settings, - expected_backend_id="hf_whisper", - expected_profile="accurate", - expected_backend_model_id=settings.models.accurate_model_id, - ) - - resolved = accurate_inference._run_with_process_timeout(payload, timeout_seconds=7.0) - - assert resolved == expected - assert captured["context"] is sentinel_context - assert captured["worker_label"] == "Accurate inference" - assert captured["timeout_seconds"] == 7.0 - assert captured["timeout_error_factory"] is accurate_inference.AccurateInferenceTimeoutError - assert captured["execution_error_factory"] is accurate_inference.AccurateInferenceExecutionError - assert captured["worker_target"] is accurate_inference._worker_entry diff --git a/tests/test_stable_whisper_mps_compat.py b/tests/test_stable_whisper_mps_compat.py index bf2df6c..9a971c8 100644 --- a/tests/test_stable_whisper_mps_compat.py +++ b/tests/test_stable_whisper_mps_compat.py @@ -137,14 +137,71 @@ def test_move_model_to_mps_with_alignment_placeholder_rolls_back_to_cpu_on_failu assert model.alignment_heads.device.type == "cpu" +def test_is_mps_log_mel_compatibility_error_detects_fft_gap() -> None: + """MPS FFT operator gaps should trigger CPU log-mel offload.""" + err = NotImplementedError( + "The operator 'aten::_fft_r2c' is not currently implemented for the MPS device." + ) + + assert mps_compat._is_mps_log_mel_compatibility_error(err) is True + + +def test_is_mps_log_mel_compatibility_error_detects_complex_dtype_gap() -> None: + """Complex dtype materialization failures should trigger CPU log-mel offload.""" + err = TypeError( + "Trying to convert ComplexFloat to the MPS backend but it does not have support for that dtype." + ) + + assert mps_compat._is_mps_log_mel_compatibility_error(err) is True + + +def test_resolve_mps_log_mel_compatibility_decision_enables_cpu_offload( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Detected MPS log-mel frontend gaps should enable the CPU offload patch.""" + mps_compat._resolve_mps_log_mel_compatibility_decision.cache_clear() + monkeypatch.setattr(mps_compat, "_mps_backend_available", lambda: True) + monkeypatch.setattr( + mps_compat, + "_probe_mps_log_mel_frontend", + lambda: (_ for _ in ()).throw( + NotImplementedError( + "The operator 'aten::_fft_r2c' is not currently implemented for the MPS device." + ) + ), + ) + + decision = mps_compat._resolve_mps_log_mel_compatibility_decision() + + assert decision.enable_cpu_offload is True + assert decision.reason_code == "mps_log_mel_frontend_cpu_offload_required" + + +def test_resolve_mps_log_mel_compatibility_decision_skips_patch_when_supported( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Healthy MPS log-mel frontends should keep the patch disabled.""" + mps_compat._resolve_mps_log_mel_compatibility_decision.cache_clear() + monkeypatch.setattr(mps_compat, "_mps_backend_available", lambda: True) + monkeypatch.setattr(mps_compat, "_probe_mps_log_mel_frontend", lambda: None) + + decision = mps_compat._resolve_mps_log_mel_compatibility_decision() + + assert decision.enable_cpu_offload is False + assert decision.reason_code == "mps_log_mel_frontend_supported" + + def test_mps_timing_compatibility_context_patches_and_restores_aliases( monkeypatch: pytest.MonkeyPatch, ) -> None: - """Context should patch DTW/std_mean/_compute_qks aliases and restore on exit.""" + """Context should patch MPS compatibility aliases and restore them on exit.""" def _original_compute_qks(*_args: object, **_kwargs: object) -> None: return None + def _original_log_mel(*_args: object, **_kwargs: object) -> torch.Tensor: + return torch.ones((80, 8), dtype=torch.float32) + @contextmanager def _disable_sdpa() -> Any: yield @@ -156,7 +213,10 @@ def _disable_sdpa() -> Any: fake_compat = SimpleNamespace( dtw=lambda _x: "compat_original", disable_sdpa=_disable_sdpa, + log_mel_spectrogram=_original_log_mel, ) + fake_original_whisper = SimpleNamespace(log_mel_spectrogram=_original_log_mel) + fake_whisper_audio = SimpleNamespace(log_mel_spectrogram=_original_log_mel) fake_whisper_timing = SimpleNamespace(dtw_cpu=lambda x: ("cpu", x)) def _fake_import_module(name: str) -> object: @@ -164,19 +224,39 @@ def _fake_import_module(name: str) -> object: return fake_timing if name == "stable_whisper.whisper_compatibility": return fake_compat + if name == "stable_whisper.whisper_word_level.original_whisper": + return fake_original_whisper + if name == "whisper.audio": + return fake_whisper_audio if name == "whisper.timing": return fake_whisper_timing raise ModuleNotFoundError(name) monkeypatch.setattr(mps_compat.importlib, "import_module", _fake_import_module) + monkeypatch.setattr( + mps_compat, + "_resolve_mps_log_mel_compatibility_decision", + lambda: mps_compat._MpsLogMelCompatibilityDecision( + enable_cpu_offload=True, + reason_code="mps_log_mel_frontend_cpu_offload_required", + python_version="3.12.8", + torch_version="2.2.2", + ), + ) original_std_mean = torch.std_mean original_compute_qks = fake_timing._compute_qks + original_compat_log_mel = fake_compat.log_mel_spectrogram + original_original_whisper_log_mel = fake_original_whisper.log_mel_spectrogram + original_whisper_audio_log_mel = fake_whisper_audio.log_mel_spectrogram with mps_compat.stable_whisper_mps_timing_compatibility_context(): assert fake_timing.dtw is not None assert fake_timing.dtw is fake_compat.dtw assert torch.std_mean is not original_std_mean assert fake_timing._compute_qks is not original_compute_qks + assert fake_compat.log_mel_spectrogram is not original_compat_log_mel + assert fake_original_whisper.log_mel_spectrogram is not original_original_whisper_log_mel + assert fake_whisper_audio.log_mel_spectrogram is not original_whisper_audio_log_mel dtw_result = cast(object, fake_timing.dtw(torch.tensor([1.0]))) assert isinstance(dtw_result, tuple) assert dtw_result[0] == "cpu" @@ -185,6 +265,9 @@ def _fake_import_module(name: str) -> object: assert fake_compat.dtw(torch.tensor([1.0])) == "compat_original" assert torch.std_mean is original_std_mean assert fake_timing._compute_qks is original_compute_qks + assert fake_compat.log_mel_spectrogram is original_compat_log_mel + assert fake_original_whisper.log_mel_spectrogram is original_original_whisper_log_mel + assert fake_whisper_audio.log_mel_spectrogram is original_whisper_audio_log_mel def test_mps_timing_compatibility_context_offloads_compute_qks_to_cpu( @@ -219,6 +302,13 @@ def _disable_sdpa() -> Any: fake_compat = SimpleNamespace( dtw=lambda _x: "compat_original", disable_sdpa=_disable_sdpa, + log_mel_spectrogram=lambda *_args, **_kwargs: torch.ones((80, 8), dtype=torch.float32), + ) + fake_original_whisper = SimpleNamespace( + log_mel_spectrogram=lambda *_args, **_kwargs: torch.ones((80, 8), dtype=torch.float32) + ) + fake_whisper_audio = SimpleNamespace( + log_mel_spectrogram=lambda *_args, **_kwargs: torch.ones((80, 8), dtype=torch.float32) ) fake_whisper_timing = SimpleNamespace(dtw_cpu=lambda x: ("cpu", x)) @@ -227,6 +317,10 @@ def _fake_import_module(name: str) -> object: return fake_timing if name == "stable_whisper.whisper_compatibility": return fake_compat + if name == "stable_whisper.whisper_word_level.original_whisper": + return fake_original_whisper + if name == "whisper.audio": + return fake_whisper_audio if name == "whisper.timing": return fake_whisper_timing raise ModuleNotFoundError(name) diff --git a/tests/test_transcript_extractor.py b/tests/test_transcript_extractor.py deleted file mode 100644 index 970f8b7..0000000 --- a/tests/test_transcript_extractor.py +++ /dev/null @@ -1,2053 +0,0 @@ -"""Behavior tests for transcript extraction error handling.""" - -import logging -import os -import sys -from collections.abc import Callable -from multiprocessing.connection import Connection -from multiprocessing.reduction import ForkingPickler -from pathlib import Path -from types import ModuleType, SimpleNamespace -from typing import TYPE_CHECKING, Any, Never, cast - -import pytest - -from ser.domain import TranscriptWord -from ser.runtime.phase_contract import ( - PHASE_TRANSCRIPTION, - PHASE_TRANSCRIPTION_MODEL_LOAD, - PHASE_TRANSCRIPTION_SETUP, -) -from ser.transcript import transcript_extractor as te -from ser.transcript.backends import faster_whisper as faster_whisper_adapter -from ser.transcript.backends.base import CompatibilityIssue - -if TYPE_CHECKING: - from stable_whisper.result import WhisperResult - - -@pytest.fixture(autouse=True) -def _disable_openmp_conflict_detection(monkeypatch: pytest.MonkeyPatch) -> None: - """Keeps transcript extractor backend tests deterministic.""" - monkeypatch.setattr( - faster_whisper_adapter, - "has_known_faster_whisper_openmp_runtime_conflict", - lambda: False, - ) - - -class FailingModel: - """Fake model that always fails during transcription.""" - - def transcribe(self, **_kwargs: object) -> Never: - raise RuntimeError("transcribe failure") - - -class FakeResult: - """Whisper-like result object with configurable word payload.""" - - def __init__(self, words: list[SimpleNamespace]) -> None: - self._words = words - - def all_words(self) -> list[SimpleNamespace]: - return self._words - - -def test_extract_transcript_raises_transcription_error( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Operational failures should propagate as TranscriptionError.""" - settings = cast(te.AppConfig, SimpleNamespace(default_language="en")) - monkeypatch.setattr( - te._boundary_support, - "load_whisper_model_for_settings", - lambda *_args, **_kwargs: object(), - ) - monkeypatch.setattr( - te._boundary_support, "transcription_setup_required", lambda **_kwargs: False - ) - monkeypatch.setattr( - te._boundary_support, - "transcribe_with_profile", - lambda *_args, **_kwargs: (_ for _ in ()).throw( - te.TranscriptionError("Failed to transcribe audio.") - ), - ) - - with pytest.raises(te.TranscriptionError, match="Failed to transcribe audio"): - te._extract_transcript( - "does-not-matter.wav", - "en", - te.TranscriptionProfile(backend_id="stable_whisper", model_name="large-v2"), - settings=settings, - ) - - -def test_extract_transcript_returns_empty_list_for_successful_empty_result( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """A successful call with no words should return an empty transcript.""" - settings = cast(te.AppConfig, SimpleNamespace(default_language="en")) - monkeypatch.setattr( - te._boundary_support, - "load_whisper_model_for_settings", - lambda *_args, **_kwargs: object(), - ) - monkeypatch.setattr( - te._boundary_support, "transcription_setup_required", lambda **_kwargs: False - ) - monkeypatch.setattr( - te._boundary_support, - "transcribe_with_profile", - lambda *_args, **_kwargs: [], - ) - - assert ( - te._extract_transcript( - "empty.wav", - "en", - te.TranscriptionProfile(backend_id="stable_whisper", model_name="large-v2"), - settings=settings, - ) - == [] - ) - - -def test_extract_transcript_formats_word_timestamps( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Word-level timestamps should be preserved in formatted output.""" - settings = cast(te.AppConfig, SimpleNamespace(default_language="en")) - monkeypatch.setattr( - te._boundary_support, - "load_whisper_model_for_settings", - lambda *_args, **_kwargs: object(), - ) - monkeypatch.setattr( - te._boundary_support, "transcription_setup_required", lambda **_kwargs: False - ) - monkeypatch.setattr( - te._boundary_support, - "transcribe_with_profile", - lambda *_args, **_kwargs: [TranscriptWord("hello", 0.1, 0.3)], - ) - - assert te._extract_transcript( - "sample.wav", - "en", - te.TranscriptionProfile(backend_id="stable_whisper", model_name="large-v2"), - settings=settings, - ) == [TranscriptWord("hello", 0.1, 0.3)] - - -def test_extract_transcript_releases_runtime_memory_on_success( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """In-process transcript extraction should release runtime memory on success.""" - settings = cast(te.AppConfig, SimpleNamespace(default_language="en")) - loaded_model = object() - released_models: list[object] = [] - monkeypatch.setattr( - te._boundary_support, - "load_whisper_model_for_settings", - lambda *_args, **_kwargs: loaded_model, - ) - monkeypatch.setattr( - te._boundary_support, "transcription_setup_required", lambda **_kwargs: False - ) - monkeypatch.setattr( - te._boundary_support, - "transcribe_with_profile", - lambda *_args, **_kwargs: [], - ) - monkeypatch.setattr( - te, - "_release_transcription_runtime_memory", - lambda *, model: released_models.append(model), - ) - - result = te._extract_transcript( - "sample.wav", - "en", - te.TranscriptionProfile(backend_id="stable_whisper", model_name="large-v2"), - settings=settings, - ) - - assert result == [] - assert released_models == [loaded_model] - - -def test_extract_transcript_releases_runtime_memory_on_failure( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """In-process transcript extraction should release runtime memory on failures.""" - settings = cast(te.AppConfig, SimpleNamespace(default_language="en")) - loaded_model = object() - released_models: list[object] = [] - monkeypatch.setattr( - te._boundary_support, - "load_whisper_model_for_settings", - lambda *_args, **_kwargs: loaded_model, - ) - monkeypatch.setattr( - te._boundary_support, "transcription_setup_required", lambda **_kwargs: False - ) - monkeypatch.setattr( - te._boundary_support, - "transcribe_with_profile", - lambda *_args, **_kwargs: (_ for _ in ()).throw( - te.TranscriptionError("Failed to transcribe audio.") - ), - ) - monkeypatch.setattr( - te, - "_release_transcription_runtime_memory", - lambda *, model: released_models.append(model), - ) - - with pytest.raises(te.TranscriptionError, match="Failed to transcribe audio"): - te._extract_transcript( - "sample.wav", - "en", - te.TranscriptionProfile(backend_id="stable_whisper", model_name="large-v2"), - settings=settings, - ) - - assert released_models == [loaded_model] - - -def test_release_transcription_runtime_memory_empties_available_torch_caches( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Torch cache cleanup should be best-effort and gated by availability checks.""" - from ser._internal.transcription import process_worker as process_worker_helpers - - calls: list[str] = [] - fake_torch = ModuleType("torch") - fake_mps = ModuleType("mps") - fake_cuda = ModuleType("cuda") - - cast(Any, fake_mps).is_available = lambda: True - cast(Any, fake_mps).empty_cache = lambda: calls.append("mps") - cast(Any, fake_cuda).is_available = lambda: True - cast(Any, fake_cuda).empty_cache = lambda: calls.append("cuda") - cast(Any, fake_torch).mps = fake_mps - cast(Any, fake_torch).cuda = fake_cuda - monkeypatch.setitem(sys.modules, "torch", fake_torch) - monkeypatch.setattr( - process_worker_helpers.gc, - "collect", - lambda: calls.append("gc"), - ) - - te._release_transcription_runtime_memory(model=object()) - - assert calls == ["gc", "mps", "cuda"] - - -def test_format_transcript_raises_for_invalid_result() -> None: - """Invalid result objects should raise a domain-level error.""" - with pytest.raises(te.TranscriptionError, match="Invalid Whisper result object"): - invalid_result = cast("WhisperResult", object()) - te.format_transcript(invalid_result) - - -def test_load_whisper_model_routes_downloads_to_model_cache_root( - monkeypatch: pytest.MonkeyPatch, - tmp_path: Path, -) -> None: - """Whisper and torch-hub assets should route to SER model cache roots.""" - download_root = tmp_path / "model-cache" / "OpenAI" / "whisper" - torch_cache_root = tmp_path / "model-cache" / "torch" - huggingface_cache_root = tmp_path / "model-cache" / "huggingface" - modelscope_cache_root = tmp_path / "model-cache" / "modelscope" / "hub" - settings = SimpleNamespace( - models=SimpleNamespace( - whisper_download_root=download_root, - torch_cache_root=torch_cache_root, - huggingface_cache_root=huggingface_cache_root, - modelscope_cache_root=modelscope_cache_root, - ), - torch_runtime=SimpleNamespace(enable_mps_fallback=False), - ) - captured: dict[str, object] = {} - fake_model = object() - - def _fake_load_model(**kwargs: object) -> object: - captured["torch_home"] = os.getenv("TORCH_HOME") - captured.update(kwargs) - return fake_model - - monkeypatch.setattr(te, "reload_settings", lambda: settings) - monkeypatch.setitem( - sys.modules, - "stable_whisper", - SimpleNamespace(load_model=_fake_load_model), - ) - monkeypatch.setattr( - "ser.transcript.backends.stable_whisper." "enable_stable_whisper_mps_compatibility", - lambda model: model, - ) - monkeypatch.delenv("TORCH_HOME", raising=False) - - loaded = te.load_whisper_model( - profile=te.TranscriptionProfile( - backend_id="stable_whisper", - model_name="tiny", - use_demucs=False, - use_vad=False, - ) - ) - - assert loaded is fake_model - assert captured["download_root"] == str(download_root) - assert captured["torch_home"] == str(torch_cache_root) - assert "TORCH_HOME" not in os.environ - assert download_root.is_dir() - assert torch_cache_root.is_dir() - - -def test_load_whisper_model_supports_faster_whisper_backend( - monkeypatch: pytest.MonkeyPatch, - tmp_path: Path, -) -> None: - """Faster-whisper backend should be loadable through the same facade.""" - download_root = tmp_path / "model-cache" / "OpenAI" / "whisper" - torch_cache_root = tmp_path / "model-cache" / "torch" - settings = SimpleNamespace( - models=SimpleNamespace( - whisper_download_root=download_root, - torch_cache_root=torch_cache_root, - ) - ) - captured: dict[str, object] = {} - - class _FakeWhisperModel: - def __init__( - self, - model_size_or_path: str, - *, - device: str, - compute_type: str, - download_root: str, - ) -> None: - captured["model_size_or_path"] = model_size_or_path - captured["device"] = device - captured["compute_type"] = compute_type - captured["download_root"] = download_root - - monkeypatch.setattr(te, "reload_settings", lambda: settings) - monkeypatch.setitem( - sys.modules, - "faster_whisper", - SimpleNamespace(WhisperModel=_FakeWhisperModel), - ) - - loaded = te.load_whisper_model( - profile=te.TranscriptionProfile( - backend_id="faster_whisper", - model_name="distil-large-v3", - use_demucs=False, - use_vad=True, - ) - ) - - assert isinstance(loaded, _FakeWhisperModel) - assert captured["model_size_or_path"] == "distil-large-v3" - assert captured["device"] == "cpu" - assert captured["compute_type"] == "int8" - assert captured["download_root"] == str(download_root) - - -def test_load_whisper_model_uses_runtime_policy_device( - monkeypatch: pytest.MonkeyPatch, - tmp_path: Path, -) -> None: - """Stable-whisper should stage MPS loads through CPU per compatibility flow.""" - download_root = tmp_path / "model-cache" / "OpenAI" / "whisper" - torch_cache_root = tmp_path / "model-cache" / "torch" - huggingface_cache_root = tmp_path / "model-cache" / "huggingface" - modelscope_cache_root = tmp_path / "model-cache" / "modelscope" / "hub" - settings = SimpleNamespace( - models=SimpleNamespace( - whisper_download_root=download_root, - torch_cache_root=torch_cache_root, - huggingface_cache_root=huggingface_cache_root, - modelscope_cache_root=modelscope_cache_root, - ), - torch_runtime=SimpleNamespace( - device="auto", - dtype="auto", - enable_mps_fallback=False, - ), - ) - captured: dict[str, object] = {} - fake_model = object() - - def _fake_load_model(**kwargs: object) -> object: - captured["torch_home"] = os.getenv("TORCH_HOME") - captured.update(kwargs) - return fake_model - - monkeypatch.setattr(te, "reload_settings", lambda: settings) - monkeypatch.setattr( - te, - "resolve_transcription_runtime_policy", - lambda **_kwargs: SimpleNamespace( - device_spec="mps", - device_type="mps", - precision_candidates=("float16", "float32"), - memory_tier="low", - ), - ) - monkeypatch.setitem( - sys.modules, - "stable_whisper", - SimpleNamespace(load_model=_fake_load_model), - ) - monkeypatch.setattr( - "ser.transcript.backends.stable_whisper." "enable_stable_whisper_mps_compatibility", - lambda model: model, - ) - monkeypatch.delenv("TORCH_HOME", raising=False) - - loaded = te.load_whisper_model( - profile=te.TranscriptionProfile( - backend_id="stable_whisper", - model_name="large-v3", - use_demucs=False, - use_vad=True, - ) - ) - - assert loaded is fake_model - assert captured["device"] == "cpu" - assert captured["download_root"] == str(download_root) - - -def test_load_whisper_model_uses_explicit_settings_without_ambient_lookup( - monkeypatch: pytest.MonkeyPatch, - tmp_path: Path, -) -> None: - """Explicit settings should fully drive model loading without ambient fallbacks.""" - download_root = tmp_path / "model-cache" / "OpenAI" / "whisper" - torch_cache_root = tmp_path / "model-cache" / "torch" - huggingface_cache_root = tmp_path / "model-cache" / "huggingface" - modelscope_cache_root = tmp_path / "model-cache" / "modelscope" / "hub" - settings = cast( - te.AppConfig, - SimpleNamespace( - models=SimpleNamespace( - whisper_download_root=download_root, - torch_cache_root=torch_cache_root, - huggingface_cache_root=huggingface_cache_root, - modelscope_cache_root=modelscope_cache_root, - whisper_model=SimpleNamespace(name="large-v3"), - ), - transcription=SimpleNamespace( - backend_id="stable_whisper", - use_demucs=False, - use_vad=True, - ), - torch_runtime=SimpleNamespace( - device="auto", - dtype="auto", - enable_mps_fallback=False, - ), - ), - ) - captured: dict[str, object] = {} - fake_model = object() - - def _fake_load_model(**kwargs: object) -> object: - captured["torch_home"] = os.getenv("TORCH_HOME") - captured.update(kwargs) - return fake_model - - monkeypatch.setattr( - te, - "reload_settings", - lambda: (_ for _ in ()).throw( - AssertionError("explicit settings must bypass ambient resolution") - ), - ) - monkeypatch.setattr( - te, - "resolve_transcription_runtime_policy", - lambda **_kwargs: SimpleNamespace( - device_spec="cpu", - device_type="cpu", - precision_candidates=("float32",), - memory_tier="low", - ), - ) - monkeypatch.setitem( - sys.modules, - "stable_whisper", - SimpleNamespace(load_model=_fake_load_model), - ) - monkeypatch.setattr( - "ser.transcript.backends.stable_whisper." "enable_stable_whisper_mps_compatibility", - lambda model: model, - ) - monkeypatch.delenv("TORCH_HOME", raising=False) - - loaded = te.load_whisper_model(settings=settings) - - assert loaded is fake_model - assert captured["name"] == "large-v3" - assert captured["download_root"] == str(download_root) - assert captured["torch_home"] == str(torch_cache_root) - assert "TORCH_HOME" not in os.environ - - -def test_transcribe_with_model_supports_faster_whisper_word_segments() -> None: - """Faster-whisper segment word payloads should map to TranscriptWord rows.""" - words = [ - SimpleNamespace(word="hello", start=0.0, end=0.2), - SimpleNamespace(word="world", start=0.2, end=0.5), - ] - - class _FakeFasterModel: - def transcribe(self, *_args: object, **_kwargs: object) -> tuple[object, object]: - return iter([SimpleNamespace(words=words)]), object() - - transcript = te.transcribe_with_model( - model=_FakeFasterModel(), - file_path="sample.wav", - language="en", - profile=te.TranscriptionProfile( - backend_id="faster_whisper", - model_name="distil-large-v3", - use_demucs=False, - use_vad=True, - ), - ) - - assert transcript == [ - TranscriptWord("hello", 0.0, 0.2), - TranscriptWord("world", 0.2, 0.5), - ] - - -def test_faster_whisper_info_logs_are_demoted_to_debug_during_transcription() -> None: - """faster-whisper INFO entries should be demoted to DEBUG in transcription scope.""" - words = [SimpleNamespace(word="hello", start=0.0, end=0.2)] - captured: list[logging.LogRecord] = [] - root_logger = logging.getLogger() - original_level = root_logger.level - - class _CaptureHandler(logging.Handler): - def emit(self, record: logging.LogRecord) -> None: - captured.append(record) - - handler = _CaptureHandler(level=logging.DEBUG) - root_logger.setLevel(logging.DEBUG) - root_logger.addHandler(handler) - try: - - class _FakeFasterModel: - def transcribe(self, *_args: object, **_kwargs: object) -> tuple[object, object]: - logging.getLogger("faster_whisper").info("Processing audio sample") - return iter([SimpleNamespace(words=words)]), object() - - transcript = te.transcribe_with_model( - model=_FakeFasterModel(), - file_path="sample.wav", - language="en", - profile=te.TranscriptionProfile( - backend_id="faster_whisper", - model_name="distil-large-v3", - use_demucs=False, - use_vad=True, - ), - ) - finally: - root_logger.removeHandler(handler) - root_logger.setLevel(original_level) - - assert transcript == [TranscriptWord("hello", 0.0, 0.2)] - faster_records = [record for record in captured if record.name.startswith("faster_whisper")] - assert faster_records, "Expected faster_whisper logs to be captured." - assert all(record.levelno == logging.DEBUG for record in faster_records) - - -def test_check_adapter_compatibility_logs_non_blocking_issues_once( - monkeypatch: pytest.MonkeyPatch, - caplog: pytest.LogCaptureFixture, -) -> None: - """Repeated compatibility checks should emit non-blocking issues once.""" - compatibility_report = te.CompatibilityReport( - backend_id="stable_whisper", - operational_issues=( - CompatibilityIssue( - code="torio_ffmpeg_abi_mismatch", - message="torchaudio FFmpeg extension ABI mismatch", - ), - ), - noise_issues=( - CompatibilityIssue( - code="stable_whisper_invalid_escape_sequence", - message="stable-whisper import warning noise", - ), - ), - ) - - class _FakeAdapter: - def check_compatibility( - self, - *, - runtime_request: te.BackendRuntimeRequest, - settings: object, - ) -> te.CompatibilityReport: - del runtime_request - del settings - return compatibility_report - - monkeypatch.setattr( - te._boundary_support, - "resolve_transcription_backend_adapter", - lambda _backend_id: _FakeAdapter(), - ) - monkeypatch.setattr(te, "_EMITTED_COMPATIBILITY_ISSUE_KEYS", set()) - runtime_request = te.BackendRuntimeRequest( - model_name="large-v2", - use_demucs=False, - use_vad=True, - ) - profile = te.TranscriptionProfile( - backend_id="stable_whisper", - model_name="large-v2", - ) - settings = cast(te.AppConfig, SimpleNamespace()) - - with caplog.at_level(logging.DEBUG): - _ = te._check_adapter_compatibility( - active_profile=profile, - settings=settings, - runtime_request=runtime_request, - ) - _ = te._check_adapter_compatibility( - active_profile=profile, - settings=settings, - runtime_request=runtime_request, - ) - - noise_records = [ - record - for record in caplog.records - if "noise issue [stable_whisper_invalid_escape_sequence]" in record.getMessage() - ] - operational_records = [ - record - for record in caplog.records - if "operational issue [torio_ffmpeg_abi_mismatch]" in record.getMessage() - ] - assert len(noise_records) == 1 - assert len(operational_records) == 1 - - -def test_check_adapter_compatibility_delegates_to_internal_service( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Extractor compatibility wrapper should delegate with injected dependencies.""" - captured: dict[str, object] = {} - emitted_issue_keys: set[tuple[str, str, str]] = set() - report = te.CompatibilityReport(backend_id="stable_whisper") - - def _fake_impl(**kwargs: object) -> te.CompatibilityReport: - captured.update(kwargs) - return report - - monkeypatch.setattr(te, "_EMITTED_COMPATIBILITY_ISSUE_KEYS", emitted_issue_keys) - monkeypatch.setattr(te._boundary_support, "_check_adapter_compatibility_impl", _fake_impl) - runtime_request = te.BackendRuntimeRequest( - model_name="large-v2", - use_demucs=False, - use_vad=True, - ) - profile = te.TranscriptionProfile( - backend_id="stable_whisper", - model_name="large-v2", - ) - settings = cast(te.AppConfig, SimpleNamespace()) - - resolved = te._check_adapter_compatibility( - active_profile=profile, - settings=settings, - runtime_request=runtime_request, - ) - - assert resolved is report - assert captured["active_profile"] == profile - assert captured["settings"] is settings - assert captured["runtime_request"] == runtime_request - assert ( - captured["runtime_request_resolver"] is te._boundary_support._runtime_request_from_profile - ) - assert ( - captured["adapter_resolver"] is te._boundary_support.resolve_transcription_backend_adapter - ) - assert captured["error_factory"] is te.TranscriptionError - assert captured["emitted_issue_keys"] is emitted_issue_keys - assert captured["logger"] is te.logger - - -def test_check_adapter_compatibility_delegates_to_boundary_owner( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Compatibility wrapper should delegate boundary assembly to the internal owner.""" - captured: dict[str, object] = {} - report = te.CompatibilityReport(backend_id="stable_whisper") - profile = te.TranscriptionProfile( - backend_id="stable_whisper", - model_name="large-v2", - ) - settings = cast(te.AppConfig, SimpleNamespace()) - emitted_issue_keys: set[tuple[str, str, str]] = set() - - def _fake_boundary_impl(**kwargs: object) -> te.CompatibilityReport: - captured.update(kwargs) - return report - - monkeypatch.setattr(te, "_EMITTED_COMPATIBILITY_ISSUE_KEYS", emitted_issue_keys) - monkeypatch.setattr( - te._boundary_support, - "_check_adapter_compatibility_boundary_impl", - _fake_boundary_impl, - ) - - resolved = te._check_adapter_compatibility( - active_profile=profile, - settings=settings, - ) - - assert resolved is report - assert captured["active_profile"] == profile - assert captured["settings"] is settings - assert captured["runtime_request"] is None - assert ( - captured["check_adapter_compatibility_impl"] - is te._boundary_support._check_adapter_compatibility_impl - ) - assert ( - captured["runtime_request_resolver"] is te._boundary_support._runtime_request_from_profile - ) - assert ( - captured["adapter_resolver"] is te._boundary_support.resolve_transcription_backend_adapter - ) - assert captured["emitted_issue_keys"] is emitted_issue_keys - assert captured["logger"] is te.logger - assert captured["error_factory"] is te.TranscriptionError - - -def test_runtime_request_from_profile_delegates_to_internal_service( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Runtime-request wrapper should delegate with policy/default injections.""" - captured: dict[str, object] = {} - expected = te.BackendRuntimeRequest( - model_name="large-v2", - use_demucs=True, - use_vad=True, - device_spec="cpu", - device_type="cpu", - precision_candidates=("float32",), - memory_tier="unknown", - ) - - def _fake_impl(**kwargs: object) -> te.BackendRuntimeRequest: - captured.update(kwargs) - return expected - - monkeypatch.setattr(te, "_runtime_request_from_profile_impl", _fake_impl) - profile = te.TranscriptionProfile( - backend_id="stable_whisper", - model_name="large-v2", - ) - settings = cast(te.AppConfig, SimpleNamespace()) - - resolved = te._runtime_request_from_profile(profile, settings) - - assert resolved is expected - assert captured["active_profile"] == profile - assert captured["settings"] is settings - assert captured["runtime_policy_resolver"] is te.resolve_transcription_runtime_policy - assert captured["default_mps_low_memory_threshold_gb"] == te.DEFAULT_MPS_LOW_MEMORY_THRESHOLD_GB - - -def test_run_faster_whisper_process_isolated_delegates_to_internal_service( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Process-isolated wrapper should delegate with injected dependencies.""" - captured: dict[str, object] = {} - expected = [TranscriptWord("hello", 0.0, 0.5)] - settings = cast( - te.AppConfig, - SimpleNamespace(torch_runtime=SimpleNamespace(device="cpu", dtype="auto")), - ) - - def _fake_impl(**kwargs: object) -> list[TranscriptWord]: - captured.update(kwargs) - return expected - - monkeypatch.setattr( - te._boundary_support, - "_run_faster_whisper_process_isolated_impl", - _fake_impl, - ) - profile = te.TranscriptionProfile( - backend_id="faster_whisper", - model_name="distil-large-v3", - use_demucs=False, - use_vad=True, - ) - - resolved = te._run_faster_whisper_process_isolated( - file_path="sample.wav", - language="en", - profile=profile, - settings=settings, - ) - - assert resolved == expected - assert captured["file_path"] == "sample.wav" - assert captured["language"] == "en" - assert captured["profile"] == profile - settings_resolver = cast(Callable[[], te.AppConfig], captured["settings_resolver"]) - assert settings_resolver() is settings - runtime_request_resolver = cast( - Callable[[te.TranscriptionProfile, te.AppConfig], te.BackendRuntimeRequest], - captured["runtime_request_resolver"], - ) - runtime_request = runtime_request_resolver( - profile, - cast( - te.AppConfig, - SimpleNamespace(torch_runtime=SimpleNamespace(device="cpu", dtype="auto")), - ), - ) - assert runtime_request.device_type == "cpu" - assert runtime_request.precision_candidates == ("float32",) - assert captured["payload_factory"] is te._boundary_support._build_transcription_process_payload - payload_factory = cast(Callable[..., object], captured["payload_factory"]) - payload = cast( - te._TranscriptionProcessPayload, - payload_factory( - file_path="sample.wav", - language="en", - profile=profile, - runtime_request=te.BackendRuntimeRequest( - model_name="distil-large-v3", - use_demucs=False, - use_vad=True, - device_spec="cpu", - device_type="cpu", - precision_candidates=("float32",), - memory_tier="not_applicable", - ), - settings=cast( - te.AppConfig, - SimpleNamespace( - models=SimpleNamespace(whisper_download_root=Path("/tmp/whisper-cache")) - ), - ), - ), - ) - assert payload.settings.models.whisper_download_root == Path("/tmp/whisper-cache") - assert captured["get_spawn_context"] is te._spawn_context - get_spawn_context = cast(Callable[[], object], captured["get_spawn_context"]) - spawn_context = get_spawn_context() - assert hasattr(spawn_context, "Pipe") - assert hasattr(spawn_context, "Process") - worker_entry = cast(Callable[[object, object], None], captured["worker_entry"]) - assert worker_entry is te._transcription_worker_entry - assert ForkingPickler.dumps(worker_entry) - terminate_worker_process = cast( - Callable[[object], None], - captured["terminate_worker_process_fn"], - ) - assert terminate_worker_process is te._terminate_worker_process - assert callable(captured["recv_worker_message_fn"]) - assert captured["raise_worker_error_fn"] is te._raise_worker_error - assert captured["transcript_word_factory"] is te.TranscriptWord - assert captured["logger"] is te.logger - assert captured["error_factory"] is te.TranscriptionError - assert captured["terminate_grace_seconds"] == te._TERMINATE_GRACE_SECONDS - - -def test_run_faster_whisper_process_isolated_delegates_to_boundary_owner( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Public isolated-run wrapper should delegate assembly to the internal boundary owner.""" - captured: dict[str, object] = {} - expected = [TranscriptWord("hello", 0.0, 0.5)] - profile = te.TranscriptionProfile( - backend_id="faster_whisper", - model_name="distil-large-v3", - use_demucs=False, - use_vad=True, - ) - settings = cast(te.AppConfig, SimpleNamespace()) - - def _fake_boundary_impl(**kwargs: object) -> list[TranscriptWord]: - captured.update(kwargs) - return expected - - monkeypatch.setattr( - te._boundary_support, - "_run_faster_whisper_process_isolated_boundary_impl", - _fake_boundary_impl, - ) - - resolved = te._run_faster_whisper_process_isolated( - file_path="sample.wav", - language="en", - profile=profile, - settings=settings, - ) - - assert resolved == expected - assert captured["file_path"] == "sample.wav" - assert captured["language"] == "en" - assert captured["profile"] == profile - assert captured["settings"] is settings - assert ( - captured["run_faster_whisper_process_isolated_impl"] - is te._boundary_support._run_faster_whisper_process_isolated_impl - ) - assert callable(captured["runtime_request_resolver"]) - assert captured["payload_factory"] is te._boundary_support._build_transcription_process_payload - assert captured["spawn_context_resolver"] is te._spawn_context - assert captured["worker_entry"] is te._transcription_worker_entry - assert callable(captured["recv_worker_message_fn"]) - assert captured["raise_worker_error_fn"] is te._raise_worker_error - assert captured["terminate_worker_process_fn"] is te._terminate_worker_process - assert captured["logger"] is te.logger - assert captured["error_factory"] is te.TranscriptionError - assert captured["terminate_grace_seconds"] == te._TERMINATE_GRACE_SECONDS - - -def test_extract_transcript_in_process_delegates_to_internal_service( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """In-process transcript wrapper should delegate with injected dependencies.""" - captured: dict[str, object] = {} - expected = [TranscriptWord("hello", 0.0, 0.5)] - settings = cast(te.AppConfig, SimpleNamespace()) - - def _fake_impl(**kwargs: object) -> list[TranscriptWord]: - captured.update(kwargs) - return expected - - monkeypatch.setattr(te._boundary_support, "_extract_transcript_in_process_impl", _fake_impl) - profile = te.TranscriptionProfile( - backend_id="stable_whisper", - model_name="large-v2", - use_demucs=True, - use_vad=True, - ) - - resolved = te._extract_transcript_in_process( - file_path="sample.wav", - language="en", - profile=profile, - settings=settings, - ) - - assert resolved == expected - assert captured["file_path"] == "sample.wav" - assert captured["language"] == "en" - assert captured["profile"] == profile - settings_resolver = cast(Callable[[], te.AppConfig], captured["settings_resolver"]) - assert settings_resolver() is settings - assert callable(captured["setup_required_checker"]) - assert callable(captured["prepare_assets_runner"]) - assert callable(captured["load_model_fn"]) - assert callable(captured["transcribe_with_profile_fn"]) - assert captured["release_memory_fn"] is te._release_transcription_runtime_memory - assert captured["logger"] is te.logger - - -def test_extract_transcript_in_process_delegates_to_boundary_owner( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """In-process transcript wrapper should delegate boundary assembly to the internal owner.""" - captured: dict[str, object] = {} - expected = [TranscriptWord("hello", 0.0, 0.5)] - settings = cast(te.AppConfig, SimpleNamespace()) - profile = te.TranscriptionProfile( - backend_id="stable_whisper", - model_name="large-v2", - use_demucs=True, - use_vad=True, - ) - - def _fake_boundary_impl(**kwargs: object) -> list[TranscriptWord]: - captured.update(kwargs) - return expected - - monkeypatch.setattr( - te._boundary_support, - "_extract_transcript_in_process_boundary_impl", - _fake_boundary_impl, - ) - - resolved = te._extract_transcript_in_process( - file_path="sample.wav", - language="en", - profile=profile, - settings=settings, - ) - - assert resolved == expected - assert captured["file_path"] == "sample.wav" - assert captured["language"] == "en" - assert captured["profile"] == profile - assert captured["settings"] is settings - assert ( - captured["extract_transcript_in_process_impl"] - is te._boundary_support._extract_transcript_in_process_impl - ) - assert callable(captured["setup_required_checker"]) - assert callable(captured["prepare_assets_runner"]) - assert callable(captured["load_whisper_model_fn"]) - assert callable(captured["transcribe_with_profile_fn"]) - assert captured["release_memory_fn"] is te._release_transcription_runtime_memory - assert captured["phase_started_fn"] is te.log_phase_started - assert captured["phase_completed_fn"] is te.log_phase_completed - assert captured["phase_failed_fn"] is te.log_phase_failed - assert captured["logger"] is te.logger - - -def test_transcription_setup_required_delegates_to_boundary_owner( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Setup-required wrapper should delegate boundary assembly to the runtime owner.""" - captured: dict[str, object] = {} - profile = te.TranscriptionProfile( - backend_id="stable_whisper", - model_name="large-v2", - ) - settings = cast(te.AppConfig, SimpleNamespace()) - - def _fake_boundary_impl(**kwargs: object) -> bool: - captured.update(kwargs) - return True - - monkeypatch.setattr( - te._boundary_support, - "_transcription_setup_required_boundary_impl", - _fake_boundary_impl, - ) - - required = te._transcription_setup_required( - active_profile=profile, - settings=settings, - ) - - assert required is True - assert captured["active_profile"] == profile - assert captured["settings"] is settings - assert ( - captured["transcription_setup_required_impl"] - is te._boundary_support._transcription_setup_required_impl - ) - assert ( - captured["runtime_request_resolver"] is te._boundary_support._runtime_request_from_profile - ) - assert callable(captured["compatibility_checker"]) - assert ( - captured["adapter_resolver"] is te._boundary_support.resolve_transcription_backend_adapter - ) - - -def test_prepare_transcription_assets_delegates_to_boundary_owner( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Asset-preparation wrapper should delegate boundary assembly to the runtime owner.""" - captured: dict[str, object] = {} - profile = te.TranscriptionProfile( - backend_id="stable_whisper", - model_name="large-v2", - ) - settings = cast(te.AppConfig, SimpleNamespace()) - - def _fake_boundary_impl(**kwargs: object) -> None: - captured.update(kwargs) - - monkeypatch.setattr( - te._boundary_support, - "_prepare_transcription_assets_boundary_impl", - _fake_boundary_impl, - ) - - te._prepare_transcription_assets( - active_profile=profile, - settings=settings, - ) - - assert captured["active_profile"] == profile - assert captured["settings"] is settings - assert ( - captured["prepare_transcription_assets_impl"] - is te._boundary_support._prepare_transcription_assets_impl - ) - assert ( - captured["runtime_request_resolver"] is te._boundary_support._runtime_request_from_profile - ) - assert callable(captured["compatibility_checker"]) - assert ( - captured["adapter_resolver"] is te._boundary_support.resolve_transcription_backend_adapter - ) - - -def test_mark_compatibility_issues_as_emitted_suppresses_duplicate_operational_logs( - monkeypatch: pytest.MonkeyPatch, - caplog: pytest.LogCaptureFixture, -) -> None: - """Pre-emitted compatibility issues should not be logged again.""" - monkeypatch.setattr(te, "_EMITTED_COMPATIBILITY_ISSUE_KEYS", set()) - te.mark_compatibility_issues_as_emitted( - backend_id="stable_whisper", - issue_kind="operational", - issue_codes=("torio_ffmpeg_abi_mismatch",), - ) - - class _Adapter: - def check_compatibility( - self, - *, - runtime_request: te.BackendRuntimeRequest, - settings: te.AppConfig, - ) -> te.CompatibilityReport: - del runtime_request, settings - return te.CompatibilityReport( - backend_id="stable_whisper", - operational_issues=( - CompatibilityIssue( - code="torio_ffmpeg_abi_mismatch", - message="torchaudio FFmpeg extension ABI mismatch", - impact="degraded", - ), - ), - ) - - monkeypatch.setattr( - te._boundary_support, - "resolve_transcription_backend_adapter", - lambda _backend_id: cast(object, _Adapter()), - ) - profile = te.TranscriptionProfile( - backend_id="stable_whisper", - model_name="large-v2", - ) - runtime_request = te.BackendRuntimeRequest( - model_name="large-v2", - use_demucs=False, - use_vad=True, - ) - - with caplog.at_level(logging.WARNING, logger=te.logger.name): - te._check_adapter_compatibility( - active_profile=profile, - settings=cast(te.AppConfig, SimpleNamespace()), - runtime_request=runtime_request, - ) - - records = [ - record - for record in caplog.records - if "operational issue [torio_ffmpeg_abi_mismatch]" in record.getMessage() - ] - assert records == [] - - -def test_extract_transcript_logs_setup_before_model_load_when_required( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Transcription setup phase should precede model load when download is needed.""" - phase_events: list[tuple[str, str]] = [] - settings = cast(te.AppConfig, SimpleNamespace(default_language="en")) - - def _fake_phase_started(_logger: object, *, phase_name: str) -> float: - phase_events.append(("start", phase_name)) - return 1.0 - - def _fake_phase_completed( - _logger: object, - *, - phase_name: str, - started_at: float, - ) -> None: - phase_events.append(("completed", phase_name)) - assert started_at == 1.0 - - monkeypatch.setattr(te, "log_phase_started", _fake_phase_started) - monkeypatch.setattr(te, "log_phase_completed", _fake_phase_completed) - monkeypatch.setattr(te, "log_phase_failed", lambda *_a, **_k: None) - monkeypatch.setattr(te._boundary_support, "transcription_setup_required", lambda **_k: True) - monkeypatch.setattr(te._boundary_support, "prepare_transcription_assets", lambda **_k: None) - monkeypatch.setattr( - te._boundary_support, - "load_whisper_model_for_settings", - lambda *_args, **_kwargs: object(), - ) - monkeypatch.setattr( - te._boundary_support, - "transcribe_with_profile", - lambda *_args, **_kwargs: [], - ) - - _ = te._extract_transcript( - "sample.wav", - "en", - te.TranscriptionProfile(backend_id="stable_whisper", model_name="large-v2"), - settings=settings, - ) - - assert phase_events == [ - ("start", PHASE_TRANSCRIPTION_SETUP), - ("completed", PHASE_TRANSCRIPTION_SETUP), - ("start", PHASE_TRANSCRIPTION_MODEL_LOAD), - ("completed", PHASE_TRANSCRIPTION_MODEL_LOAD), - ("start", PHASE_TRANSCRIPTION), - ("completed", PHASE_TRANSCRIPTION), - ] - - -def test_extract_transcript_skips_setup_phase_when_not_required( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Transcription setup phase should be omitted when assets are already present.""" - phase_events: list[tuple[str, str]] = [] - settings = cast(te.AppConfig, SimpleNamespace(default_language="en")) - - def _fake_phase_started(_logger: object, *, phase_name: str) -> float: - phase_events.append(("start", phase_name)) - return 1.0 - - def _fake_phase_completed( - _logger: object, - *, - phase_name: str, - started_at: float, - ) -> None: - phase_events.append(("completed", phase_name)) - assert started_at == 1.0 - - def _fail_prepare(**_kwargs: object) -> None: - raise AssertionError("setup should not run") - - monkeypatch.setattr(te, "log_phase_started", _fake_phase_started) - monkeypatch.setattr(te, "log_phase_completed", _fake_phase_completed) - monkeypatch.setattr(te, "log_phase_failed", lambda *_a, **_k: None) - monkeypatch.setattr(te._boundary_support, "transcription_setup_required", lambda **_k: False) - monkeypatch.setattr(te._boundary_support, "prepare_transcription_assets", _fail_prepare) - monkeypatch.setattr( - te._boundary_support, - "load_whisper_model_for_settings", - lambda *_args, **_kwargs: object(), - ) - monkeypatch.setattr( - te._boundary_support, - "transcribe_with_profile", - lambda *_args, **_kwargs: [], - ) - - _ = te._extract_transcript( - "sample.wav", - "en", - te.TranscriptionProfile(backend_id="stable_whisper", model_name="large-v2"), - settings=settings, - ) - - assert phase_events == [ - ("start", PHASE_TRANSCRIPTION_MODEL_LOAD), - ("completed", PHASE_TRANSCRIPTION_MODEL_LOAD), - ("start", PHASE_TRANSCRIPTION), - ("completed", PHASE_TRANSCRIPTION), - ] - - -def test_extract_transcript_uses_process_isolation_for_faster_whisper( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """faster-whisper profiles should route to process-isolated execution path.""" - captured: dict[str, object] = {} - expected = [TranscriptWord("hello", 0.0, 0.5)] - profile = te.TranscriptionProfile( - backend_id="faster_whisper", - model_name="distil-large-v3", - use_demucs=False, - use_vad=True, - ) - - def _fake_isolated_runner(**kwargs: object) -> list[TranscriptWord]: - captured.update(kwargs) - return expected - - def _fail_in_process(**_kwargs: object) -> list[TranscriptWord]: - raise AssertionError("in-process path should not be used for faster-whisper") - - monkeypatch.setattr( - te._boundary_support, "run_faster_whisper_process_isolated", _fake_isolated_runner - ) - monkeypatch.setattr(te._boundary_support, "extract_transcript_in_process", _fail_in_process) - monkeypatch.setattr( - te, - "reload_settings", - lambda: (_ for _ in ()).throw(AssertionError("private helper must use explicit settings")), - ) - - settings = cast(te.AppConfig, SimpleNamespace(default_language="en")) - transcript = te._extract_transcript("sample.wav", "en", profile, settings=settings) - - assert transcript == expected - assert captured["file_path"] == "sample.wav" - assert captured["language"] == "en" - assert captured["profile"] == profile - assert captured["settings"] is settings - - -def test_extract_transcript_routes_non_faster_profiles_in_process( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Non faster-whisper profiles should route to in-process execution.""" - profile = te.TranscriptionProfile( - backend_id="stable_whisper", - model_name="large-v2", - use_demucs=True, - use_vad=True, - ) - expected = [TranscriptWord("hello", 0.0, 0.5)] - captured: dict[str, object] = {} - - def _fail_isolated_runner(**_kwargs: object) -> list[TranscriptWord]: - raise AssertionError("process-isolated path should not be used") - - def _fake_in_process_runner(**kwargs: object) -> list[TranscriptWord]: - captured.update(kwargs) - return expected - - monkeypatch.setattr( - te._boundary_support, "run_faster_whisper_process_isolated", _fail_isolated_runner - ) - monkeypatch.setattr( - te._boundary_support, "extract_transcript_in_process", _fake_in_process_runner - ) - monkeypatch.setattr( - te, - "reload_settings", - lambda: (_ for _ in ()).throw(AssertionError("private helper must use explicit settings")), - ) - - settings = cast(te.AppConfig, SimpleNamespace(default_language="en")) - transcript = te._extract_transcript("sample.wav", "en", profile, settings=settings) - - assert transcript == expected - assert captured["file_path"] == "sample.wav" - assert captured["language"] == "en" - assert captured["profile"] == profile - assert captured["settings"] is settings - - -def test_transcribe_with_profile_uses_explicit_settings_without_ambient_lookup( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Private transcription helper should honor caller-provided settings only.""" - settings = cast(te.AppConfig, SimpleNamespace(default_language="en")) - profile = te.TranscriptionProfile( - backend_id="stable_whisper", - model_name="large-v2", - use_demucs=False, - use_vad=True, - ) - runtime_request = cast(object, SimpleNamespace()) - captured: dict[str, object] = {} - - def _fake_transcribe(**kwargs: object) -> list[TranscriptWord]: - captured.update(kwargs) - return [TranscriptWord("hello", 0.0, 0.5)] - - monkeypatch.setattr( - te, - "reload_settings", - lambda: (_ for _ in ()).throw(AssertionError("private helper must use explicit settings")), - ) - monkeypatch.setattr( - te._boundary_support, "_runtime_request_from_profile", lambda *_a: runtime_request - ) - monkeypatch.setattr(te._boundary_support, "check_adapter_compatibility", lambda **_kwargs: None) - monkeypatch.setattr( - te._boundary_support, - "resolve_transcription_backend_adapter", - lambda _backend_id: SimpleNamespace(transcribe=_fake_transcribe), - ) - - transcript = te._transcribe_file_with_profile( - object(), - "en", - "sample.wav", - profile, - settings=settings, - ) - - assert transcript == [TranscriptWord("hello", 0.0, 0.5)] - assert captured["runtime_request"] is runtime_request - assert captured["file_path"] == "sample.wav" - assert captured["language"] == "en" - assert captured["settings"] is settings - - -def test_transcribe_with_profile_resolves_defaults_from_explicit_settings( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Default transcription profile should resolve from explicit settings only.""" - settings = cast( - te.AppConfig, - SimpleNamespace( - default_language="en", - models=SimpleNamespace(whisper_model=SimpleNamespace(name="large-v3")), - transcription=SimpleNamespace( - backend_id="stable_whisper", - use_demucs=False, - use_vad=True, - ), - torch_runtime=SimpleNamespace(device="cpu", dtype="auto"), - ), - ) - runtime_request = cast(object, SimpleNamespace()) - captured: dict[str, object] = {} - - def _fake_runtime_request( - active_profile: te.TranscriptionProfile, - active_settings: te.AppConfig, - ) -> object: - captured["active_profile"] = active_profile - captured["runtime_settings"] = active_settings - return runtime_request - - def _fake_transcribe(**kwargs: object) -> list[TranscriptWord]: - captured.update(kwargs) - return [TranscriptWord("hello", 0.0, 0.5)] - - monkeypatch.setattr( - te, - "reload_settings", - lambda: (_ for _ in ()).throw(AssertionError("private helper must use explicit settings")), - ) - monkeypatch.setattr( - te._boundary_support, "_runtime_request_from_profile", _fake_runtime_request - ) - monkeypatch.setattr(te._boundary_support, "check_adapter_compatibility", lambda **_kwargs: None) - monkeypatch.setattr( - te._boundary_support, - "resolve_transcription_backend_adapter", - lambda _backend_id: SimpleNamespace(transcribe=_fake_transcribe), - ) - - transcript = te._transcribe_file_with_profile( - object(), - "en", - "sample.wav", - None, - settings=settings, - ) - - assert transcript == [TranscriptWord("hello", 0.0, 0.5)] - active_profile = cast(te.TranscriptionProfile, captured["active_profile"]) - assert active_profile.backend_id == "stable_whisper" - assert active_profile.model_name == "large-v3" - assert active_profile.use_demucs is False - assert active_profile.use_vad is True - assert captured["runtime_settings"] is settings - assert captured["runtime_request"] is runtime_request - assert captured["settings"] is settings - - -def test_transcribe_with_profile_delegates_to_boundary_owner( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Private transcription helper should delegate boundary assembly to the internal owner.""" - captured: dict[str, object] = {} - settings = cast(te.AppConfig, SimpleNamespace(default_language="en")) - expected = [TranscriptWord("hello", 0.0, 0.5)] - model = object() - profile = te.TranscriptionProfile( - backend_id="stable_whisper", - model_name="large-v2", - use_demucs=False, - use_vad=True, - ) - - def _fake_boundary_impl(*args: object, **kwargs: object) -> list[TranscriptWord]: - captured["args"] = args - captured.update(kwargs) - return expected - - monkeypatch.setattr( - te._boundary_support, - "_transcribe_with_profile_boundary_impl", - _fake_boundary_impl, - ) - - transcript = te._transcribe_file_with_profile( - model, - "en", - "sample.wav", - profile, - settings=settings, - ) - - assert transcript == expected - assert captured["args"] == (model, "en", "sample.wav", profile) - assert captured["settings"] is settings - assert ( - captured["transcribe_with_profile_entrypoint"] - is te._boundary_support._transcribe_with_profile_entrypoint - ) - assert callable(captured["resolve_profile_for_settings"]) - assert ( - captured["runtime_request_resolver"] is te._boundary_support._runtime_request_from_profile - ) - assert callable(captured["compatibility_checker"]) - assert ( - captured["adapter_resolver"] is te._boundary_support.resolve_transcription_backend_adapter - ) - assert captured["passthrough_error_cls"] is te.TranscriptionError - assert captured["logger"] is te.logger - assert captured["error_factory"] is te.TranscriptionError - - -class _FakeIsolatedParentConnection: - """Parent pipe endpoint with deterministic message queue.""" - - def __init__(self, messages: list[tuple[object, ...]]) -> None: - self._messages = messages - self.closed = False - - def recv(self) -> tuple[object, ...]: - if not self._messages: - raise EOFError - return self._messages.pop(0) - - def close(self) -> None: - self.closed = True - - -class _FakeIsolatedChildConnection: - """Child pipe endpoint for process-isolated transcript tests.""" - - def __init__(self) -> None: - self.closed = False - - def close(self) -> None: - self.closed = True - - -class _FakeIsolatedProcess: - """Fake process supporting join-first and terminate-fallback scenarios.""" - - def __init__(self, *, exit_on_join: bool) -> None: - self._alive = False - self.closed = False - self.join_timeouts: list[float | None] = [] - self._exit_on_join = exit_on_join - - def start(self) -> None: - self._alive = True - - def join(self, timeout: float | None = None) -> None: - self.join_timeouts.append(timeout) - if self._exit_on_join: - self._alive = False - - def is_alive(self) -> bool: - return self._alive - - def close(self) -> None: - self.closed = True - - -class _FakeIsolatedContext: - """Fake spawn context for deterministic process-isolated cleanup tests.""" - - def __init__( - self, - *, - parent_conn: _FakeIsolatedParentConnection, - child_conn: _FakeIsolatedChildConnection, - process: _FakeIsolatedProcess, - ) -> None: - self.parent_conn = parent_conn - self.child_conn = child_conn - self.process = process - - def Pipe( - self, duplex: bool = False - ) -> tuple[_FakeIsolatedParentConnection, _FakeIsolatedChildConnection]: - assert duplex is False - return self.parent_conn, self.child_conn - - def Process( - self, - *, - target: object, - args: tuple[object, ...], - daemon: bool, - ) -> _FakeIsolatedProcess: - del target, args - assert daemon is False - return self.process - - -def _build_fake_isolated_context( - *, - messages: list[tuple[object, ...]], - exit_on_join: bool, -) -> _FakeIsolatedContext: - """Builds deterministic fake multiprocessing context for isolated runs.""" - return _FakeIsolatedContext( - parent_conn=_FakeIsolatedParentConnection(messages), - child_conn=_FakeIsolatedChildConnection(), - process=_FakeIsolatedProcess(exit_on_join=exit_on_join), - ) - - -def test_faster_whisper_isolated_run_joins_worker_before_terminate_on_success( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Successful isolated runs should allow worker shutdown before terminate fallback.""" - messages: list[tuple[object, ...]] = [ - ("phase", "setup_complete"), - ("phase", "model_loaded"), - ("ok", [("hello", 0.0, 0.5)]), - ] - terminate_calls: list[object] = [] - - context = _build_fake_isolated_context(messages=messages, exit_on_join=True) - profile = te.TranscriptionProfile( - backend_id="faster_whisper", - model_name="distil-large-v3", - use_demucs=False, - use_vad=True, - ) - monkeypatch.setattr(te.mp, "get_context", lambda _method: context) - monkeypatch.setattr(te, "_terminate_worker_process", terminate_calls.append) - - result = te._run_faster_whisper_process_isolated( - file_path="sample.wav", - language="en", - profile=profile, - settings=cast( - te.AppConfig, - SimpleNamespace( - torch_runtime=SimpleNamespace(device="cpu", dtype="auto"), - models=SimpleNamespace(whisper_download_root=Path("/tmp/whisper-cache")), - ), - ), - ) - - assert result == [TranscriptWord("hello", 0.0, 0.5)] - assert context.process.join_timeouts == [te._TERMINATE_GRACE_SECONDS] - assert terminate_calls == [] - assert context.parent_conn.closed is True - assert context.child_conn.closed is True - assert context.process.closed is True - - -def test_faster_whisper_isolated_run_terminates_worker_after_join_timeout( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Isolated runs should still terminate workers that remain alive after join timeout.""" - messages: list[tuple[object, ...]] = [ - ("phase", "setup_complete"), - ("phase", "model_loaded"), - ("ok", [("hello", 0.0, 0.5)]), - ] - terminate_calls: list[object] = [] - - def _fake_terminate_worker_process(process: object) -> None: - terminate_calls.append(process) - cast(_FakeIsolatedProcess, process)._alive = False - - context = _build_fake_isolated_context(messages=messages, exit_on_join=False) - profile = te.TranscriptionProfile( - backend_id="faster_whisper", - model_name="distil-large-v3", - use_demucs=False, - use_vad=True, - ) - monkeypatch.setattr(te.mp, "get_context", lambda _method: context) - monkeypatch.setattr(te, "_terminate_worker_process", _fake_terminate_worker_process) - - result = te._run_faster_whisper_process_isolated( - file_path="sample.wav", - language="en", - profile=profile, - settings=cast( - te.AppConfig, - SimpleNamespace( - torch_runtime=SimpleNamespace(device="cpu", dtype="auto"), - models=SimpleNamespace(whisper_download_root=Path("/tmp/whisper-cache")), - ), - ), - ) - - assert result == [TranscriptWord("hello", 0.0, 0.5)] - assert context.process.join_timeouts == [te._TERMINATE_GRACE_SECONDS] - assert terminate_calls == [context.process] - assert context.parent_conn.closed is True - assert context.child_conn.closed is True - assert context.process.closed is True - - -def test_runtime_request_for_isolated_faster_whisper_defaults_to_cpu() -> None: - """Process-isolated faster runtime request should avoid torch dependency on CPU.""" - settings = cast( - te.AppConfig, - SimpleNamespace(torch_runtime=SimpleNamespace(device="auto", dtype="auto")), - ) - profile = te.TranscriptionProfile( - backend_id="faster_whisper", - model_name="distil-large-v3", - use_demucs=False, - use_vad=True, - ) - - runtime_request = te._runtime_request_for_isolated_faster_whisper( - profile=profile, - settings=settings, - ) - - assert runtime_request.device_spec == "cpu" - assert runtime_request.device_type == "cpu" - assert runtime_request.precision_candidates == ("float32",) - - -def test_runtime_request_for_isolated_faster_whisper_honors_cuda_request() -> None: - """Process-isolated faster runtime request should preserve explicit CUDA selectors.""" - settings = cast( - te.AppConfig, - SimpleNamespace(torch_runtime=SimpleNamespace(device="cuda:0", dtype="float16")), - ) - profile = te.TranscriptionProfile( - backend_id="faster_whisper", - model_name="distil-large-v3", - use_demucs=False, - use_vad=True, - ) - - runtime_request = te._runtime_request_for_isolated_faster_whisper( - profile=profile, - settings=settings, - ) - - assert runtime_request.device_spec == "cuda:0" - assert runtime_request.device_type == "cuda" - assert runtime_request.precision_candidates == ("float16",) - - -def test_runtime_request_for_isolated_faster_whisper_logs_non_cuda_fallback( - caplog: pytest.LogCaptureFixture, -) -> None: - """Unsupported isolated-device selectors should fall back to CPU with one info log.""" - settings = cast( - te.AppConfig, - SimpleNamespace(torch_runtime=SimpleNamespace(device="mps", dtype="float16")), - ) - profile = te.TranscriptionProfile( - backend_id="faster_whisper", - model_name="distil-large-v3", - use_demucs=False, - use_vad=True, - ) - caplog.set_level(logging.INFO, logger=te.logger.name) - - runtime_request = te._runtime_request_for_isolated_faster_whisper( - profile=profile, - settings=settings, - ) - - assert runtime_request.device_spec == "cpu" - assert runtime_request.device_type == "cpu" - assert runtime_request.precision_candidates == ("float32",) - assert any( - "requested device 'mps' is unsupported; using cpu/float32" in record.getMessage() - for record in caplog.records - ) - - -def test_runtime_request_for_isolated_faster_whisper_rejects_non_faster_backend() -> None: - """Isolated runtime request helper should fail fast for non-faster backends.""" - settings = cast( - te.AppConfig, - SimpleNamespace(torch_runtime=SimpleNamespace(device="cpu", dtype="auto")), - ) - profile = te.TranscriptionProfile( - backend_id="stable_whisper", - model_name="large-v2", - ) - - with pytest.raises( - te.TranscriptionError, - match="only supports faster-whisper backend", - ): - te._runtime_request_for_isolated_faster_whisper( - profile=profile, - settings=settings, - ) - - -def test_run_faster_whisper_process_isolated_rejects_non_faster_backend_before_spawn( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Process-isolated entrypoint should reject unsupported backend before spawning.""" - profile = te.TranscriptionProfile( - backend_id="stable_whisper", - model_name="large-v2", - ) - monkeypatch.setattr( - te.mp, - "get_context", - lambda _method: (_ for _ in ()).throw(AssertionError("must not spawn")), - ) - - with pytest.raises( - te.TranscriptionError, - match="only supports faster-whisper backend", - ): - te._run_faster_whisper_process_isolated( - file_path="sample.wav", - language="en", - profile=profile, - settings=cast( - te.AppConfig, - SimpleNamespace(torch_runtime=SimpleNamespace(device="cpu", dtype="auto")), - ), - ) - - -def test_transcription_worker_entry_blocks_torch_for_faster_whisper( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Worker should disable torch import path before faster-whisper adapter operations.""" - observed: dict[str, object] = {} - messages: list[tuple[object, ...]] = [] - - class _FakeConnection: - def send(self, message: tuple[object, ...]) -> None: - messages.append(message) - - def close(self) -> None: - observed["closed"] = True - - class _FakeAdapter: - def setup_required(self, *, runtime_request: object, settings: object) -> bool: - del runtime_request - observed["torch_none_setup"] = sys.modules.get("torch") is None - assert isinstance(settings, te._TranscriptionWorkerSettings) - observed["whisper_download_root"] = settings.models.whisper_download_root - return False - - def prepare_assets(self, *, runtime_request: object, settings: object) -> None: - del runtime_request, settings - raise AssertionError("prepare_assets should not run when setup is not required") - - def load_model(self, *, runtime_request: object, settings: object) -> object: - del runtime_request - observed["torch_none_load"] = sys.modules.get("torch") is None - assert isinstance(settings, te._TranscriptionWorkerSettings) - assert settings.models.whisper_download_root == Path("/tmp/whisper-cache") - return object() - - def transcribe( - self, - *, - model: object, - runtime_request: object, - file_path: str, - language: str, - settings: object, - ) -> list[TranscriptWord]: - del model, runtime_request - observed["torch_none_transcribe"] = sys.modules.get("torch") is None - observed["file_path"] = file_path - observed["language"] = language - assert isinstance(settings, te._TranscriptionWorkerSettings) - assert settings.models.whisper_download_root == Path("/tmp/whisper-cache") - return [TranscriptWord("hello", 0.0, 0.5)] - - monkeypatch.setattr( - te, - "reload_settings", - lambda: (_ for _ in ()).throw(AssertionError("worker must not use ambient settings")), - ) - monkeypatch.setattr( - te, - "resolve_transcription_backend_adapter", - lambda _backend_id: cast(object, _FakeAdapter()), - ) - payload = te._TranscriptionProcessPayload( - file_path="sample.wav", - language="en", - profile=te.TranscriptionProfile( - backend_id="faster_whisper", - model_name="distil-large-v3", - use_demucs=False, - use_vad=True, - ), - runtime_request=te.BackendRuntimeRequest( - model_name="distil-large-v3", - use_demucs=False, - use_vad=True, - device_spec="cpu", - device_type="cpu", - precision_candidates=("float32",), - memory_tier="not_applicable", - ), - settings=te._TranscriptionWorkerSettings( - models=te._TranscriptionWorkerModelsConfig( - whisper_download_root=Path("/tmp/whisper-cache") - ) - ), - ) - original_torch = sys.modules.pop("torch", None) - try: - te._transcription_worker_entry( - payload, - cast(Connection, _FakeConnection()), - ) - finally: - if original_torch is not None: - sys.modules["torch"] = original_torch - elif "torch" in sys.modules: - del sys.modules["torch"] - - assert observed["torch_none_setup"] is True - assert observed["torch_none_load"] is True - assert observed["torch_none_transcribe"] is True - assert observed["whisper_download_root"] == Path("/tmp/whisper-cache") - assert observed["file_path"] == "sample.wav" - assert observed["language"] == "en" - assert observed["closed"] is True - assert messages[0] == ("phase", "setup_complete") - assert messages[1] == ("phase", "model_loaded") - assert messages[2] == ("ok", [("hello", 0.0, 0.5)]) - - -def test_transcription_worker_entry_delegates_to_boundary_owner( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Public worker-entry wrapper should delegate casting/setup to the internal owner.""" - captured: dict[str, object] = {} - payload = object() - connection = object() - - def _fake_boundary_impl(*args: object, **kwargs: object) -> None: - captured["args"] = args - captured.update(kwargs) - - monkeypatch.setattr( - te, - "_transcription_worker_entry_boundary_impl", - _fake_boundary_impl, - ) - - te._transcription_worker_entry(payload, connection) - - assert captured["args"] == (payload, connection) - assert captured["transcription_worker_entry_impl"] is te._transcription_worker_entry_impl - assert captured["adapter_resolver"] is te._resolve_transcription_adapter - - -def test_faster_whisper_setup_required_when_cache_missing( - monkeypatch: pytest.MonkeyPatch, - tmp_path: Path, -) -> None: - """faster-whisper setup should run when local cache lookup misses.""" - settings = cast( - te.AppConfig, - SimpleNamespace(models=SimpleNamespace(whisper_download_root=tmp_path / "model-cache")), - ) - captured: dict[str, object] = {} - - def _fake_download_model( - model_name: str, - *, - local_files_only: bool, - cache_dir: str, - ) -> str: - captured["model_name"] = model_name - captured["local_files_only"] = local_files_only - captured["cache_dir"] = cache_dir - raise RuntimeError("cache miss") - - monkeypatch.setattr( - faster_whisper_adapter.importlib, - "import_module", - lambda name: ( - SimpleNamespace(download_model=_fake_download_model) - if name == "faster_whisper.utils" - else __import__(name) - ), - ) - - required = te._transcription_setup_required( - active_profile=te.TranscriptionProfile( - backend_id="faster_whisper", - model_name="distil-large-v3", - ), - settings=settings, - ) - - assert required is True - assert captured["model_name"] == "distil-large-v3" - assert captured["local_files_only"] is True - assert captured["cache_dir"] == str(settings.models.whisper_download_root) - - -def test_faster_whisper_prepare_transcription_assets_downloads( - monkeypatch: pytest.MonkeyPatch, - tmp_path: Path, -) -> None: - """faster-whisper setup should trigger download in non-local-files mode.""" - settings = cast( - te.AppConfig, - SimpleNamespace(models=SimpleNamespace(whisper_download_root=tmp_path / "model-cache")), - ) - captured: dict[str, object] = {} - - def _fake_download_model( - model_name: str, - *, - local_files_only: bool, - cache_dir: str, - ) -> str: - captured["model_name"] = model_name - captured["local_files_only"] = local_files_only - captured["cache_dir"] = cache_dir - return str(tmp_path / "snapshot") - - monkeypatch.setattr( - faster_whisper_adapter.importlib, - "import_module", - lambda name: ( - SimpleNamespace(download_model=_fake_download_model) - if name == "faster_whisper.utils" - else __import__(name) - ), - ) - - te._prepare_transcription_assets( - active_profile=te.TranscriptionProfile( - backend_id="faster_whisper", - model_name="distil-large-v3", - ), - settings=settings, - ) - - assert captured["model_name"] == "distil-large-v3" - assert captured["local_files_only"] is False - assert captured["cache_dir"] == str(settings.models.whisper_download_root) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000..293feb7 --- /dev/null +++ b/tests/utils/__init__.py @@ -0,0 +1 @@ +"""Shared test utilities.""" diff --git a/tests/utils/helpers/__init__.py b/tests/utils/helpers/__init__.py new file mode 100644 index 0000000..e577e8d --- /dev/null +++ b/tests/utils/helpers/__init__.py @@ -0,0 +1 @@ +"""Helper utilities for tests.""" diff --git a/tests/utils/helpers/process_spawn_support.py b/tests/utils/helpers/process_spawn_support.py new file mode 100644 index 0000000..b3efa05 --- /dev/null +++ b/tests/utils/helpers/process_spawn_support.py @@ -0,0 +1,115 @@ +"""Importable module-level helpers for spawned-process test coverage.""" + +from __future__ import annotations + +import time +from dataclasses import dataclass +from multiprocessing.connection import Connection +from pathlib import Path + +from ser.profiles import TranscriptionBackendId +from ser.transcript.backends import BackendRuntimeRequest + + +@dataclass(frozen=True, slots=True) +class RuntimeWorkerPayload: + """Serializable runtime payload for spawned worker contract tests.""" + + result: str = "ok" + compute_delay_seconds: float = 0.0 + error_type: str | None = None + error_message: str = "" + + +def runtime_worker_entry(payload: RuntimeWorkerPayload, connection: Connection) -> None: + """Emits the standard runtime worker protocol from a spawned helper worker.""" + try: + connection.send(("phase", "setup_complete")) + if payload.compute_delay_seconds > 0.0: + time.sleep(payload.compute_delay_seconds) + if isinstance(payload.error_type, str): + connection.send(("err", payload.error_type, payload.error_message)) + else: + connection.send(("ok", payload.result)) + finally: + connection.close() + + +@dataclass(frozen=True, slots=True) +class FakeTranscriptionProfile: + """Serializable transcription profile for spawned helper workers.""" + + backend_id: TranscriptionBackendId = "faster_whisper" + model_name: str = "tiny" + use_demucs: bool = False + use_vad: bool = True + + +@dataclass(frozen=True, slots=True) +class FakeModelsConfig: + """Serializable models config for transcription worker settings.""" + + whisper_download_root: Path + + +@dataclass(frozen=True, slots=True) +class FakeSettings: + """Serializable settings snapshot for spawned transcription workers.""" + + models: FakeModelsConfig + + +@dataclass(frozen=True, slots=True) +class FakeTranscriptionPayload: + """Serializable payload for process-isolated transcription helper tests.""" + + file_path: str + language: str + profile: FakeTranscriptionProfile + runtime_request: BackendRuntimeRequest + settings: FakeSettings + + +def build_transcription_payload( + *, + file_path: str, + language: str, + profile: FakeTranscriptionProfile, + runtime_request: BackendRuntimeRequest, + settings: FakeSettings, +) -> FakeTranscriptionPayload: + """Builds one serializable payload for spawned transcription tests.""" + return FakeTranscriptionPayload( + file_path=file_path, + language=language, + profile=profile, + runtime_request=runtime_request, + settings=settings, + ) + + +def transcription_success_worker( + payload: FakeTranscriptionPayload, + connection: Connection, +) -> None: + """Emits the standard transcription worker success protocol.""" + del payload + try: + connection.send(("phase", "setup_complete")) + connection.send(("phase", "model_loaded")) + connection.send(("ok", [("hello", 0.0, 0.5), ("world", 0.5, 1.0)])) + finally: + connection.close() + + +def transcription_error_worker( + payload: FakeTranscriptionPayload, + connection: Connection, +) -> None: + """Emits the standard transcription worker error protocol.""" + del payload + try: + connection.send(("phase", "setup_complete")) + connection.send(("err", "model_load", "RuntimeError", "boom")) + finally: + connection.close()