diff --git a/.changeset/openai-transcription-diarization.md b/.changeset/openai-transcription-diarization.md new file mode 100644 index 000000000..67769277c --- /dev/null +++ b/.changeset/openai-transcription-diarization.md @@ -0,0 +1,7 @@ +--- +'@tanstack/ai': minor +'@tanstack/ai-client': minor +'@tanstack/ai-openai': minor +--- + +Add OpenAI transcription diarization support with `diarized_json` output, speaker-labeled segments, diarization model validation, chunking strategy options, and docs. diff --git a/docs/adapters/openai.md b/docs/adapters/openai.md index 1eed1a8ef..d275ad260 100644 --- a/docs/adapters/openai.md +++ b/docs/adapters/openai.md @@ -313,8 +313,11 @@ import { audioFile } from "./audio"; const result = await generateTranscription({ adapter: openaiTranscription("whisper-1"), audio: audioFile, + responseFormat: "verbose_json", + prompt: "Technical terms: API, SDK", modelOptions: { temperature: 0, + timestamp_granularities: ["word", "segment"], }, }); @@ -322,6 +325,34 @@ const result = await generateTranscription({ console.log(result.text); ``` +### Speaker Diarization + +Use `gpt-4o-transcribe-diarize` for speaker-labeled transcripts: + +```typescript +import { generateTranscription } from "@tanstack/ai"; +import { openaiTranscription } from "@tanstack/ai-openai"; +import { meetingAudioFile } from "./audio"; + +const result = await generateTranscription({ + adapter: openaiTranscription("gpt-4o-transcribe-diarize"), + audio: meetingAudioFile, + modelOptions: { + known_speaker_names: ["agent", "customer"], + known_speaker_references: [ + "data:audio/wav;base64,...", + "data:audio/wav;base64,...", + ], + }, +}); + +for (const segment of result.segments ?? []) { + console.log(segment.speaker, segment.start, segment.end, segment.text); +} +``` + +When no response format is specified, `gpt-4o-transcribe-diarize` requests default to `response_format: "diarized_json"` and `chunking_strategy: "auto"`; passing a top-level `responseFormat` of `"json"` or `"text"` opts out of speaker segments. `known_speaker_names` and `known_speaker_references` must be provided together (up to 4, matching lengths). OpenAI does not support `prompt`, `include`, or `timestamp_granularities` with diarized transcription. + ## Environment Variables Set your API key in environment variables: @@ -370,7 +401,7 @@ Creates an OpenAI text-to-speech adapter. ### `openaiTranscription(model, config?)` / `createOpenaiTranscription(model, apiKey, config?)` -Creates an OpenAI transcription adapter (Whisper). +Creates an OpenAI transcription adapter for Whisper, GPT-4o transcription, and GPT-4o diarized transcription models. ### `openaiVideo(model, config?)` / `createOpenaiVideo(model, apiKey, config?)` diff --git a/docs/comparison/vercel-ai-sdk.md b/docs/comparison/vercel-ai-sdk.md index 8f32d0acd..2d3c68bd5 100644 --- a/docs/comparison/vercel-ai-sdk.md +++ b/docs/comparison/vercel-ai-sdk.md @@ -541,7 +541,7 @@ const result = await generateSpeech({ }) ``` -**Transcription** - `generateTranscription()` supports 5 output formats (json, text, srt, verbose_json, vtt), word-level timestamps with confidence scores, and four providers (OpenAI, Grok, ElevenLabs, fal.ai), with speaker diarization via OpenAI's `gpt-4o-transcribe-diarize` model. +**Transcription** - `generateTranscription()` supports common output formats (json, text, srt, verbose_json, vtt), word-level timestamps with confidence scores, and four providers (OpenAI, Grok, ElevenLabs, fal.ai), with speaker diarization via OpenAI's `gpt-4o-transcribe-diarize` model. ```ts import { generateTranscription } from '@tanstack/ai' diff --git a/docs/config.json b/docs/config.json index 770c6fc7f..8336c1778 100644 --- a/docs/config.json +++ b/docs/config.json @@ -259,7 +259,7 @@ "label": "Transcription", "to": "media/transcription", "addedAt": "2026-04-15", - "updatedAt": "2026-07-01" + "updatedAt": "2026-07-03" }, { "label": "Audio Recording", @@ -503,7 +503,8 @@ { "label": "OpenAI", "to": "adapters/openai", - "addedAt": "2026-04-15" + "addedAt": "2026-04-15", + "updatedAt": "2026-07-03" }, { "label": "Anthropic", diff --git a/docs/media/generation-hooks.md b/docs/media/generation-hooks.md index 273ea72ab..69cb7b7c9 100644 --- a/docs/media/generation-hooks.md +++ b/docs/media/generation-hooks.md @@ -214,7 +214,7 @@ The `generate` function accepts a `TranscriptionGenerateInput`: | `audio` | `string \| File \| Blob \| ArrayBuffer` | Audio data -- base64 string, File, Blob, or ArrayBuffer (required) | | `language` | `string` | Language in ISO-639-1 format (e.g., `"en"`) | | `prompt` | `string` | Optional prompt to guide the transcription | -| `responseFormat` | `'json' \| 'text' \| 'srt' \| 'verbose_json' \| 'vtt'` | Output format | +| `responseFormat` | `'json' \| 'text' \| 'srt' \| 'verbose_json' \| 'vtt'` | Common output format | | `modelOptions` | `Record` | Model-specific options | ## useSummarize diff --git a/docs/media/transcription.md b/docs/media/transcription.md index cace99e3a..bd92c633a 100644 --- a/docs/media/transcription.md +++ b/docs/media/transcription.md @@ -2,7 +2,7 @@ title: Transcription id: transcription order: 4 -description: "Transcribe audio to text with OpenAI Whisper, GPT-4o-transcribe, Groq Whisper, and fal.ai STT models via TanStack AI's generateTranscription() API." +description: "Transcribe audio to text with OpenAI Whisper and GPT-4o transcription models (including speaker diarization), Groq Whisper, and fal.ai STT models via TanStack AI's generateTranscription() API." keywords: - tanstack ai - transcription @@ -24,7 +24,7 @@ TanStack AI provides support for audio transcription (speech-to-text) through de Audio transcription is handled by transcription adapters that follow the same tree-shakeable architecture as other adapters in TanStack AI. Currently supported: -- **OpenAI**: Whisper-1, GPT-4o-transcribe, GPT-4o-mini-transcribe +- **OpenAI**: Whisper-1, GPT-4o-transcribe, GPT-4o-mini-transcribe, GPT-4o-transcribe-diarize - **Groq**: whisper-large-v3-turbo, whisper-large-v3 - **fal.ai**: Whisper, Wizper, speech-to-text turbo, ElevenLabs speech-to-text @@ -139,6 +139,8 @@ for (const segment of result.segments ?? []) { |--------|------|-------------| | `audio` | `File \| string` | Audio data (File object or base64 string) - required | | `language` | `string` | Language code (e.g., "en", "es", "fr") | +| `prompt` | `string` | Optional prompt to guide transcription style or terms. Not supported with `gpt-4o-transcribe-diarize`. | +| `responseFormat` | `'json' \| 'text' \| 'srt' \| 'verbose_json' \| 'vtt'` | Common output format | ### Supported Languages @@ -175,6 +177,7 @@ const result = await generateTranscription({ prompt: 'Technical terms: API, SDK, CLI', // Top-level: guide transcription modelOptions: { temperature: 0, // Lower = more deterministic (provider option) + timestamp_granularities: ['word', 'segment'], }, }) ``` @@ -182,8 +185,12 @@ const result = await generateTranscription({ | Option | Type | Description | |--------|------|-------------| | `temperature` | `number` | Sampling temperature (0 to 1) | -| `timestamp_granularities` | `Array<'word' \| 'segment'>` | Timestamp granularity to populate (requires top-level `responseFormat: 'verbose_json'`) | +| `timestamp_granularities` | `Array<'word' \| 'segment'>` | Timestamp granularity to populate (`whisper-1` only; requires top-level `responseFormat: 'verbose_json'`) | | `include` | `string[]` | Additional values to include in the response (e.g., `logprobs`) | +| `response_format` | `'json' \| 'text' \| 'srt' \| 'verbose_json' \| 'vtt' \| 'diarized_json'` | Raw OpenAI response format. Use `diarized_json` here for speaker-labeled diarization output. | +| `chunking_strategy` | `'auto' \| { type: 'server_vad', ... } \| null` | Audio chunking strategy (any model; unset transcribes the audio as a single block). Required by OpenAI for `gpt-4o-transcribe-diarize` inputs longer than 30 seconds — the adapter defaults it to `'auto'` for that model | +| `known_speaker_names` | `string[]` | Up to four speaker labels for diarization | +| `known_speaker_references` | `string[]` | 2-10 second data URL audio samples matching `known_speaker_names` | > `responseFormat` and `prompt` are **top-level** options on `generateTranscription`, not `modelOptions` keys. @@ -197,6 +204,36 @@ const result = await generateTranscription({ | `verbose_json` | Detailed JSON with timestamps and segments | | `vtt` | WebVTT subtitle format | +OpenAI's `gpt-4o-transcribe-diarize` also supports `modelOptions.response_format: 'diarized_json'` for speaker-labeled segments. + +### Speaker Diarization + +Use `gpt-4o-transcribe-diarize` when you need speaker labels. When no response format is specified, TanStack AI defaults the request to `response_format: 'diarized_json'` and sends `chunking_strategy: 'auto'` unless you provide a chunking strategy yourself. Passing a top-level `responseFormat: 'json'` or `'text'` opts out of speaker segments. + +```typescript +import { generateTranscription } from '@tanstack/ai' +import { openaiTranscription } from '@tanstack/ai-openai' +import { meetingAudioFile } from './audio' + +const result = await generateTranscription({ + adapter: openaiTranscription('gpt-4o-transcribe-diarize'), + audio: meetingAudioFile, + modelOptions: { + known_speaker_names: ['agent', 'customer'], + known_speaker_references: [ + 'data:audio/wav;base64,...', + 'data:audio/wav;base64,...', + ], + }, +}) + +for (const segment of result.segments ?? []) { + console.log(segment.speaker, segment.start, segment.end, segment.text) +} +``` + +OpenAI accepts up to four known speaker references; `known_speaker_names` and `known_speaker_references` must be provided together with matching lengths. The diarization model does not support `prompt`, `include`, or `timestamp_granularities`; the adapter rejects those combinations before making the API request. + ## Response Format The transcription result includes: @@ -499,9 +536,14 @@ import { transcribeStreamFn } from '../lib/server-functions' function AudioTranscriber() { const { generate, result, isLoading } = useTranscription({ - fetcher: (input) => transcribeStreamFn({ - data: { ...input, audio: input.audio as string }, - }), + fetcher: (input) => { + if (typeof input.audio !== 'string') { + throw new Error('Expected base64 or data URL audio') + } + return transcribeStreamFn({ + data: { ...input, audio: input.audio }, + }) + }, }) // ... same UI as above } @@ -586,5 +628,6 @@ const adapter = createOpenaiTranscription('whisper-1', 'your-openai-api-key') 5. **Prompting**: Use the `prompt` option to provide context or expected vocabulary (e.g., technical terms, names). -6. **Timestamps**: Request `verbose_json` format and enable `timestamp_granularities: ['word', 'segment']` when you need timing information for captions or synchronization. +6. **Timestamps**: Request `responseFormat: 'verbose_json'` and set `modelOptions.timestamp_granularities` when you need timing information for captions or synchronization. +7. **Diarization**: Use `gpt-4o-transcribe-diarize` with `modelOptions.response_format: 'diarized_json'` output for multi-speaker audio. Keep `chunking_strategy: 'auto'` unless you need custom VAD tuning. diff --git a/examples/ts-react-chat/src/lib/audio-providers.ts b/examples/ts-react-chat/src/lib/audio-providers.ts index 5ff72fae2..283f21779 100644 --- a/examples/ts-react-chat/src/lib/audio-providers.ts +++ b/examples/ts-react-chat/src/lib/audio-providers.ts @@ -6,6 +6,8 @@ * and audio generation flows. */ +import type { TranscriptionGenerateInput } from '@tanstack/ai-client' + export type SpeechProviderId = | 'openai' | 'gemini' @@ -87,13 +89,22 @@ export const SPEECH_PROVIDERS: ReadonlyArray = [ }, ] -export type TranscriptionProviderId = 'openai' | 'fal' | 'grok' | 'elevenlabs' +export type TranscriptionProviderId = + | 'openai' + | 'openai-diarize' + | 'fal' + | 'grok' + | 'elevenlabs' export interface TranscriptionProviderConfig { id: TranscriptionProviderId label: string model: string description: string + transcriptionOptions?: Pick< + TranscriptionGenerateInput, + 'responseFormat' | 'modelOptions' + > } export const TRANSCRIPTION_PROVIDERS: ReadonlyArray = @@ -104,6 +115,19 @@ export const TRANSCRIPTION_PROVIDERS: ReadonlyArray model: 'whisper-1', description: 'OpenAI Whisper transcription with optional streaming.', }, + { + id: 'openai-diarize', + label: 'OpenAI Diarize', + model: 'gpt-4o-transcribe-diarize', + description: + 'OpenAI diarized transcription with speaker-labeled segments.', + transcriptionOptions: { + modelOptions: { + response_format: 'diarized_json', + chunking_strategy: 'auto', + }, + }, + }, { id: 'fal', label: 'Fal Whisper', diff --git a/examples/ts-react-chat/src/lib/server-audio-adapters.ts b/examples/ts-react-chat/src/lib/server-audio-adapters.ts index ff9d8f47c..46cbb7c42 100644 --- a/examples/ts-react-chat/src/lib/server-audio-adapters.ts +++ b/examples/ts-react-chat/src/lib/server-audio-adapters.ts @@ -65,6 +65,8 @@ export function buildTranscriptionAdapter( switch (config.id) { case 'openai': return openaiTranscription(config.model as 'whisper-1') + case 'openai-diarize': + return openaiTranscription(config.model as 'gpt-4o-transcribe-diarize') case 'fal': return falTranscription(config.model) case 'grok': diff --git a/examples/ts-react-chat/src/lib/server-fns.ts b/examples/ts-react-chat/src/lib/server-fns.ts index b9b8ef62b..1c8109be4 100644 --- a/examples/ts-react-chat/src/lib/server-fns.ts +++ b/examples/ts-react-chat/src/lib/server-fns.ts @@ -78,7 +78,11 @@ const SPEECH_PROVIDER_SCHEMA = z .optional() const TRANSCRIPTION_PROVIDER_SCHEMA = z - .enum(['openai', 'fal', 'grok', 'elevenlabs']) + .enum(['openai', 'openai-diarize', 'fal', 'grok', 'elevenlabs']) + .optional() + +const TRANSCRIPTION_RESPONSE_FORMAT_SCHEMA = z + .enum(['json', 'text', 'srt', 'verbose_json', 'vtt']) .optional() const AUDIO_PROVIDER_SCHEMA = z @@ -144,6 +148,8 @@ export const transcribeFn = createServerFn({ method: 'POST' }) z.object({ audio: z.string(), language: z.string().optional(), + responseFormat: TRANSCRIPTION_RESPONSE_FORMAT_SCHEMA, + modelOptions: z.record(z.string(), z.any()).optional(), provider: TRANSCRIPTION_PROVIDER_SCHEMA, }), ) @@ -162,6 +168,8 @@ export const transcribeFn = createServerFn({ method: 'POST' }) adapter, audio: data.audio, language: data.language, + responseFormat: data.responseFormat, + modelOptions: data.modelOptions, }) }) @@ -316,6 +324,8 @@ export const transcribeStreamFn = createServerFn({ method: 'POST' }) z.object({ audio: z.string(), language: z.string().optional(), + responseFormat: TRANSCRIPTION_RESPONSE_FORMAT_SCHEMA, + modelOptions: z.record(z.string(), z.any()).optional(), provider: TRANSCRIPTION_PROVIDER_SCHEMA, }), ) @@ -335,6 +345,8 @@ export const transcribeStreamFn = createServerFn({ method: 'POST' }) adapter, audio: data.audio, language: data.language, + responseFormat: data.responseFormat, + modelOptions: data.modelOptions, stream: true, }), ) diff --git a/examples/ts-react-chat/src/routes/api.transcribe.ts b/examples/ts-react-chat/src/routes/api.transcribe.ts index b841ea904..a26800547 100644 --- a/examples/ts-react-chat/src/routes/api.transcribe.ts +++ b/examples/ts-react-chat/src/routes/api.transcribe.ts @@ -8,12 +8,18 @@ import { } from '../lib/server-audio-adapters' const TRANSCRIPTION_PROVIDER_SCHEMA = z - .enum(['openai', 'fal', 'grok', 'elevenlabs']) + .enum(['openai', 'openai-diarize', 'fal', 'grok', 'elevenlabs']) + .optional() + +const TRANSCRIPTION_RESPONSE_FORMAT_SCHEMA = z + .enum(['json', 'text', 'srt', 'verbose_json', 'vtt']) .optional() const TRANSCRIBE_BODY_SCHEMA = z.object({ audio: z.string().min(1), language: z.string().optional(), + responseFormat: TRANSCRIPTION_RESPONSE_FORMAT_SCHEMA, + modelOptions: z.record(z.string(), z.any()).optional(), provider: TRANSCRIPTION_PROVIDER_SCHEMA, }) @@ -55,7 +61,8 @@ export const Route = createFileRoute('/api/transcribe')({ }) } - const { audio, language, provider } = parsed.data + const { audio, language, responseFormat, modelOptions, provider } = + parsed.data try { const adapter = buildTranscriptionAdapter(provider ?? 'openai') @@ -64,6 +71,8 @@ export const Route = createFileRoute('/api/transcribe')({ adapter, audio, language, + responseFormat, + modelOptions, stream: true, }) diff --git a/examples/ts-react-chat/src/routes/generations.transcription.tsx b/examples/ts-react-chat/src/routes/generations.transcription.tsx index b03f8838f..d889cd452 100644 --- a/examples/ts-react-chat/src/routes/generations.transcription.tsx +++ b/examples/ts-react-chat/src/routes/generations.transcription.tsx @@ -34,6 +34,8 @@ function TranscriptionForm({ data: { audio: input.audio as string, language: input.language, + responseFormat: input.responseFormat, + modelOptions: input.modelOptions, provider: config.id, }, }), @@ -45,6 +47,8 @@ function TranscriptionForm({ data: { audio: input.audio as string, language: input.language, + responseFormat: input.responseFormat, + modelOptions: input.modelOptions, provider: config.id, }, }), @@ -75,7 +79,11 @@ function TranscriptionUI({ ) const dataUrl = `data:${file.type};base64,${base64}` - await generate({ audio: dataUrl, language: 'en' }) + await generate({ + audio: dataUrl, + language: 'en', + ...config.transcriptionOptions, + }) if (fileInputRef.current) { fileInputRef.current.value = '' @@ -159,6 +167,11 @@ function TranscriptionUI({ {seg.start.toFixed(1)}s - {seg.end.toFixed(1)}s + {seg.speaker && ( + + {seg.speaker} + + )} {seg.text} ))} diff --git a/knip.json b/knip.json index b977ce034..c7995c1aa 100644 --- a/knip.json +++ b/knip.json @@ -15,7 +15,6 @@ "packages/ai-openai/live-tests/**", "packages/ai-openai/src/**/*.test.ts", "packages/ai-openai/src/audio/audio-provider-options.ts", - "packages/ai-openai/src/audio/transcribe-provider-options.ts", "packages/ai-openai/src/image/image-provider-options.ts", "packages/ai-devtools/src/production.ts", "codemods/**/__testfixtures__/**" diff --git a/packages/ai-client/src/generation-types.ts b/packages/ai-client/src/generation-types.ts index 725cdf91b..b11e8ca2a 100644 --- a/packages/ai-client/src/generation-types.ts +++ b/packages/ai-client/src/generation-types.ts @@ -1,4 +1,5 @@ import type { MediaPrompt, StreamChunk } from '@tanstack/ai/client' +import type { TranscriptionResponseFormat } from '@tanstack/ai' import type { ConnectConnectionAdapter } from './connection-adapters' import type { AIDevtoolsClientMetadata } from './devtools' import type { @@ -289,7 +290,7 @@ export interface TranscriptionGenerateInput { /** An optional prompt to guide the transcription */ prompt?: string /** The format of the transcription output */ - responseFormat?: 'json' | 'text' | 'srt' | 'verbose_json' | 'vtt' + responseFormat?: TranscriptionResponseFormat /** Model-specific options */ modelOptions?: Record } diff --git a/packages/ai-openai/src/adapters/transcription.ts b/packages/ai-openai/src/adapters/transcription.ts index 249cbf8da..b5fcd23e9 100644 --- a/packages/ai-openai/src/adapters/transcription.ts +++ b/packages/ai-openai/src/adapters/transcription.ts @@ -11,9 +11,43 @@ import type { } from '@tanstack/ai' import type OpenAI_SDK from 'openai' import type { OpenAITranscriptionModel } from '../model-meta' -import type { OpenAITranscriptionProviderOptions } from '../audio/transcription-provider-options' +import type { + OpenAITranscriptionProviderOptions, + OpenAITranscriptionResponseFormat, +} from '../audio/transcription-provider-options' import type { OpenAIClientConfig } from '../utils/client' +const DIARIZE_MODELS = ['gpt-4o-transcribe-diarize'] as const +const DIARIZE_RESPONSE_FORMATS = ['json', 'text', 'diarized_json'] as const + +type DiarizeModel = (typeof DIARIZE_MODELS)[number] +type OpenAITranscriptionResponseMode = 'diarized' | 'verbose' | 'plain' + +interface OpenAITranscriptionRequestPlan { + request: OpenAI_SDK.Audio.TranscriptionCreateParamsNonStreaming + responseMode: OpenAITranscriptionResponseMode +} + +function isDiarizeModel(model: string): model is DiarizeModel { + return DIARIZE_MODELS.includes(model as DiarizeModel) +} + +// OpenAI diarized segments carry string ids like `seg_0`, but the shared +// TranscriptionSegment.id is numeric: parse the numeric suffix (or a plain +// numeric string) and fall back to the array index otherwise. The empty-string +// guard matters because Number('') is 0, which would collide with `seg_0`. +function mapDiarizedSegmentId(id: string, index: number): number { + const match = /^seg_(\d+)$/.exec(id) + if (match) return Number(match[1]) + + if (id.trim() !== '') { + const numericId = Number(id) + if (!Number.isNaN(numericId)) return numericId + } + + return index +} + /** * Build TokenUsage from transcription response. * Whisper-1 uses duration-based billing, GPT-4o models use token-based billing. @@ -31,10 +65,21 @@ function buildTranscriptionUsage( // billing data to report, so return undefined rather than fabricating a // duration-based result for a token-billed model. if (model.startsWith('gpt-4o')) { - if (!usage || usage.type !== 'tokens') { + if (!usage) { return undefined } + // gpt-4o-transcribe-diarize responses may report duration-based usage; + // surface it rather than discarding billing data the API returned. + if (usage.type === 'duration') { + return { + promptTokens: 0, + completionTokens: 0, + totalTokens: 0, + durationSeconds: usage.seconds, + } + } + const result: TokenUsage = { promptTokens: usage.input_tokens || 0, completionTokens: usage.output_tokens || 0, @@ -86,16 +131,17 @@ export interface OpenAITranscriptionConfig extends OpenAIClientConfig {} * OpenAI Transcription (Speech-to-Text) Adapter * * Tree-shakeable adapter for OpenAI audio transcription functionality. - * Supports whisper-1, gpt-4o-transcribe, gpt-4o-mini-transcribe, and gpt-4o-transcribe-diarize models. + * Supports whisper-1, gpt-4o-transcribe, gpt-4o-mini-transcribe, and gpt-4o-transcribe-diarize. * * Features: * - Multiple transcription models with different capabilities * - Language detection or specification - * - Multiple output formats: json, text, srt, verbose_json, vtt + * - Multiple output formats: json, text, srt, verbose_json, vtt, diarized_json * - Word and segment-level timestamps (with verbose_json — whisper-1 only; - * gpt-4o-* transcribe models accept only json/text and reject verbose_json - * with HTTP 400) - * - Speaker diarization (with gpt-4o-transcribe-diarize) + * gpt-4o-transcribe and gpt-4o-mini-transcribe accept only json/text and + * reject verbose_json with HTTP 400) + * - Speaker diarization (with gpt-4o-transcribe-diarize, which accepts json, + * text, and diarized_json) */ export class OpenAITranscriptionAdapter< TModel extends OpenAITranscriptionModel, @@ -112,42 +158,58 @@ export class OpenAITranscriptionAdapter< async transcribe( options: TranscriptionOptions, ): Promise { - const { model, audio, language, prompt, responseFormat, modelOptions } = - options - - const file = this.prepareAudioFile(audio) - - // With exactOptionalPropertyTypes, vendor SDK request shapes reject - // `T | undefined` in optional fields. Build the request incrementally and - // only set optional fields when they're actually defined. - const responseFormatValue = this.mapResponseFormat(responseFormat) - const request: OpenAI_SDK.Audio.TranscriptionCreateParams = { - model, - file, - ...(modelOptions ?? {}), - } - if (language !== undefined) { - request.language = language - } - if (prompt !== undefined) { - request.prompt = prompt - } - if (responseFormatValue !== undefined) { - request.response_format = responseFormatValue - } - - // Only Whisper supports verbose_json. The gpt-4o-* transcribe models - // accept only json/text and reject verbose_json with HTTP 400. - const useVerbose = - responseFormat === 'verbose_json' || - (!responseFormat && model === 'whisper-1') + const { model, language } = options try { + const { request, responseMode } = this.buildTranscriptionRequest(options) + options.logger.request( - `activity=transcription provider=${this.name} model=${model} verbose=${useVerbose}`, + `activity=transcription provider=${this.name} model=${model} verbose=${responseMode === 'verbose'} diarized=${responseMode === 'diarized'}`, { provider: this.name, model }, ) - if (useVerbose) { + if (responseMode === 'diarized') { + const response = (await this.client.audio.transcriptions.create( + request, + )) as OpenAI_SDK.Audio.TranscriptionDiarized + + // Guard the cast: a proxy/gateway or API change that returns a + // non-diarized shape would otherwise fail with a context-free + // TypeError deep in the mapping below. + if (!Array.isArray(response.segments)) { + throw new Error( + `OpenAI diarized transcription response did not include segments (model=${model}, response_format=diarized_json).`, + ) + } + + const segments = response.segments.map( + (segment, index): TranscriptionSegment => ({ + id: mapDiarizedSegmentId(segment.id, index), + start: segment.start, + end: segment.end, + text: segment.text, + speaker: segment.speaker, + }), + ) + + const usage = buildTranscriptionUsage( + model, + response.duration, + response, + ) + return { + id: generateId(this.name), + model, + text: response.text, + duration: response.duration, + // Always include segments (even empty) for diarized requests: the + // caller asked for speaker segments, so an empty list is meaningful + // and should not look like a non-diarized result. + segments, + ...(usage !== undefined && { usage }), + } + } + + if (responseMode === 'verbose') { const response = (await this.client.audio.transcriptions.create({ ...request, response_format: 'verbose_json', @@ -188,20 +250,20 @@ export class OpenAITranscriptionAdapter< ...(words !== undefined && { words }), ...(usage !== undefined && { usage }), } - } else { - const response = await this.client.audio.transcriptions.create(request) + } - const usage = - typeof response === 'string' - ? undefined - : buildTranscriptionUsage(model, undefined, response) - return { - id: generateId(this.name), - model, - text: typeof response === 'string' ? response : response.text, - ...(language !== undefined && { language }), - ...(usage !== undefined && { usage }), - } + const response = await this.client.audio.transcriptions.create(request) + + const usage = + typeof response === 'string' + ? undefined + : buildTranscriptionUsage(model, undefined, response) + return { + id: generateId(this.name), + model, + text: typeof response === 'string' ? response : response.text, + ...(language !== undefined && { language }), + ...(usage !== undefined && { usage }), } } catch (error: unknown) { options.logger.errors(`${this.name}.transcribe fatal`, { @@ -212,6 +274,105 @@ export class OpenAITranscriptionAdapter< } } + private buildTranscriptionRequest( + options: TranscriptionOptions, + ): OpenAITranscriptionRequestPlan { + const { model, audio, language, prompt, responseFormat, modelOptions } = + options + const file = this.prepareAudioFile(audio) + const isDiarizeTranscriptionModel = isDiarizeModel(model) + const topLevelResponseFormat = responseFormat + const effectiveResponseFormat = + topLevelResponseFormat ?? modelOptions?.response_format + + if ( + topLevelResponseFormat !== undefined && + modelOptions?.response_format !== undefined && + topLevelResponseFormat !== modelOptions.response_format + ) { + throw new Error( + `Conflicting response formats: responseFormat="${topLevelResponseFormat}" and modelOptions.response_format="${modelOptions.response_format}". Provide only one.`, + ) + } + + this.validateDiarizationOptions({ + model, + prompt, + responseFormat: topLevelResponseFormat, + modelOptions, + }) + + const responseMode = this.resolveResponseMode({ + model, + isDiarizeTranscriptionModel, + effectiveResponseFormat, + }) + const responseFormatValue = + responseMode === 'diarized' + ? 'diarized_json' + : this.mapResponseFormat(effectiveResponseFormat) + + // With exactOptionalPropertyTypes, vendor SDK request shapes reject + // `T | undefined` in optional fields. Build the request incrementally and + // only set optional fields when they're actually defined. + // Spread modelOptions first so it can never override the validated + // `model`/`file` fields (server routes often pass modelOptions through + // from untyped client input). + const request: OpenAI_SDK.Audio.TranscriptionCreateParamsNonStreaming = { + ...modelOptions, + model, + file, + } + // `stream` is not a supported provider option for this adapter; an + // untyped passthrough setting it would flip the SDK into streaming mode + // and break response parsing. + delete request.stream + if (language !== undefined) { + request.language = language + } + if (prompt !== undefined) { + request.prompt = prompt + } + if ( + isDiarizeTranscriptionModel && + modelOptions?.chunking_strategy === undefined + ) { + request.chunking_strategy = 'auto' + } + request.response_format = responseFormatValue + + return { request, responseMode } + } + + private resolveResponseMode({ + model, + isDiarizeTranscriptionModel, + effectiveResponseFormat, + }: { + model: string + isDiarizeTranscriptionModel: boolean + effectiveResponseFormat?: OpenAITranscriptionResponseFormat + }): OpenAITranscriptionResponseMode { + if ( + effectiveResponseFormat === 'diarized_json' || + (isDiarizeTranscriptionModel && effectiveResponseFormat === undefined) + ) { + return 'diarized' + } + + // Only Whisper supports verbose_json. gpt-4o-transcribe and + // gpt-4o-mini-transcribe accept only json/text and reject verbose_json + // with HTTP 400 (the diarize model is handled above). + if ( + effectiveResponseFormat === 'verbose_json' || + (effectiveResponseFormat === undefined && model === 'whisper-1') + ) { + return 'verbose' + } + + return 'plain' + } + protected prepareAudioFile(audio: string | File | Blob | ArrayBuffer): File { if (typeof File !== 'undefined' && audio instanceof File) { return audio @@ -257,9 +418,116 @@ export class OpenAITranscriptionAdapter< } } + private validateDiarizationOptions({ + model, + prompt, + responseFormat, + modelOptions, + }: Pick< + TranscriptionOptions, + 'model' | 'prompt' | 'modelOptions' + > & { + responseFormat?: OpenAITranscriptionResponseFormat + }): void { + const isDiarizeTranscriptionModel = isDiarizeModel(model) + const modelOptionsResponseFormat = modelOptions?.response_format + + // `chunking_strategy` is deliberately NOT rejected here: per the OpenAI + // API it is a general transcription parameter for all models (only + // *required* for gpt-4o-transcribe-diarize inputs longer than 30s). + if ( + !isDiarizeTranscriptionModel && + (responseFormat === 'diarized_json' || + modelOptionsResponseFormat === 'diarized_json' || + modelOptions?.known_speaker_names !== undefined || + modelOptions?.known_speaker_references !== undefined) + ) { + throw new Error( + `OpenAI speaker diarization options (response_format: 'diarized_json', known_speaker_names, known_speaker_references) are only supported with OpenAI diarization transcription models; model is "${model}".`, + ) + } + + if (!isDiarizeTranscriptionModel) return + + const requestedResponseFormats = [ + this.mapResponseFormat(responseFormat), + ...(modelOptionsResponseFormat !== undefined + ? [this.mapResponseFormat(modelOptionsResponseFormat)] + : []), + ] + const unsupportedResponseFormat = requestedResponseFormats.find( + (format) => + !DIARIZE_RESPONSE_FORMATS.includes( + format as (typeof DIARIZE_RESPONSE_FORMATS)[number], + ), + ) + if (unsupportedResponseFormat !== undefined) { + throw new Error( + `OpenAI diarization transcription models only support json, text, and diarized_json response formats; received "${unsupportedResponseFormat}".`, + ) + } + + if (prompt !== undefined || modelOptions?.prompt !== undefined) { + throw new Error( + 'OpenAI diarization transcription models do not support prompts.', + ) + } + + if (modelOptions?.include !== undefined) { + throw new Error( + 'OpenAI diarization transcription models do not support the include option.', + ) + } + + if (modelOptions?.timestamp_granularities !== undefined) { + throw new Error( + 'OpenAI diarization transcription models do not support timestamp_granularities.', + ) + } + + if ( + (modelOptions?.known_speaker_names === undefined) !== + (modelOptions?.known_speaker_references === undefined) + ) { + throw new Error( + 'OpenAI diarization known_speaker_names and known_speaker_references must both be provided together.', + ) + } + + if (modelOptions?.known_speaker_names !== undefined) { + const knownSpeakerCount = modelOptions.known_speaker_names.length + if (knownSpeakerCount > 4) { + throw new Error( + 'OpenAI diarization transcription models support at most 4 known speaker names.', + ) + } + } + + if (modelOptions?.known_speaker_references !== undefined) { + const knownSpeakerReferenceCount = + modelOptions.known_speaker_references.length + if (knownSpeakerReferenceCount > 4) { + throw new Error( + 'OpenAI diarization transcription models support at most 4 known speaker references.', + ) + } + } + + if ( + modelOptions?.known_speaker_names !== undefined && + modelOptions.known_speaker_references !== undefined && + modelOptions.known_speaker_names.length !== + modelOptions.known_speaker_references.length + ) { + throw new Error( + `OpenAI diarization known_speaker_names and known_speaker_references must have matching lengths; received ${modelOptions.known_speaker_names.length} names and ${modelOptions.known_speaker_references.length} references.`, + ) + } + } + protected mapResponseFormat( - format?: 'json' | 'text' | 'srt' | 'verbose_json' | 'vtt', - ): OpenAI_SDK.Audio.TranscriptionCreateParams['response_format'] { + format?: OpenAITranscriptionResponseFormat, + ): OpenAITranscriptionResponseFormat { if (!format) return 'json' return format } diff --git a/packages/ai-openai/src/audio/transcribe-provider-options.ts b/packages/ai-openai/src/audio/transcribe-provider-options.ts deleted file mode 100644 index 063e719ff..000000000 --- a/packages/ai-openai/src/audio/transcribe-provider-options.ts +++ /dev/null @@ -1,128 +0,0 @@ -export interface TranscribeProviderOptions { - /** - * The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. - * https://platform.openai.com/docs/api-reference/audio/createTranscription#audio_createtranscription-file - */ - file: File - /** - * The model to use for transcription. - * https://platform.openai.com/docs/api-reference/audio/createTranscription#audio_createtranscription-model - */ - model: string - - chunking_strategy: - | 'auto' - | { - type: 'server_vad' - /** - * Amount of audio to include before the VAD detected speech (in milliseconds). - * @default 300 - */ - prefix_padding_ms?: number - /** - * Duration of silence to detect speech stop (in milliseconds). With shorter values the model will respond more quickly, but may jump in on short pauses from the user. - * @default 200 - */ - silence_duration_ms: number - /** - * Sensitivity threshold (0.0 to 1.0) for voice activity detection. A higher threshold will require louder audio to activate the model, and thus might perform better in noisy environments. - * @default 0.5 - */ - threshold?: number - } - /** - * Additional information to include in the transcription response. logprobs will return the log probabilities of the tokens in the response to understand the model's confidence in the transcription. logprobs only works with response_format set to json and only with the models gpt-4o-transcribe and gpt-4o-mini-transcribe. This field is not supported when using gpt-4o-transcribe-diarize. - */ - include?: Array - /** - * Optional list of speaker names that correspond to the audio samples provided in known_speaker_references[]. Each entry should be a short identifier (for example customer or agent). Up to 4 speakers are supported. - */ - known_speaker_names: Array - /** - * Optional list of audio samples (as data URLs) that contain known speaker references matching known_speaker_names[]. Each sample must be between 2 and 10 seconds, and can use any of the same input audio formats supported by file. - */ - known_speaker_references?: Array - /** - * The language of the input audio. Supplying the input language in ISO-639-1 (e.g. en) format will improve accuracy and latency. - */ - language?: string - /** - * An optional prompt to guide the transcription model's style or to help with uncommon words or phrases. - */ - prompt?: string - /** - * The format of the output, in one of these options: json, text, srt, verbose_json, vtt, or diarized_json. For gpt-4o-transcribe and gpt-4o-mini-transcribe, the only supported format is json. For gpt-4o-transcribe-diarize, the supported formats are json, text, and diarized_json, with diarized_json required to receive speaker annotations. - */ - response_format?: - | 'json' - | 'text' - | 'srt' - | 'verbose_json' - | 'vtt' - | 'diarized_json' - - /** - * If set to true, the model response data will be streamed to the client as it is generated using server-sent events - * Note: Streaming is not supported for the whisper-1 model and will be ignored. - */ - stream?: boolean - /** - * The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use log probability to automatically increase the temperature until certain thresholds are hit. - */ - temperature?: number - /** - * The timestamp granularities to populate for this transcription. response_format must be set verbose_json to use timestamp granularities. Either or both of these options are supported: word, or segment. Note: There is no additional latency for segment timestamps, but generating word timestamps incurs additional latency. This option is not available for gpt-4o-transcribe-diarize. - */ - timestamp_granularities?: Array<'word' | 'segment'> -} - -export const validateTemperature = (options: TranscribeProviderOptions) => { - if (options.temperature) { - if (options.temperature < 0 || options.temperature > 1) { - throw new Error('Temperature must be between 0 and 1.') - } - } -} - -export const validateStream = (options: TranscribeProviderOptions) => { - const unsupportedModels = ['whisper-1'] - if (options.stream) { - if (unsupportedModels.includes(options.model)) { - throw new Error(`The model ${options.model} does not support streaming.`) - } - } -} - -export const validatePrompt = (options: TranscribeProviderOptions) => { - const unsupportedModels = ['gpt-4o-transcribe-diarize'] - if (options.prompt) { - if (unsupportedModels.includes(options.model)) { - throw new Error(`The model ${options.model} does not support prompts.`) - } - } -} - -export const validateKnownSpeakerNames = ( - options: TranscribeProviderOptions, -) => { - if (options.known_speaker_names.length > 4) { - throw new Error('A maximum of 4 known speaker names are supported.') - } -} - -export const validateInclude = (options: TranscribeProviderOptions) => { - const unsupportedModels = ['gpt-4o-transcribe-diarize'] - if (options.include) { - if (unsupportedModels.includes(options.model)) { - throw new Error( - `The model ${options.model} does not support the include field.`, - ) - } - } - - if (options.include && options.response_format !== 'json') { - throw new Error( - 'The include field is only supported when response_format is set to json.', - ) - } -} diff --git a/packages/ai-openai/src/audio/transcription-provider-options.ts b/packages/ai-openai/src/audio/transcription-provider-options.ts index 17f619cb0..c951640c5 100644 --- a/packages/ai-openai/src/audio/transcription-provider-options.ts +++ b/packages/ai-openai/src/audio/transcription-provider-options.ts @@ -1,4 +1,9 @@ import type OpenAI from 'openai' +import type { TranscriptionResponseFormat } from '@tanstack/ai' + +export type OpenAITranscriptionResponseFormat = + | TranscriptionResponseFormat + | 'diarized_json' /** * Provider-specific options for OpenAI Transcription @@ -30,12 +35,41 @@ export interface OpenAITranscriptionProviderOptions { * Either or both of these options are supported: word, or segment. */ timestamp_granularities?: Array<'word' | 'segment'> + /** + * Raw OpenAI response_format option. Prefer the top-level responseFormat + * argument for common transcription formats when using + * generateTranscription(). Use `diarized_json` here for OpenAI diarization + * output. Setting both this and the top-level responseFormat to different + * values throws. + */ + response_format?: OpenAITranscriptionResponseFormat + /** + * Raw OpenAI prompt option. Prefer the top-level prompt argument when using + * generateTranscription(). + */ + prompt?: string /** * Optional list of speaker names that correspond to the audio samples provided in known_speaker_references[]. Each entry should be a short identifier (for example customer or agent). Up to 4 speakers are supported. + * Must be provided together with known_speaker_references, with matching lengths. + * Only supported with gpt-4o-transcribe-diarize. */ known_speaker_names?: Array /** * Optional list of audio samples (as data URLs) that contain known speaker references matching known_speaker_names[]. Each sample must be between 2 and 10 seconds, and can use any of the same input audio formats supported by file. + * Must be provided together with known_speaker_names, with matching lengths. + * Only supported with gpt-4o-transcribe-diarize. */ known_speaker_references?: Array + /** + * Controls how the audio is cut into chunks. If unset, the audio is + * transcribed as a single block. Required by OpenAI when + * `gpt-4o-transcribe-diarize` input is longer than 30 seconds (this adapter + * defaults it to `"auto"` for that model). Use `"auto"` for the + * service-managed VAD strategy, or pass a `server_vad` config to tune + * segmentation. + */ + chunking_strategy?: + | 'auto' + | OpenAI.Audio.TranscriptionCreateParams.VadConfig + | null } diff --git a/packages/ai-openai/tests/transcription-adapter.test.ts b/packages/ai-openai/tests/transcription-adapter.test.ts new file mode 100644 index 000000000..df022a196 --- /dev/null +++ b/packages/ai-openai/tests/transcription-adapter.test.ts @@ -0,0 +1,784 @@ +import { describe, expect, it, vi } from 'vitest' +import { resolveDebugOption } from '@tanstack/ai/adapter-internals' +import { + OpenAITranscriptionAdapter, + createOpenaiTranscription, +} from '../src/adapters/transcription' +import type OpenAI from 'openai' +import type { OpenAITranscriptionModel } from '../src/model-meta' + +const testLogger = resolveDebugOption(false) + +class TestOpenAITranscriptionAdapter< + TModel extends OpenAITranscriptionModel, +> extends OpenAITranscriptionAdapter { + spyOnTranscriptionsCreate() { + return vi.spyOn(this.client.audio.transcriptions, 'create') + } +} + +describe('OpenAI transcription adapter', () => { + it('creates a diarization-capable adapter', () => { + const adapter = createOpenaiTranscription( + 'gpt-4o-transcribe-diarize', + 'test-api-key', + ) + + expect(adapter).toBeInstanceOf(OpenAITranscriptionAdapter) + expect(adapter.name).toBe('openai') + }) + + it('defaults the diarization model to diarized_json with automatic chunking', async () => { + const mockResponse: OpenAI.Audio.TranscriptionDiarized = { + text: 'Agent: Hello\nCustomer: Hi', + duration: 2.2, + task: 'transcribe', + segments: [ + { + id: 'seg_0', + type: 'transcript.text.segment', + start: 0, + end: 1.4, + text: 'Hello', + speaker: 'agent', + }, + { + id: 'seg_1', + type: 'transcript.text.segment', + start: 1.5, + end: 2.2, + text: 'Hi', + speaker: 'customer', + }, + ], + } + const adapter = new TestOpenAITranscriptionAdapter( + { apiKey: 'test-api-key' }, + 'gpt-4o-transcribe-diarize', + ) + const mockCreate = adapter + .spyOnTranscriptionsCreate() + .mockResolvedValueOnce(mockResponse) + + const result = await adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'meeting.wav', { type: 'audio/wav' }), + logger: testLogger, + }) + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'gpt-4o-transcribe-diarize', + response_format: 'diarized_json', + chunking_strategy: 'auto', + }), + ) + expect(result.text).toBe('Agent: Hello\nCustomer: Hi') + expect(result.segments).toEqual([ + { + id: 0, + start: 0, + end: 1.4, + text: 'Hello', + speaker: 'agent', + }, + { + id: 1, + start: 1.5, + end: 2.2, + text: 'Hi', + speaker: 'customer', + }, + ]) + }) + + it('passes explicit diarization chunking and known speaker references', async () => { + const mockResponse: OpenAI.Audio.TranscriptionDiarized = { + text: 'Speaker text', + duration: 1, + task: 'transcribe', + segments: [ + { + id: 'speaker-intro', + type: 'transcript.text.segment', + start: 0, + end: 1, + text: 'Speaker text', + speaker: 'agent', + }, + ], + } + const adapter = new TestOpenAITranscriptionAdapter( + { apiKey: 'test-api-key' }, + 'gpt-4o-transcribe-diarize', + ) + const mockCreate = adapter + .spyOnTranscriptionsCreate() + .mockResolvedValueOnce(mockResponse) + + const result = await adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'meeting.wav', { type: 'audio/wav' }), + modelOptions: { + response_format: 'diarized_json', + chunking_strategy: { + type: 'server_vad', + threshold: 0.5, + prefix_padding_ms: 300, + silence_duration_ms: 500, + }, + known_speaker_names: ['agent'], + known_speaker_references: ['data:audio/wav;base64,AAA='], + }, + logger: testLogger, + }) + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + response_format: 'diarized_json', + chunking_strategy: { + type: 'server_vad', + threshold: 0.5, + prefix_padding_ms: 300, + silence_duration_ms: 500, + }, + known_speaker_names: ['agent'], + known_speaker_references: ['data:audio/wav;base64,AAA='], + }), + ) + expect(result.segments?.[0]?.id).toBe(0) + }) + + it('uses snake_case modelOptions response_format for diarized output', async () => { + const mockResponse: OpenAI.Audio.TranscriptionDiarized = { + text: 'Agent: Hello', + duration: 1, + task: 'transcribe', + segments: [ + { + id: 'seg_0', + type: 'transcript.text.segment', + start: 0, + end: 1, + text: 'Hello', + speaker: 'agent', + }, + ], + } + const adapter = new TestOpenAITranscriptionAdapter( + { apiKey: 'test-api-key' }, + 'gpt-4o-transcribe-diarize', + ) + const mockCreate = adapter + .spyOnTranscriptionsCreate() + .mockResolvedValueOnce(mockResponse) + + const result = await adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'meeting.wav', { type: 'audio/wav' }), + modelOptions: { + response_format: 'diarized_json', + chunking_strategy: null, + }, + logger: testLogger, + }) + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + response_format: 'diarized_json', + chunking_strategy: null, + }), + ) + expect(result.segments?.[0]?.speaker).toBe('agent') + }) + + it('respects explicit null chunking for short diarization inputs', async () => { + const mockResponse: OpenAI.Audio.TranscriptionDiarized = { + text: 'Hello', + duration: 1, + task: 'transcribe', + segments: [], + } + const adapter = new TestOpenAITranscriptionAdapter( + { apiKey: 'test-api-key' }, + 'gpt-4o-transcribe-diarize', + ) + const mockCreate = adapter + .spyOnTranscriptionsCreate() + .mockResolvedValueOnce(mockResponse) + + const result = await adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'short.wav', { type: 'audio/wav' }), + modelOptions: { + chunking_strategy: null, + }, + logger: testLogger, + }) + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + chunking_strategy: null, + }), + ) + // Diarized requests always report segments, even when empty — an empty + // list must not look like a non-diarized result. + expect(result.segments).toEqual([]) + }) + + it('allows json or text response formats for the diarization model', async () => { + const mockResponse: OpenAI.Audio.Transcription = { + text: 'Hello', + } + const adapter = new TestOpenAITranscriptionAdapter( + { apiKey: 'test-api-key' }, + 'gpt-4o-transcribe-diarize', + ) + const mockCreate = adapter + .spyOnTranscriptionsCreate() + .mockResolvedValueOnce(mockResponse) + + const result = await adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'short.wav', { type: 'audio/wav' }), + responseFormat: 'json', + logger: testLogger, + }) + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + response_format: 'json', + chunking_strategy: 'auto', + }), + ) + expect(result).toMatchObject({ + model: 'gpt-4o-transcribe-diarize', + text: 'Hello', + }) + }) + + it('rejects unsupported response formats for the diarization model', async () => { + const adapter = new TestOpenAITranscriptionAdapter( + { apiKey: 'test-api-key' }, + 'gpt-4o-transcribe-diarize', + ) + + for (const responseFormat of ['srt', 'vtt', 'verbose_json'] as const) { + await expect( + adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + responseFormat, + logger: testLogger, + }), + ).rejects.toThrow( + 'diarization transcription models only support json, text, and diarized_json', + ) + } + + await expect( + adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + modelOptions: { + response_format: 'verbose_json', + }, + logger: testLogger, + }), + ).rejects.toThrow( + 'diarization transcription models only support json, text, and diarized_json', + ) + }) + + it('rejects diarization-only options with non-diarization models', async () => { + const adapter = new TestOpenAITranscriptionAdapter( + { apiKey: 'test-api-key' }, + 'whisper-1', + ) + + await expect( + adapter.transcribe({ + model: 'whisper-1', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + responseFormat: 'diarized_json' as never, + logger: testLogger, + }), + ).rejects.toThrow('speaker diarization options') + + await expect( + adapter.transcribe({ + model: 'whisper-1', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + modelOptions: { + response_format: 'diarized_json', + }, + logger: testLogger, + }), + ).rejects.toThrow('speaker diarization options') + + await expect( + adapter.transcribe({ + model: 'whisper-1', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + modelOptions: { + known_speaker_names: ['agent'], + known_speaker_references: ['data:audio/wav;base64,AAA='], + }, + logger: testLogger, + }), + ).rejects.toThrow('speaker diarization options') + }) + + it('allows chunking_strategy with non-diarization models', async () => { + const mockResponse: OpenAI.Audio.Transcription = { + text: 'Hello', + } + const adapter = new TestOpenAITranscriptionAdapter( + { apiKey: 'test-api-key' }, + 'gpt-4o-transcribe', + ) + const mockCreate = adapter + .spyOnTranscriptionsCreate() + .mockResolvedValueOnce(mockResponse) + + await adapter.transcribe({ + model: 'gpt-4o-transcribe', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + modelOptions: { + chunking_strategy: 'auto', + }, + logger: testLogger, + }) + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'gpt-4o-transcribe', + chunking_strategy: 'auto', + }), + ) + }) + + it('rejects unsupported diarization prompt and timestamp options', async () => { + const adapter = new TestOpenAITranscriptionAdapter( + { apiKey: 'test-api-key' }, + 'gpt-4o-transcribe-diarize', + ) + + await expect( + adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + prompt: 'Use product vocabulary', + logger: testLogger, + }), + ).rejects.toThrow('do not support prompts') + + await expect( + adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + modelOptions: { + prompt: 'Use product vocabulary', + }, + logger: testLogger, + }), + ).rejects.toThrow('do not support prompts') + + await expect( + adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + modelOptions: { + timestamp_granularities: ['word'], + }, + logger: testLogger, + }), + ).rejects.toThrow('timestamp_granularities') + }) + + it('rejects unsupported diarization include and too many known speakers', async () => { + const adapter = new TestOpenAITranscriptionAdapter( + { apiKey: 'test-api-key' }, + 'gpt-4o-transcribe-diarize', + ) + + await expect( + adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + modelOptions: { + include: ['logprobs'], + }, + logger: testLogger, + }), + ).rejects.toThrow('include') + + await expect( + adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + modelOptions: { + known_speaker_names: ['a', 'b', 'c', 'd', 'e'], + known_speaker_references: [ + 'data:audio/wav;base64,AAA=', + 'data:audio/wav;base64,BBB=', + 'data:audio/wav;base64,CCC=', + 'data:audio/wav;base64,DDD=', + 'data:audio/wav;base64,EEE=', + ], + }, + logger: testLogger, + }), + ).rejects.toThrow('at most 4') + + await expect( + adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + modelOptions: { + known_speaker_names: ['agent'], + }, + logger: testLogger, + }), + ).rejects.toThrow('must both be provided together') + + await expect( + adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + modelOptions: { + known_speaker_references: ['data:audio/wav;base64,AAA='], + }, + logger: testLogger, + }), + ).rejects.toThrow('must both be provided together') + + await expect( + adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + modelOptions: { + known_speaker_names: ['agent'], + known_speaker_references: [ + 'data:audio/wav;base64,AAA=', + 'data:audio/wav;base64,BBB=', + ], + }, + logger: testLogger, + }), + ).rejects.toThrow('matching lengths') + }) + + it('accepts exactly 4 known speakers', async () => { + const mockResponse: OpenAI.Audio.TranscriptionDiarized = { + text: 'Hello', + duration: 1, + task: 'transcribe', + segments: [], + } + const adapter = new TestOpenAITranscriptionAdapter( + { apiKey: 'test-api-key' }, + 'gpt-4o-transcribe-diarize', + ) + const mockCreate = adapter + .spyOnTranscriptionsCreate() + .mockResolvedValueOnce(mockResponse) + + const names = ['a', 'b', 'c', 'd'] + const references = [ + 'data:audio/wav;base64,AAA=', + 'data:audio/wav;base64,BBB=', + 'data:audio/wav;base64,CCC=', + 'data:audio/wav;base64,DDD=', + ] + await adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + modelOptions: { + known_speaker_names: names, + known_speaker_references: references, + }, + logger: testLogger, + }) + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + known_speaker_names: names, + known_speaker_references: references, + }), + ) + }) + + it('rejects conflicting top-level and modelOptions response formats', async () => { + const adapter = new TestOpenAITranscriptionAdapter( + { apiKey: 'test-api-key' }, + 'gpt-4o-transcribe-diarize', + ) + + await expect( + adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + responseFormat: 'json', + modelOptions: { + response_format: 'diarized_json', + }, + logger: testLogger, + }), + ).rejects.toThrow('Conflicting response formats') + }) + + it('parses numeric diarized segment ids and guards blank ids', async () => { + const mockResponse: OpenAI.Audio.TranscriptionDiarized = { + text: 'One Two', + duration: 2, + task: 'transcribe', + segments: [ + { + id: 'seg_7', + type: 'transcript.text.segment', + start: 0, + end: 1, + text: 'One', + speaker: 'a', + }, + { + id: '', + type: 'transcript.text.segment', + start: 1, + end: 2, + text: 'Two', + speaker: 'b', + }, + ], + } + const adapter = new TestOpenAITranscriptionAdapter( + { apiKey: 'test-api-key' }, + 'gpt-4o-transcribe-diarize', + ) + adapter.spyOnTranscriptionsCreate().mockResolvedValueOnce(mockResponse) + + const result = await adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + logger: testLogger, + }) + + // seg_7 parses to 7; the blank id falls back to the array index (1), not + // Number('') === 0, which would collide with a real seg_0. + expect(result.segments?.map((s) => s.id)).toEqual([7, 1]) + }) + + it('maps token usage and duration for diarized responses', async () => { + const mockResponse: OpenAI.Audio.TranscriptionDiarized = { + text: 'Hello', + duration: 2.5, + task: 'transcribe', + segments: [], + usage: { + type: 'tokens', + input_tokens: 10, + output_tokens: 5, + total_tokens: 15, + input_token_details: { audio_tokens: 8, text_tokens: 2 }, + }, + } + const adapter = new TestOpenAITranscriptionAdapter( + { apiKey: 'test-api-key' }, + 'gpt-4o-transcribe-diarize', + ) + adapter.spyOnTranscriptionsCreate().mockResolvedValueOnce(mockResponse) + + const result = await adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + logger: testLogger, + }) + + expect(result.duration).toBe(2.5) + expect(result.usage).toEqual({ + promptTokens: 10, + completionTokens: 5, + totalTokens: 15, + promptTokensDetails: { audioTokens: 8, textTokens: 2 }, + completionTokensDetails: { textTokens: 5 }, + }) + }) + + it('maps duration-billed usage for diarized responses', async () => { + const mockResponse: OpenAI.Audio.TranscriptionDiarized = { + text: 'Hello', + duration: 2.5, + task: 'transcribe', + segments: [], + usage: { type: 'duration', seconds: 2.5 }, + } + const adapter = new TestOpenAITranscriptionAdapter( + { apiKey: 'test-api-key' }, + 'gpt-4o-transcribe-diarize', + ) + adapter.spyOnTranscriptionsCreate().mockResolvedValueOnce(mockResponse) + + const result = await adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + logger: testLogger, + }) + + expect(result.usage).toEqual({ + promptTokens: 0, + completionTokens: 0, + totalTokens: 0, + durationSeconds: 2.5, + }) + }) + + it('honors modelOptions response_format verbose_json for whisper', async () => { + const mockResponse: OpenAI.Audio.Transcriptions.TranscriptionVerbose = { + text: 'Hello world', + duration: 3, + language: 'en', + segments: [ + { + id: 0, + avg_logprob: 0, + compression_ratio: 1, + end: 3, + no_speech_prob: 0, + seek: 0, + start: 0, + temperature: 0, + text: 'Hello world', + tokens: [1, 2], + }, + ], + } + const adapter = new TestOpenAITranscriptionAdapter( + { apiKey: 'test-api-key' }, + 'whisper-1', + ) + const mockCreate = adapter + .spyOnTranscriptionsCreate() + .mockResolvedValueOnce(mockResponse) + + const result = await adapter.transcribe({ + model: 'whisper-1', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + modelOptions: { + response_format: 'verbose_json', + }, + logger: testLogger, + }) + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + response_format: 'verbose_json', + }), + ) + expect(result.segments).toEqual([ + { id: 0, start: 0, end: 3, text: 'Hello world', confidence: 1 }, + ]) + }) + + it('honors modelOptions response_format text for whisper', async () => { + const adapter = new TestOpenAITranscriptionAdapter( + { apiKey: 'test-api-key' }, + 'whisper-1', + ) + const mockCreate = adapter + .spyOnTranscriptionsCreate() + .mockResolvedValueOnce( + 'Hello world' as unknown as OpenAI.Audio.Transcription, + ) + + const result = await adapter.transcribe({ + model: 'whisper-1', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + modelOptions: { + response_format: 'text', + }, + logger: testLogger, + }) + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + response_format: 'text', + }), + ) + expect(result.text).toBe('Hello world') + expect(result.segments).toBeUndefined() + }) + + it('returns plain text for the diarization model with text format', async () => { + const adapter = new TestOpenAITranscriptionAdapter( + { apiKey: 'test-api-key' }, + 'gpt-4o-transcribe-diarize', + ) + adapter + .spyOnTranscriptionsCreate() + .mockResolvedValueOnce('Hello' as unknown as OpenAI.Audio.Transcription) + + const result = await adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + responseFormat: 'text', + logger: testLogger, + }) + + expect(result.text).toBe('Hello') + }) + + it('does not let modelOptions override model, file, or stream', async () => { + const mockResponse: OpenAI.Audio.TranscriptionDiarized = { + text: 'Hello', + duration: 1, + task: 'transcribe', + segments: [], + } + const adapter = new TestOpenAITranscriptionAdapter( + { apiKey: 'test-api-key' }, + 'gpt-4o-transcribe-diarize', + ) + const mockCreate = adapter + .spyOnTranscriptionsCreate() + .mockResolvedValueOnce(mockResponse) + + const file = new File([], 'audio.wav', { type: 'audio/wav' }) + // Simulates untyped modelOptions passed through from a server route. + const hostileModelOptions = { + model: 'whisper-1', + file: new File([], 'evil.wav'), + stream: true, + } as never + + await adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: file, + modelOptions: hostileModelOptions, + logger: testLogger, + }) + + const sentRequest = mockCreate.mock.calls[0]?.[0] + expect(sentRequest?.model).toBe('gpt-4o-transcribe-diarize') + expect(sentRequest?.file).toBe(file) + expect(sentRequest).not.toHaveProperty('stream') + }) + + it('throws a descriptive error when a diarized response has no segments', async () => { + const adapter = new TestOpenAITranscriptionAdapter( + { apiKey: 'test-api-key' }, + 'gpt-4o-transcribe-diarize', + ) + adapter.spyOnTranscriptionsCreate().mockResolvedValueOnce({ + text: 'Hello', + } as unknown as OpenAI.Audio.TranscriptionDiarized) + + await expect( + adapter.transcribe({ + model: 'gpt-4o-transcribe-diarize', + audio: new File([], 'audio.wav', { type: 'audio/wav' }), + logger: testLogger, + }), + ).rejects.toThrow('did not include segments') + }) +}) diff --git a/packages/ai/skills/ai-core/media-generation/SKILL.md b/packages/ai/skills/ai-core/media-generation/SKILL.md index 0c63f347e..b55638111 100644 --- a/packages/ai/skills/ai-core/media-generation/SKILL.md +++ b/packages/ai/skills/ai-core/media-generation/SKILL.md @@ -357,7 +357,7 @@ const { generate, result, isLoading } = useGenerateSpeech({ ### 4. Audio Transcription Adapter: `openaiTranscription` (whisper-1, gpt-4o-transcribe, -gpt-4o-mini-transcribe). +gpt-4o-mini-transcribe, gpt-4o-transcribe-diarize). > **Capturing audio in the browser:** Use `useAudioRecorder` from `@tanstack/ai-react` to record directly in the browser, then pass the recording as the `audio` input to `generate()`, or use `recording.part` as a prompt part in chat/generation calls. No transcoding or extra dependencies required — the recorder returns the native browser format (`audio/webm` or `audio/mp4`). For transcription, wrap it as a `data:` URL so the provider gets the real content type; passing raw `recording.base64` makes the adapter assume `audio/mpeg` and mislabel the webm/mp4 bytes. > @@ -382,16 +382,21 @@ const result = await generateTranscription({ language: 'en', responseFormat: 'verbose_json', modelOptions: { - include: ['segment', 'word'], + timestamp_granularities: ['word', 'segment'], }, }) // result.text -- full transcribed text // result.language -- detected/specified language // result.duration -- audio duration in seconds -// result.segments -- timestamped segments with optional word-level timestamps +// result.segments -- timestamped segments (word-level timestamps are in result.words) ``` +For speaker diarization, use `openaiTranscription('gpt-4o-transcribe-diarize')`. +When no response format is given it defaults the request to `response_format: 'diarized_json'` +and `chunking_strategy: 'auto'` (a top-level `responseFormat` of `'json'`/`'text'` opts out of +speaker segments); do not pass `prompt`, `include`, or `timestamp_granularities` with this model. + Client hook: ```tsx diff --git a/packages/ai/src/activities/generateTranscription/index.ts b/packages/ai/src/activities/generateTranscription/index.ts index 2bd377b4f..31b4b3a0c 100644 --- a/packages/ai/src/activities/generateTranscription/index.ts +++ b/packages/ai/src/activities/generateTranscription/index.ts @@ -19,7 +19,11 @@ import type { InternalLogger } from '../../logger/internal-logger' import type { DebugOption } from '../../logger/types' import type { GenerationMiddleware } from '../middleware' import type { TranscriptionAdapter } from './adapter' -import type { StreamChunk, TranscriptionResult } from '../../types' +import type { + StreamChunk, + TranscriptionResponseFormat, + TranscriptionResult, +} from '../../types' // =========================== // Activity Kind @@ -67,7 +71,7 @@ export interface TranscriptionActivityOptions< /** An optional prompt to guide the transcription */ prompt?: string /** The format of the transcription output */ - responseFormat?: 'json' | 'text' | 'srt' | 'verbose_json' | 'vtt' + responseFormat?: TranscriptionResponseFormat /** Provider-specific options for transcription */ modelOptions?: TranscriptionProviderOptions /** diff --git a/packages/ai/src/types.ts b/packages/ai/src/types.ts index 128c6779f..e497090e2 100644 --- a/packages/ai/src/types.ts +++ b/packages/ai/src/types.ts @@ -1930,6 +1930,13 @@ export interface TTSResult { * Options for audio transcription. * These are the common options supported across providers. */ +export type TranscriptionResponseFormat = + | 'json' + | 'text' + | 'srt' + | 'verbose_json' + | 'vtt' + export interface TranscriptionOptions< TProviderOptions extends object = object, > { @@ -1942,7 +1949,7 @@ export interface TranscriptionOptions< /** An optional prompt to guide the transcription */ prompt?: string /** The format of the transcription output */ - responseFormat?: 'json' | 'text' | 'srt' | 'verbose_json' | 'vtt' + responseFormat?: TranscriptionResponseFormat /** Model-specific options for transcription */ modelOptions?: TProviderOptions /** diff --git a/testing/e2e/fixtures/transcription/basic.json b/testing/e2e/fixtures/transcription/basic.json index 2a1616e26..a936b4bff 100644 --- a/testing/e2e/fixtures/transcription/basic.json +++ b/testing/e2e/fixtures/transcription/basic.json @@ -1,7 +1,7 @@ { "fixtures": [ { - "match": { "userMessage": "audio.mpeg" }, + "match": { "model": "whisper-1", "userMessage": "audio.mpeg" }, "response": { "transcription": { "text": "I would like to buy a Fender Stratocaster please" diff --git a/testing/e2e/fixtures/transcription/diarization.json b/testing/e2e/fixtures/transcription/diarization.json new file mode 100644 index 000000000..bb2d8041b --- /dev/null +++ b/testing/e2e/fixtures/transcription/diarization.json @@ -0,0 +1,32 @@ +{ + "fixtures": [ + { + "match": { "model": "gpt-4o-transcribe-diarize" }, + "response": { + "transcription": { + "text": "agent: Welcome to the store.\ncustomer: I need a Fender Stratocaster.", + "language": "english", + "duration": 3.2, + "segments": [ + { + "id": "seg_0", + "type": "transcript.text.segment", + "start": 0, + "end": 1.4, + "text": "Welcome to the store.", + "speaker": "agent" + }, + { + "id": "seg_1", + "type": "transcript.text.segment", + "start": 1.5, + "end": 3.2, + "text": "I need a Fender Stratocaster.", + "speaker": "customer" + } + ] + } + } + } + ] +} diff --git a/testing/e2e/src/components/TranscriptionUI.tsx b/testing/e2e/src/components/TranscriptionUI.tsx index 47a76bac9..536118f43 100644 --- a/testing/e2e/src/components/TranscriptionUI.tsx +++ b/testing/e2e/src/components/TranscriptionUI.tsx @@ -6,10 +6,16 @@ import { import { generateTranscriptionFn } from '@/lib/server-functions' import type { TranscriptionResult } from '@tanstack/ai' import type { TranscriptionGenerateInput } from '@tanstack/ai-client' -import type { Mode, Provider } from '@/lib/types' +import type { Feature, Mode, Provider } from '@/lib/types' + +type TranscriptionFeature = Extract< + Feature, + 'transcription' | 'transcription-diarization' +> interface TranscriptionUIProps { provider: Provider + feature: TranscriptionFeature mode: Mode testId?: string aimockPort?: number @@ -21,12 +27,29 @@ const TEST_AUDIO_BASE64 = 'data:audio/mpeg;base64,SGVsbG8=' export function TranscriptionUI({ provider, + feature, mode, testId, aimockPort, }: TranscriptionUIProps) { + const isDiarization = feature === 'transcription-diarization' + const transcriptionInput: TranscriptionGenerateInput = { + audio: TEST_AUDIO_BASE64, + language: 'en', + ...(isDiarization + ? { + modelOptions: { + response_format: 'diarized_json', + chunking_strategy: 'auto', + known_speaker_names: ['agent', 'customer'], + known_speaker_references: [TEST_AUDIO_BASE64, TEST_AUDIO_BASE64], + }, + } + : {}), + } + const connectionOptions = () => { - const body = { provider, testId, aimockPort } + const body = { provider, feature, testId, aimockPort } if (mode === 'sse') { return { connection: fetchServerSentEvents('/api/transcription'), body } @@ -40,7 +63,10 @@ export function TranscriptionUI({ data: { audio: input.audio as string, language: input.language, + responseFormat: input.responseFormat, + modelOptions: input.modelOptions, provider, + feature, aimockPort, testId, }, @@ -56,7 +82,7 @@ export function TranscriptionUI({
)} {result && ( -

- {result.text} -

+
+

+ {result.text} +

+ {result.segments && result.segments.length > 0 && ( +
+ {result.segments.map((segment, index) => ( +
+ {segment.speaker && ( + + {segment.speaker} + + )} + {segment.text} +
+ ))} +
+ )} +
)} ) diff --git a/testing/e2e/src/lib/feature-support.ts b/testing/e2e/src/lib/feature-support.ts index 01738f1db..7f985635a 100644 --- a/testing/e2e/src/lib/feature-support.ts +++ b/testing/e2e/src/lib/feature-support.ts @@ -233,6 +233,7 @@ export const matrix: Record> = { 'sound-effects': new Set(['elevenlabs']), tts: new Set(['openai', 'grok', 'elevenlabs']), transcription: new Set(['openai', 'grok', 'groq', 'elevenlabs']), + 'transcription-diarization': new Set(['openai']), // Gemini Veo runs through a custom aimock mount (see geminiVeoMount in // global-setup.ts) — aimock 1.29 doesn't model the long-running // `:predictLongRunning` + operations-polling pair natively. diff --git a/testing/e2e/src/lib/features.ts b/testing/e2e/src/lib/features.ts index 972eba0ae..fa2b78724 100644 --- a/testing/e2e/src/lib/features.ts +++ b/testing/e2e/src/lib/features.ts @@ -124,6 +124,10 @@ export const featureConfigs: Record = { tools: [], modelOptions: {}, }, + 'transcription-diarization': { + tools: [], + modelOptions: {}, + }, 'video-gen': { tools: [], modelOptions: {}, diff --git a/testing/e2e/src/lib/media-providers.ts b/testing/e2e/src/lib/media-providers.ts index ad5c01815..90446f9e3 100644 --- a/testing/e2e/src/lib/media-providers.ts +++ b/testing/e2e/src/lib/media-providers.ts @@ -20,11 +20,17 @@ import { createElevenLabsSpeech, createElevenLabsTranscription, } from '@tanstack/ai-elevenlabs' +import type { TranscriptionResponseFormat } from '@tanstack/ai' import type { Feature, Provider } from '@/lib/types' const LLMOCK_DEFAULT_BASE = process.env.LLMOCK_URL || 'http://127.0.0.1:4010' const DUMMY_KEY = 'sk-e2e-test-dummy-key' +type TranscriptionAdapterOptions = { + responseFormat?: TranscriptionResponseFormat + modelOptions?: Record +} + function llmockBase(aimockPort?: number): string { if (aimockPort) return `http://127.0.0.1:${aimockPort}` return LLMOCK_DEFAULT_BASE @@ -38,6 +44,17 @@ function testHeaders(testId?: string): Record | undefined { return testId ? { 'X-Test-Id': testId } : undefined } +function getOpenaiTranscriptionModel(options: TranscriptionAdapterOptions) { + const modelOptions = options.modelOptions + const isDiarizationRequest = + modelOptions?.response_format === 'diarized_json' || + modelOptions?.chunking_strategy !== undefined || + modelOptions?.known_speaker_names !== undefined || + modelOptions?.known_speaker_references !== undefined + + return isDiarizationRequest ? 'gpt-4o-transcribe-diarize' : 'whisper-1' +} + export function createImageAdapter( provider: Provider, aimockPort?: number, @@ -97,11 +114,13 @@ export function createTranscriptionAdapter( provider: Provider, aimockPort?: number, testId?: string, + options: TranscriptionAdapterOptions = {}, ) { const headers = testHeaders(testId) + const openaiTranscriptionModel = getOpenaiTranscriptionModel(options) const factories: Record any> = { openai: () => - createOpenaiTranscription('whisper-1', DUMMY_KEY, { + createOpenaiTranscription(openaiTranscriptionModel, DUMMY_KEY, { baseURL: openaiUrl(aimockPort), defaultHeaders: headers, }), diff --git a/testing/e2e/src/lib/server-functions.ts b/testing/e2e/src/lib/server-functions.ts index 20faeb7b4..f63843168 100644 --- a/testing/e2e/src/lib/server-functions.ts +++ b/testing/e2e/src/lib/server-functions.ts @@ -7,7 +7,7 @@ import { generateVideo, getVideoJobStatus, } from '@tanstack/ai' -import type { MediaPrompt } from '@tanstack/ai' +import type { MediaPrompt, TranscriptionResponseFormat } from '@tanstack/ai' import type { Feature, Provider } from '@/lib/types' import { createAudioAdapter, @@ -82,7 +82,10 @@ export const generateTranscriptionFn = createServerFn({ method: 'POST' }) (data: { audio: string language?: string + responseFormat?: TranscriptionResponseFormat + modelOptions?: Record provider: Provider + feature?: Feature aimockPort?: number testId?: string }) => { @@ -97,11 +100,17 @@ export const generateTranscriptionFn = createServerFn({ method: 'POST' }) data.provider, data.aimockPort, data.testId, + { + responseFormat: data.responseFormat, + modelOptions: data.modelOptions, + }, ) return generateTranscription({ adapter, audio: data.audio, language: data.language, + responseFormat: data.responseFormat, + modelOptions: data.modelOptions, }) }) diff --git a/testing/e2e/src/lib/types.ts b/testing/e2e/src/lib/types.ts index e982da7e3..3abd41107 100644 --- a/testing/e2e/src/lib/types.ts +++ b/testing/e2e/src/lib/types.ts @@ -39,6 +39,7 @@ export type Feature = | 'sound-effects' | 'tts' | 'transcription' + | 'transcription-diarization' | 'video-gen' | 'image-to-video' | 'stateful-interactions' @@ -83,6 +84,7 @@ export const ALL_FEATURES: Feature[] = [ 'sound-effects', 'tts', 'transcription', + 'transcription-diarization', 'video-gen', 'image-to-video', 'stateful-interactions', diff --git a/testing/e2e/src/routes/$provider/$feature.tsx b/testing/e2e/src/routes/$provider/$feature.tsx index b1fe5b40f..ed1dfa9ea 100644 --- a/testing/e2e/src/routes/$provider/$feature.tsx +++ b/testing/e2e/src/routes/$provider/$feature.tsx @@ -45,6 +45,7 @@ const MEDIA_FEATURES = new Set([ 'image-to-image', 'tts', 'transcription', + 'transcription-diarization', 'video-gen', 'image-to-video', 'audio-gen', @@ -154,9 +155,11 @@ function MediaFeature({ /> ) case 'transcription': + case 'transcription-diarization': return ( provider: Provider testId?: string aimockPort?: number } - const adapter = createTranscriptionAdapter(provider, aimockPort, testId) + const adapter = createTranscriptionAdapter( + provider, + aimockPort, + testId, + { responseFormat, modelOptions }, + ) try { const stream = generateTranscription({ adapter, audio, language, + responseFormat, + modelOptions, stream: true, }) return toHttpResponse(stream, { abortController }) diff --git a/testing/e2e/src/routes/api.transcription.ts b/testing/e2e/src/routes/api.transcription.ts index 070b29db7..f18fd6867 100644 --- a/testing/e2e/src/routes/api.transcription.ts +++ b/testing/e2e/src/routes/api.transcription.ts @@ -1,7 +1,8 @@ import { createFileRoute } from '@tanstack/react-router' import { generateTranscription, toServerSentEventsResponse } from '@tanstack/ai' -import { createTranscriptionAdapter } from '@/lib/media-providers' +import type { TranscriptionResponseFormat } from '@tanstack/ai' import type { Provider } from '@/lib/types' +import { createTranscriptionAdapter } from '@/lib/media-providers' export const Route = createFileRoute('/api/transcription')({ server: { @@ -11,21 +12,38 @@ export const Route = createFileRoute('/api/transcription')({ const abortController = new AbortController() const body = await request.json() const data = body.forwardedProps ?? body.data ?? body - const { audio, language, provider, testId, aimockPort } = data as { + const { + audio, + language, + responseFormat, + modelOptions, + provider, + testId, + aimockPort, + } = data as { audio: string language?: string + responseFormat?: TranscriptionResponseFormat + modelOptions?: Record provider: Provider testId?: string aimockPort?: number } - const adapter = createTranscriptionAdapter(provider, aimockPort, testId) + const adapter = createTranscriptionAdapter( + provider, + aimockPort, + testId, + { responseFormat, modelOptions }, + ) try { const stream = generateTranscription({ adapter, audio, language, + responseFormat, + modelOptions, stream: true, }) return toServerSentEventsResponse(stream, { abortController }) diff --git a/testing/e2e/tests/transcription.spec.ts b/testing/e2e/tests/transcription.spec.ts index 85822b633..faf1cd0ed 100644 --- a/testing/e2e/tests/transcription.spec.ts +++ b/testing/e2e/tests/transcription.spec.ts @@ -53,3 +53,43 @@ for (const provider of providersFor('transcription')) { }) }) } + +for (const provider of providersFor('transcription-diarization')) { + test.describe(`${provider} -- transcription-diarization`, () => { + for (const mode of ['sse', 'http-stream', 'fetcher'] as const) { + test(`${mode} -- transcribes diarized audio`, async ({ + page, + testId, + aimockPort, + }) => { + await page.goto( + featureUrl( + provider, + 'transcription-diarization', + testId, + aimockPort, + mode, + ), + ) + await clickGenerate(page) + await waitForGenerationComplete(page) + + await expect(page.getByTestId('transcription-text')).toContainText( + 'Fender Stratocaster', + ) + await expect(page.getByTestId('transcription-segments')).toContainText( + 'Welcome to the store', + ) + await expect(page.getByTestId('transcription-segments')).toContainText( + 'I need a Fender Stratocaster', + ) + await expect(page.getByTestId('transcription-speaker-0')).toHaveText( + 'agent', + ) + await expect(page.getByTestId('transcription-speaker-1')).toHaveText( + 'customer', + ) + }) + } + }) +}