diff --git a/docs/features/stt.md b/docs/features/stt.md index bb8c85ea..9705adae 100644 --- a/docs/features/stt.md +++ b/docs/features/stt.md @@ -69,6 +69,12 @@ Any OpenAI-compatible transcription API works: - Self-hosted Whisper servers - Local STT servers with OpenAI-compatible API +### Performance + +When external voice input is enabled, the browser audio pipeline is warmed up ahead of time so the first and subsequent recordings start faster. The audio context and worklet processor are prepared without requesting microphone access, and are retained between recordings; only the microphone track is stopped after each use. Resources are released entirely when external STT is disabled or the voice input UI unmounts. + +This optimization applies only to the external API provider. It does not affect the initial permission prompt — the browser still asks for microphone access on the first recording. + ## Using Voice Input ### Tap-to-Start / Tap-to-Stop diff --git a/frontend/public/audio-worklet-processor.js b/frontend/public/audio-worklet-processor.js index 0a0e584d..848dd5b2 100644 --- a/frontend/public/audio-worklet-processor.js +++ b/frontend/public/audio-worklet-processor.js @@ -59,12 +59,16 @@ class RecorderProcessor extends AudioWorkletProcessor { } _flushBuffer() { - const int16 = new Int16Array(this._buffer.length) - for (let i = 0; i < this._buffer.length; i++) { + const length = this._buffer.length + const int16 = new Int16Array(length) + let sumSquares = 0 + for (let i = 0; i < length; i++) { const sample = Math.max(-1, Math.min(1, this._buffer[i])) + sumSquares += sample * sample int16[i] = sample < 0 ? sample * 32768 : sample * 32767 } - this.port.postMessage(int16, [int16.buffer]) + const rms = length > 0 ? Math.sqrt(sumSquares / length) : 0 + this.port.postMessage({ samples: int16, rms }, [int16.buffer]) this._buffer = [] } } diff --git a/frontend/src/hooks/useSTT.test.tsx b/frontend/src/hooks/useSTT.test.tsx new file mode 100644 index 00000000..559f7c30 --- /dev/null +++ b/frontend/src/hooks/useSTT.test.tsx @@ -0,0 +1,182 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { renderHook, act, waitFor } from '@testing-library/react' +import { useSTT } from './useSTT' + +type MockRecorder = { + start: ReturnType + stop: ReturnType + abort: ReturnType + dispose: ReturnType + prepare: ReturnType + getState: ReturnType + setOnStateChange: ReturnType + setOnError: ReturnType + setOnDataAvailable: ReturnType + setOnNoSpeech: ReturnType +} + +const mocks = vi.hoisted(() => ({ + useSettings: vi.fn(), + AudioRecorder: vi.fn(), + getWebSpeechRecognizer: vi.fn(), + isWebRecognitionSupported: vi.fn(), +})) + +vi.mock('@/hooks/useSettings', () => ({ + useSettings: mocks.useSettings, +})) + +vi.mock('@/lib/audioRecorder', () => ({ + AudioRecorder: mocks.AudioRecorder, +})) + +vi.mock('@/lib/webSpeechRecognizer', () => ({ + getWebSpeechRecognizer: mocks.getWebSpeechRecognizer, + isWebRecognitionSupported: mocks.isWebRecognitionSupported, +})) + +const externalSTTPreferences = { + preferences: { + stt: { + enabled: true, + provider: 'external' as const, + endpoint: 'https://api.openai.com', + apiKey: 'test-key', + model: 'whisper-1', + language: 'en-US', + }, + }, +} + +describe('useSTT external provider lifecycle', () => { + let mockRecorder: MockRecorder + + beforeEach(() => { + vi.clearAllMocks() + + mockRecorder = { + start: vi.fn().mockResolvedValue(undefined), + stop: vi.fn(), + abort: vi.fn(), + dispose: vi.fn(), + prepare: vi.fn().mockResolvedValue(undefined), + getState: vi.fn().mockReturnValue('recording'), + setOnStateChange: vi.fn(), + setOnError: vi.fn(), + setOnDataAvailable: vi.fn(), + setOnNoSpeech: vi.fn(), + } + + mocks.AudioRecorder.mockImplementation(() => mockRecorder) + mocks.useSettings.mockReturnValue(externalSTTPreferences) + mocks.getWebSpeechRecognizer.mockReturnValue({ + start: vi.fn(), + stop: vi.fn(), + abort: vi.fn(), + clearCallbacks: vi.fn(), + onResult: vi.fn(), + onInterimResult: vi.fn(), + onError: vi.fn(), + onEnd: vi.fn(), + onStart: vi.fn(), + }) + mocks.isWebRecognitionSupported.mockReturnValue(true) + }) + + it('does not start external recording until startRecording is called', async () => { + const { result } = renderHook(() => useSTT()) + + await waitFor(() => { + expect(mocks.AudioRecorder).toHaveBeenCalledTimes(1) + }) + + expect(mockRecorder.start).not.toHaveBeenCalled() + expect(mockRecorder.prepare).toHaveBeenCalledTimes(1) + expect(mockRecorder.setOnStateChange).toHaveBeenCalledTimes(1) + expect(mockRecorder.setOnError).toHaveBeenCalledTimes(1) + expect(mockRecorder.setOnDataAvailable).toHaveBeenCalledTimes(1) + expect(mockRecorder.setOnNoSpeech).toHaveBeenCalledTimes(1) + + await act(async () => { + await result.current.startRecording() + }) + + expect(mockRecorder.start).toHaveBeenCalledTimes(1) + }) + + it('clears processing without an error when no speech is detected', async () => { + const { result } = renderHook(() => useSTT()) + + await waitFor(() => { + expect(mockRecorder.setOnNoSpeech).toHaveBeenCalledTimes(1) + }) + + const onNoSpeech = mockRecorder.setOnNoSpeech.mock.calls[0][0] as () => void + + act(() => { + onNoSpeech() + }) + + expect(result.current.isProcessing).toBe(false) + expect(result.current.isRecording).toBe(false) + expect(result.current.isError).toBe(false) + expect(result.current.error).toBeNull() + }) + + it('does not get stuck processing when stopping a silent recording', async () => { + const { result } = renderHook(() => useSTT()) + + await waitFor(() => { + expect(mockRecorder.setOnNoSpeech).toHaveBeenCalledTimes(1) + }) + + const onNoSpeech = mockRecorder.setOnNoSpeech.mock.calls[0][0] as () => void + mockRecorder.stop.mockImplementation(() => { + onNoSpeech() + }) + + await act(async () => { + await result.current.startRecording() + }) + + act(() => { + result.current.stopRecording() + }) + + expect(result.current.isProcessing).toBe(false) + expect(result.current.isRecording).toBe(false) + expect(result.current.isError).toBe(false) + }) + + it('ignores stopRecording when the recorder is not recording', async () => { + const { result } = renderHook(() => useSTT()) + + await waitFor(() => { + expect(mocks.AudioRecorder).toHaveBeenCalledTimes(1) + }) + + mockRecorder.getState.mockReturnValue('stopped') + + act(() => { + result.current.stopRecording() + }) + + expect(mockRecorder.stop).not.toHaveBeenCalled() + expect(result.current.isProcessing).toBe(false) + }) + + it('disposes external recorder resources on unmount', async () => { + const { unmount } = renderHook(() => useSTT()) + + await waitFor(() => { + expect(mocks.AudioRecorder).toHaveBeenCalledTimes(1) + }) + + const recorder = mockRecorder + + unmount() + + expect(recorder.dispose).toHaveBeenCalledTimes(1) + expect(recorder.abort).not.toHaveBeenCalled() + }) +}) diff --git a/frontend/src/hooks/useSTT.ts b/frontend/src/hooks/useSTT.ts index 384cdf1b..c820e4db 100644 --- a/frontend/src/hooks/useSTT.ts +++ b/frontend/src/hooks/useSTT.ts @@ -139,6 +139,13 @@ export function useSTT(userId = 'default') { }, 3000) }) + recorder.setOnNoSpeech(() => { + setIsProcessing(false) + setIsRecording(false) + setInterimTranscript('') + setState('idle') + }) + recorder.setOnDataAvailable(async (blob) => { if (lastProcessedBlobRef.current === blob) { return @@ -185,26 +192,36 @@ export function useSTT(userId = 'default') { }) }, []) - useEffect(() => { - if (!isEnabled || !isExternalProvider) { - return - } - + const ensureAudioRecorder = useCallback((): AudioRecorder => { if (!audioRecorder.current) { audioRecorder.current = new AudioRecorder() } - if (!recorderConfiguredRef.current) { setupAudioRecorder(audioRecorder.current) recorderConfiguredRef.current = true } + return audioRecorder.current + }, [setupAudioRecorder]) + + const disposeAudioRecorder = useCallback(() => { + if (audioRecorder.current) { + audioRecorder.current.dispose() + audioRecorder.current = null + } + recorderConfiguredRef.current = false + }, []) + + useEffect(() => { + if (!isEnabled || !isExternalProvider) { + return + } + + void ensureAudioRecorder().prepare().catch(() => undefined) return () => { - if (audioRecorder.current) { - audioRecorder.current.abort() - } + disposeAudioRecorder() } - }, [isEnabled, isExternalProvider, setupAudioRecorder]) + }, [isEnabled, isExternalProvider, ensureAudioRecorder, disposeAudioRecorder]) const clearStartupTimeout = useCallback(() => { if (startupTimeoutRef.current) { @@ -214,8 +231,8 @@ export function useSTT(userId = 'default') { }, []) const abortAndResetOnTimeout = useCallback(() => { - if (isExternalProvider && audioRecorder.current) { - audioRecorder.current.abort() + if (isExternalProvider) { + disposeAudioRecorder() } else { recognizer.current.abort() } @@ -224,7 +241,38 @@ export function useSTT(userId = 'default') { setState('idle') setIsError(true) setError('Microphone start timed out') - }, [isExternalProvider]) + }, [isExternalProvider, disposeAudioRecorder]) + + const runStartupWithTimeout = useCallback( + async (startup: () => Promise, startOpId: number): Promise => { + try { + const startupPromise = startup() + const timeoutPromise = new Promise((_, reject) => { + startupTimeoutRef.current = setTimeout(() => { + if (startOpIdRef.current !== startOpId) return + reject(new Error('Microphone start timed out')) + }, STT_START_TIMEOUT_MS) + }) + + await Promise.race([startupPromise, timeoutPromise]) + clearStartupTimeout() + + return startOpIdRef.current === startOpId + } catch (err) { + clearStartupTimeout() + if (startOpIdRef.current !== startOpId) return false + setIsProcessing(false) + if (err instanceof Error && err.message === 'Microphone start timed out') { + abortAndResetOnTimeout() + return false + } + setIsError(true) + setError(err instanceof Error ? err.message : 'Failed to start recording') + return false + } + }, + [clearStartupTimeout, abortAndResetOnTimeout], + ) const startRecording = useCallback(async (): Promise => { if (!isSupported) { @@ -249,41 +297,14 @@ export function useSTT(userId = 'default') { clearStartupTimeout() if (isExternalProvider) { - if (!audioRecorder.current) { - audioRecorder.current = new AudioRecorder() - setupAudioRecorder(audioRecorder.current) - } - - try { - setIsProcessing(true) - - const startupPromise = audioRecorder.current.start() - const timeoutPromise = new Promise((_, reject) => { - startupTimeoutRef.current = setTimeout(() => { - if (startOpIdRef.current !== startOpId) return - reject(new Error('Microphone start timed out')) - }, STT_START_TIMEOUT_MS) - }) - - await Promise.race([startupPromise, timeoutPromise]) - clearStartupTimeout() - - if (startOpIdRef.current !== startOpId) return false + const recorder = ensureAudioRecorder() + setIsProcessing(true) + const started = await runStartupWithTimeout(() => recorder.start(), startOpId) + if (started) { setIsProcessing(false) - return true - } catch (err) { - clearStartupTimeout() - if (startOpIdRef.current !== startOpId) return false - setIsProcessing(false) - if (err instanceof Error && err.message === 'Microphone start timed out') { - abortAndResetOnTimeout() - return false - } - setIsError(true) - setError(err instanceof Error ? err.message : 'Failed to start recording') - return false } + return started } else { const options: SpeechRecognitionOptions = { language: config.language, @@ -291,45 +312,21 @@ export function useSTT(userId = 'default') { maxAlternatives: 1, } - try { - setIsProcessing(true) - - const startupPromise = recognizer.current.start(options) - const timeoutPromise = new Promise((_, reject) => { - startupTimeoutRef.current = setTimeout(() => { - if (startOpIdRef.current !== startOpId) return - reject(new Error('Microphone start timed out')) - }, STT_START_TIMEOUT_MS) - }) - - await Promise.race([startupPromise, timeoutPromise]) - clearStartupTimeout() - - if (startOpIdRef.current !== startOpId) return false - - return true - } catch (err) { - clearStartupTimeout() - if (startOpIdRef.current !== startOpId) return false - setIsProcessing(false) - if (err instanceof Error && err.message === 'Microphone start timed out') { - abortAndResetOnTimeout() - return false - } - setIsError(true) - setError(err instanceof Error ? err.message : 'Failed to start recording') - return false - } + setIsProcessing(true) + return runStartupWithTimeout(() => recognizer.current.start(options), startOpId) } - }, [isSupported, isEnabled, isExternalProvider, config.language, setupAudioRecorder, clearStartupTimeout, abortAndResetOnTimeout]) + }, [isSupported, isEnabled, isExternalProvider, config.language, clearStartupTimeout, ensureAudioRecorder, runStartupWithTimeout]) const stopRecording = useCallback(() => { if (isExternalProvider && audioRecorder.current) { - audioRecorder.current.stop() + if (audioRecorder.current.getState() !== 'recording') { + return + } setIsProcessing(true) + audioRecorder.current.stop() } else { - recognizer.current.stop() setIsProcessing(true) + recognizer.current.stop() } }, [isExternalProvider]) diff --git a/frontend/src/lib/audioRecorder.test.ts b/frontend/src/lib/audioRecorder.test.ts index d0b3773b..34c47e27 100644 --- a/frontend/src/lib/audioRecorder.test.ts +++ b/frontend/src/lib/audioRecorder.test.ts @@ -1,4 +1,4 @@ -import { describe, it, expect } from 'vitest' +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' import { AudioRecorder, downsampleAndConvert, encodeWavFromInt16 } from './audioRecorder' const blobToArrayBuffer = (blob: Blob): Promise => @@ -158,3 +158,312 @@ describe('AudioRecorder.isSupported', () => { }).not.toThrow() }) }) + +describe('AudioRecorder.prepare', () => { + let originalAudioContext: typeof window.AudioContext + let originalAudioWorkletNode: unknown + let originalGetUserMedia: (typeof navigator.mediaDevices)['getUserMedia'] | undefined + let mockAddModule: ReturnType + let mockClose: ReturnType + let mockTrack: { stop: ReturnType; kind: string } + let mockGetUserMedia: ReturnType + let MockAudioContext: ReturnType + + beforeEach(() => { + originalAudioContext = window.AudioContext + originalAudioWorkletNode = (window as any).AudioWorkletNode + originalGetUserMedia = navigator.mediaDevices?.getUserMedia + + mockAddModule = vi.fn().mockResolvedValue(undefined) + mockClose = vi.fn().mockResolvedValue(undefined) + mockTrack = { stop: vi.fn(), kind: 'audio' } + const mockSource = { connect: vi.fn(), disconnect: vi.fn() } + + MockAudioContext = vi.fn().mockImplementation(() => ({ + state: 'running', + sampleRate: 16000, + audioWorklet: { addModule: mockAddModule }, + createMediaStreamSource: vi.fn().mockReturnValue(mockSource), + createScriptProcessor: vi.fn().mockReturnValue({ + connect: vi.fn(), + disconnect: vi.fn(), + onaudioprocess: null, + }), + resume: vi.fn().mockResolvedValue(undefined), + close: mockClose, + })) + window.AudioContext = MockAudioContext as unknown as typeof window.AudioContext + + ;(window as any).AudioWorkletNode = vi.fn().mockImplementation(() => ({ + port: { onmessage: null, postMessage: vi.fn() }, + disconnect: vi.fn(), + })) + + mockGetUserMedia = vi.fn().mockResolvedValue({ + getTracks: () => [mockTrack], + getAudioTracks: () => [mockTrack], + }) + Object.defineProperty(navigator, 'mediaDevices', { + value: { getUserMedia: mockGetUserMedia }, + writable: true, + configurable: true, + }) + }) + + afterEach(() => { + window.AudioContext = originalAudioContext + ;(window as any).AudioWorkletNode = originalAudioWorkletNode + if (originalGetUserMedia) { + Object.defineProperty(navigator, 'mediaDevices', { + value: { getUserMedia: originalGetUserMedia }, + writable: true, + configurable: true, + }) + } + }) + + it('prepares the audio context and worklet without requesting microphone access', async () => { + const recorder = new AudioRecorder() + await recorder.prepare() + + expect(mockGetUserMedia).not.toHaveBeenCalled() + expect(MockAudioContext).toHaveBeenCalledTimes(1) + expect(mockAddModule).toHaveBeenCalledOnce() + expect(mockAddModule).toHaveBeenCalledWith('/audio-worklet-processor.js') + }) + + it('reuses the same AudioContext and worklet when prepare() precedes start()', async () => { + const recorder = new AudioRecorder() + await recorder.prepare() + + mockAddModule.mockClear() + + await recorder.start() + + expect(mockAddModule).not.toHaveBeenCalled() + expect(MockAudioContext).toHaveBeenCalledTimes(1) + expect(mockGetUserMedia).toHaveBeenCalledTimes(1) + + recorder.stop() + }) + + it('reuses the prepared audio context and loaded worklet across recordings', async () => { + const recorder = new AudioRecorder() + + await recorder.start() + recorder.stop() + + await recorder.start() + recorder.stop() + + recorder.dispose() + + expect(mockGetUserMedia).toHaveBeenCalledTimes(2) + expect(MockAudioContext).toHaveBeenCalledTimes(1) + expect(mockAddModule).toHaveBeenCalledTimes(1) + expect(mockTrack.stop).toHaveBeenCalledTimes(2) + expect(mockClose).toHaveBeenCalledTimes(1) + }) +}) + +describe('AudioRecorder lifecycle cancellation', () => { + let originalAudioContext: typeof window.AudioContext + let originalAudioWorkletNode: unknown + let originalGetUserMedia: (typeof navigator.mediaDevices)['getUserMedia'] | undefined + let mockTrack: { stop: ReturnType; kind: string } + + beforeEach(() => { + originalAudioContext = window.AudioContext + originalAudioWorkletNode = (window as any).AudioWorkletNode + originalGetUserMedia = navigator.mediaDevices?.getUserMedia + + mockTrack = { stop: vi.fn(), kind: 'audio' } + + const MockAudioContext = vi.fn().mockImplementation(() => ({ + state: 'running', + sampleRate: 16000, + audioWorklet: { addModule: vi.fn().mockResolvedValue(undefined) }, + createMediaStreamSource: vi.fn(), + createScriptProcessor: vi.fn(), + resume: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + })) + + window.AudioContext = MockAudioContext as unknown as typeof window.AudioContext + ;(window as any).AudioWorkletNode = vi.fn().mockImplementation(() => ({ + port: { onmessage: null, postMessage: vi.fn() }, + disconnect: vi.fn(), + })) + }) + + afterEach(() => { + window.AudioContext = originalAudioContext + ;(window as any).AudioWorkletNode = originalAudioWorkletNode + if (originalGetUserMedia) { + Object.defineProperty(navigator, 'mediaDevices', { + value: { getUserMedia: originalGetUserMedia }, + writable: true, + configurable: true, + }) + } + }) + + it('cleans up and does not enter recording state when dispose is called during async startup', async () => { + let resolveGetUserMedia: (stream: MediaStream) => void + const deferredGetUserMedia = new Promise((resolve) => { + resolveGetUserMedia = resolve + }) + + const mockGetUserMedia = vi.fn().mockReturnValue(deferredGetUserMedia) + Object.defineProperty(navigator, 'mediaDevices', { + value: { getUserMedia: mockGetUserMedia }, + writable: true, + configurable: true, + }) + + const recorder = new AudioRecorder() + + const startPromise = recorder.start() + + recorder.dispose() + + const stream = { getTracks: () => [mockTrack], getAudioTracks: () => [mockTrack] } as unknown as MediaStream + resolveGetUserMedia!(stream) + + await startPromise + + expect(recorder.getState()).toBe('idle') + expect(mockTrack.stop).toHaveBeenCalled() + expect(window.AudioContext).not.toHaveBeenCalled() + }) +}) + +describe('AudioRecorder voice activity detection', () => { + const SAMPLE_RATE = 16000 + const msToSamples = (ms: number): number => Math.round((ms / 1000) * SAMPLE_RATE) + + let originalAudioContext: typeof window.AudioContext + let originalAudioWorkletNode: unknown + let originalGetUserMedia: (typeof navigator.mediaDevices)['getUserMedia'] | undefined + let mockWorkletNode: { port: { onmessage: ((e: MessageEvent) => void) | null; postMessage: ReturnType }; disconnect: ReturnType } + + type Frame = { samples: Int16Array; rms: number } + const feed = (rms: number, ms: number): void => { + const frame: Frame = { samples: new Int16Array(msToSamples(ms)), rms } + mockWorkletNode.port.onmessage?.({ data: frame } as MessageEvent) + } + + beforeEach(() => { + originalAudioContext = window.AudioContext + originalAudioWorkletNode = (window as any).AudioWorkletNode + originalGetUserMedia = navigator.mediaDevices?.getUserMedia + + const mockSource = { connect: vi.fn(), disconnect: vi.fn() } + mockWorkletNode = { + port: { onmessage: null, postMessage: vi.fn() }, + disconnect: vi.fn(), + } + + const MockAudioContext = vi.fn().mockImplementation(() => ({ + state: 'running', + sampleRate: SAMPLE_RATE, + audioWorklet: { addModule: vi.fn().mockResolvedValue(undefined) }, + createMediaStreamSource: vi.fn().mockReturnValue(mockSource), + createScriptProcessor: vi.fn(), + resume: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + })) + window.AudioContext = MockAudioContext as unknown as typeof window.AudioContext + + ;(window as any).AudioWorkletNode = vi.fn().mockImplementation(() => mockWorkletNode) + + const mockTrack = { stop: vi.fn(), kind: 'audio' } + Object.defineProperty(navigator, 'mediaDevices', { + value: { + getUserMedia: vi.fn().mockResolvedValue({ + getTracks: () => [mockTrack], + getAudioTracks: () => [mockTrack], + }), + }, + writable: true, + configurable: true, + }) + }) + + afterEach(() => { + window.AudioContext = originalAudioContext + ;(window as any).AudioWorkletNode = originalAudioWorkletNode + if (originalGetUserMedia) { + Object.defineProperty(navigator, 'mediaDevices', { + value: { getUserMedia: originalGetUserMedia }, + writable: true, + configurable: true, + }) + } + }) + + it('does not emit audio and signals no-speech when the recording is silent', async () => { + const onDataAvailable = vi.fn() + const onNoSpeech = vi.fn() + const recorder = new AudioRecorder() + recorder.setOnDataAvailable(onDataAvailable) + recorder.setOnNoSpeech(onNoSpeech) + + await recorder.start() + for (let i = 0; i < 5; i++) { + feed(0.0005, 100) + } + recorder.stop() + + expect(onNoSpeech).toHaveBeenCalledTimes(1) + expect(onDataAvailable).not.toHaveBeenCalled() + expect(recorder.getState()).toBe('stopped') + }) + + it('signals no-speech when stopped before any audio frame is captured', async () => { + const onDataAvailable = vi.fn() + const onNoSpeech = vi.fn() + const recorder = new AudioRecorder() + recorder.setOnDataAvailable(onDataAvailable) + recorder.setOnNoSpeech(onNoSpeech) + + await recorder.start() + recorder.stop() + + expect(onNoSpeech).toHaveBeenCalledTimes(1) + expect(onDataAvailable).not.toHaveBeenCalled() + expect(recorder.getState()).toBe('stopped') + }) + + it('emits audio when speech is detected', async () => { + const onDataAvailable = vi.fn() + const onNoSpeech = vi.fn() + const recorder = new AudioRecorder() + recorder.setOnDataAvailable(onDataAvailable) + recorder.setOnNoSpeech(onNoSpeech) + + await recorder.start() + for (let i = 0; i < 3; i++) { + feed(0.2, 100) + } + recorder.stop() + + expect(onDataAvailable).toHaveBeenCalledTimes(1) + expect(onNoSpeech).not.toHaveBeenCalled() + }) + + it('auto-stops after trailing silence once speech has been detected', async () => { + const onDataAvailable = vi.fn() + const recorder = new AudioRecorder({ vad: { minSpeechMs: 50, silenceTimeoutMs: 200 } }) + recorder.setOnDataAvailable(onDataAvailable) + + await recorder.start() + feed(0.2, 100) + feed(0.0005, 100) + feed(0.0005, 100) + + expect(recorder.getState()).toBe('stopped') + expect(onDataAvailable).toHaveBeenCalledTimes(1) + expect(mockWorkletNode.port.onmessage).toBeNull() + }) +}) diff --git a/frontend/src/lib/audioRecorder.ts b/frontend/src/lib/audioRecorder.ts index 526eacd6..67788d90 100644 --- a/frontend/src/lib/audioRecorder.ts +++ b/frontend/src/lib/audioRecorder.ts @@ -1,8 +1,11 @@ +import { VoiceActivityDetector, type VadOptions } from './voiceActivityDetector' + export type AudioRecorderState = 'idle' | 'recording' | 'stopped' | 'error' export interface AudioRecorderOptions { sampleRate?: number channelCount?: number + vad?: Partial> } const DEFAULT_OPTIONS: AudioRecorderOptions = { @@ -49,6 +52,18 @@ export function downsampleAndConvert(input: Float32Array, inputRate: number, tar return output } +export function computeRms(samples: Int16Array): number { + if (samples.length === 0) { + return 0 + } + let sumSquares = 0 + for (let i = 0; i < samples.length; i++) { + const normalized = samples[i] / 32768 + sumSquares += normalized * normalized + } + return Math.sqrt(sumSquares / samples.length) +} + export function encodeWavFromInt16(samples: Int16Array, sampleRate: number, channels: number): Blob { const dataLength = samples.length * 2 const bufferSize = 44 + dataLength @@ -91,10 +106,12 @@ export class AudioRecorder { private state: AudioRecorderState = 'idle' private options: AudioRecorderOptions private isAborted: boolean = false + private vad: VoiceActivityDetector | null = null private onStateChange?: (state: AudioRecorderState) => void private onError?: (error: string) => void private onDataAvailable?: (blob: Blob) => void + private onNoSpeech?: () => void constructor(options: AudioRecorderOptions = {}) { this.options = { ...DEFAULT_OPTIONS, ...options } @@ -110,6 +127,32 @@ export class AudioRecorder { ) } + async prepare(): Promise { + if (!AudioRecorder.isSupported()) { + throw new Error('Audio recording is not supported in this browser') + } + + const ctx = this.getReusableAudioContext() + + if (ctx.audioWorklet) { + await ensureWorkletLoaded(ctx) + } + + if (ctx.state === 'suspended') { + await ctx.resume() + } + } + + private getReusableAudioContext(): AudioContext { + if (this.audioContext && this.audioContext.state !== 'closed') { + return this.audioContext + } + this.audioContext = new AudioContext({ + sampleRate: this.options.sampleRate, + }) + return this.audioContext + } + getState(): AudioRecorderState { return this.state } @@ -126,6 +169,10 @@ export class AudioRecorder { this.onDataAvailable = callback } + setOnNoSpeech(callback: () => void): void { + this.onNoSpeech = callback + } + private setState(newState: AudioRecorderState): void { this.state = newState this.onStateChange?.(newState) @@ -142,6 +189,10 @@ export class AudioRecorder { this.isAborted = false this.chunks = [] this.totalSamples = 0 + this.vad = new VoiceActivityDetector({ + sampleRate: this.options.sampleRate ?? 16000, + ...this.options.vad, + }) this.mediaStream = await navigator.mediaDevices.getUserMedia({ audio: { @@ -151,38 +202,43 @@ export class AudioRecorder { }, }) - this.audioContext = new AudioContext({ - sampleRate: this.options.sampleRate, - }) + if (this.isAborted) { + this.cleanupRecording(true) + return + } + + await this.prepare() - this.source = this.audioContext.createMediaStreamSource(this.mediaStream) - - if (this.audioContext.audioWorklet) { - try { - await ensureWorkletLoaded(this.audioContext) - this.workletNode = new AudioWorkletNode(this.audioContext, 'recorder-processor', { - processorOptions: { targetSampleRate: this.options.sampleRate }, - }) - this.workletNode.port.onmessage = (e: MessageEvent) => { - this.chunks.push(e.data) - this.totalSamples += e.data.length - } - this.source.connect(this.workletNode) - } catch (error) { - this.audioContext.close() - this.audioContext = null - throw new Error('Failed to load audio worklet processor', { cause: error }) + if (this.isAborted) { + this.cleanupRecording(true) + return + } + + const ctx = this.audioContext! + this.source = ctx.createMediaStreamSource(this.mediaStream) + + if (ctx.audioWorklet) { + this.workletNode = new AudioWorkletNode(ctx, 'recorder-processor', { + processorOptions: { targetSampleRate: this.options.sampleRate }, + }) + this.workletNode.port.onmessage = (e: MessageEvent<{ samples: Int16Array; rms: number }>) => { + const { samples, rms } = e.data + this.chunks.push(samples) + this.totalSamples += samples.length + this.handleVadFrame(rms, samples.length) } - } else if (this.audioContext) { + this.source.connect(this.workletNode) + } else { const bufferSize = 4096 - this.processor = this.audioContext.createScriptProcessor(bufferSize, 1, 1) + this.processor = ctx.createScriptProcessor(bufferSize, 1, 1) const targetRate = this.options.sampleRate ?? 16000 - const inputRate = this.audioContext.sampleRate + const inputRate = ctx.sampleRate this.processor.onaudioprocess = (e) => { const inputData = e.inputBuffer.getChannelData(0) const int16Chunk = downsampleAndConvert(inputData, inputRate, targetRate) this.chunks.push(int16Chunk) this.totalSamples += int16Chunk.length + this.handleVadFrame(computeRms(int16Chunk), int16Chunk.length) } this.source.connect(this.processor) } @@ -190,7 +246,7 @@ export class AudioRecorder { this.setState('recording') } catch (error) { this.setState('error') - this.cleanup() + this.cleanupRecording(true) if (error instanceof DOMException) { if (error.name === 'NotAllowedError') { @@ -213,19 +269,45 @@ export class AudioRecorder { this.processRecording() } this.resetRecordingState() - this.cleanup() + this.cleanupRecording(false) this.setState('stopped') } abort(): void { + this.teardown(false) + } + + dispose(): void { + this.teardown(true) + } + + private teardown(closeAudioContext: boolean): void { this.isAborted = true this.resetRecordingState() - this.cleanup() + this.cleanupRecording(closeAudioContext) this.setState('idle') } + private handleVadFrame(rms: number, frameSamples: number): void { + if (!this.vad) { + return + } + const { shouldAutoStop } = this.vad.process(rms, frameSamples) + if (shouldAutoStop) { + this.stop() + } + } + private processRecording(): void { - if (this.isAborted || this.chunks.length === 0 || this.totalSamples === 0) { + if (this.isAborted) { + return + } + + const hasAudio = this.chunks.length > 0 && this.totalSamples > 0 + const hasSpeech = !this.vad || this.vad.hasSpeech + + if (!hasAudio || !hasSpeech) { + this.onNoSpeech?.() return } @@ -244,7 +326,7 @@ export class AudioRecorder { } } - private cleanup(): void { + private cleanupRecording(closeAudioContext: boolean): void { if (this.workletNode) { this.workletNode.port.onmessage = null this.workletNode.port.postMessage('stop') @@ -263,15 +345,15 @@ export class AudioRecorder { this.source = null } - if (this.audioContext && this.audioContext.state !== 'closed') { - this.audioContext.close() - this.audioContext = null - } - if (this.mediaStream) { this.mediaStream.getTracks().forEach(track => track.stop()) this.mediaStream = null } + + if (closeAudioContext && this.audioContext && this.audioContext.state !== 'closed') { + this.audioContext.close() + this.audioContext = null + } } private resetRecordingState(): void { diff --git a/frontend/src/lib/voiceActivityDetector.test.ts b/frontend/src/lib/voiceActivityDetector.test.ts new file mode 100644 index 00000000..1327788b --- /dev/null +++ b/frontend/src/lib/voiceActivityDetector.test.ts @@ -0,0 +1,94 @@ +import { describe, it, expect } from 'vitest' +import { VoiceActivityDetector } from './voiceActivityDetector' + +const SAMPLE_RATE = 16000 +const msToSamples = (ms: number): number => Math.round((ms / 1000) * SAMPLE_RATE) + +describe('VoiceActivityDetector', () => { + it('classifies loud frames as speech and quiet frames as silence', () => { + const vad = new VoiceActivityDetector({ sampleRate: SAMPLE_RATE }) + + expect(vad.process(0.2, msToSamples(100)).isSpeech).toBe(true) + expect(vad.process(0.0005, msToSamples(100)).isSpeech).toBe(false) + }) + + it('reports hasSpeech only after cumulative speech exceeds minSpeechMs', () => { + const vad = new VoiceActivityDetector({ sampleRate: SAMPLE_RATE, minSpeechMs: 150 }) + + vad.process(0.2, msToSamples(100)) + expect(vad.hasSpeech).toBe(false) + + vad.process(0.2, msToSamples(100)) + expect(vad.hasSpeech).toBe(true) + }) + + it('does not flag pure silence as speech', () => { + const vad = new VoiceActivityDetector({ sampleRate: SAMPLE_RATE }) + + for (let i = 0; i < 10; i++) { + vad.process(0.0008, msToSamples(100)) + } + + expect(vad.hasSpeech).toBe(false) + }) + + it('auto-stops after trailing silence once speech has started', () => { + const vad = new VoiceActivityDetector({ + sampleRate: SAMPLE_RATE, + minSpeechMs: 50, + silenceTimeoutMs: 200, + }) + + expect(vad.process(0.2, msToSamples(100)).shouldAutoStop).toBe(false) + expect(vad.process(0.0005, msToSamples(100)).shouldAutoStop).toBe(false) + expect(vad.process(0.0005, msToSamples(100)).shouldAutoStop).toBe(true) + }) + + it('does not auto-stop during leading silence before any speech', () => { + const vad = new VoiceActivityDetector({ + sampleRate: SAMPLE_RATE, + silenceTimeoutMs: 200, + }) + + for (let i = 0; i < 10; i++) { + expect(vad.process(0.0005, msToSamples(100)).shouldAutoStop).toBe(false) + } + }) + + it('never auto-stops when silenceTimeoutMs is 0', () => { + const vad = new VoiceActivityDetector({ + sampleRate: SAMPLE_RATE, + minSpeechMs: 50, + silenceTimeoutMs: 0, + }) + + vad.process(0.2, msToSamples(100)) + for (let i = 0; i < 20; i++) { + expect(vad.process(0.0005, msToSamples(100)).shouldAutoStop).toBe(false) + } + }) + + it('resets accumulated speech and silence state', () => { + const vad = new VoiceActivityDetector({ sampleRate: SAMPLE_RATE, minSpeechMs: 50 }) + + vad.process(0.2, msToSamples(100)) + expect(vad.hasSpeech).toBe(true) + + vad.reset() + expect(vad.hasSpeech).toBe(false) + expect(vad.process(0.0005, msToSamples(100)).shouldAutoStop).toBe(false) + }) + + it('resets trailing silence when speech resumes', () => { + const vad = new VoiceActivityDetector({ + sampleRate: SAMPLE_RATE, + minSpeechMs: 50, + silenceTimeoutMs: 200, + }) + + vad.process(0.2, msToSamples(100)) + vad.process(0.0005, msToSamples(100)) + vad.process(0.2, msToSamples(100)) + expect(vad.process(0.0005, msToSamples(100)).shouldAutoStop).toBe(false) + }) +}) diff --git a/frontend/src/lib/voiceActivityDetector.ts b/frontend/src/lib/voiceActivityDetector.ts new file mode 100644 index 00000000..3197a048 --- /dev/null +++ b/frontend/src/lib/voiceActivityDetector.ts @@ -0,0 +1,73 @@ +export interface VadOptions { + sampleRate: number + silenceFloor: number + speechMultiplier: number + silenceTimeoutMs: number + minSpeechMs: number + noiseFloorSmoothing: number +} + +export const DEFAULT_VAD_OPTIONS: Omit = { + silenceFloor: 0.008, + speechMultiplier: 2.5, + silenceTimeoutMs: 1500, + minSpeechMs: 150, + noiseFloorSmoothing: 0.95, +} + +export interface VadFrameResult { + isSpeech: boolean + shouldAutoStop: boolean +} + +export class VoiceActivityDetector { + private readonly options: VadOptions + private noiseFloor: number + private speechSamples = 0 + private trailingSilenceSamples = 0 + private speechStarted = false + + constructor(options: Partial & Pick) { + this.options = { ...DEFAULT_VAD_OPTIONS, ...options } + this.noiseFloor = this.options.silenceFloor + } + + process(rms: number, frameSamples: number): VadFrameResult { + const { silenceFloor, speechMultiplier, noiseFloorSmoothing, silenceTimeoutMs } = this.options + const threshold = Math.max(silenceFloor, this.noiseFloor * speechMultiplier) + const isSpeech = rms >= threshold + + if (isSpeech) { + this.speechStarted = true + this.speechSamples += frameSamples + this.trailingSilenceSamples = 0 + } else { + this.noiseFloor = noiseFloorSmoothing * this.noiseFloor + (1 - noiseFloorSmoothing) * rms + if (this.speechStarted) { + this.trailingSilenceSamples += frameSamples + } + } + + const shouldAutoStop = + silenceTimeoutMs > 0 && + this.speechStarted && + this.samplesToMs(this.trailingSilenceSamples) >= silenceTimeoutMs + + return { isSpeech, shouldAutoStop } + } + + get hasSpeech(): boolean { + return this.samplesToMs(this.speechSamples) >= this.options.minSpeechMs + } + + reset(): void { + this.noiseFloor = this.options.silenceFloor + this.speechSamples = 0 + this.trailingSilenceSamples = 0 + this.speechStarted = false + } + + private samplesToMs(samples: number): number { + return (samples / this.options.sampleRate) * 1000 + } +}