From 158b6ae840e818b9ecfc4ca619cde9bbf1a42c8a Mon Sep 17 00:00:00 2001 From: Boris Dibon Date: Wed, 3 Jun 2026 13:29:36 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20webSocketObservable=20for=20W?= =?UTF-8?q?ebSocket=20lifecycle=20events?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Cursor --- .../src/browser/addEventListener.ts | 4 +- .../src/browser/webSocketObservable.spec.ts | 356 ++++++++++++++++++ .../src/browser/webSocketObservable.ts | 200 ++++++++++ packages/browser-core/src/index.ts | 9 + packages/browser-core/test/forEach.spec.ts | 2 + 5 files changed, 570 insertions(+), 1 deletion(-) create mode 100644 packages/browser-core/src/browser/webSocketObservable.spec.ts create mode 100644 packages/browser-core/src/browser/webSocketObservable.ts diff --git a/packages/browser-core/src/browser/addEventListener.ts b/packages/browser-core/src/browser/addEventListener.ts index 1fca98d267..c2e01df885 100644 --- a/packages/browser-core/src/browser/addEventListener.ts +++ b/packages/browser-core/src/browser/addEventListener.ts @@ -79,7 +79,9 @@ type EventMapFor = T extends Window ? WorkerEventMap : T extends CookieStore ? CookieStoreEventMap - : Record + : T extends WebSocket + ? WebSocketEventMap + : Record /** * Add an event listener to an event target object (Window, Element, mock object...). This provides diff --git a/packages/browser-core/src/browser/webSocketObservable.spec.ts b/packages/browser-core/src/browser/webSocketObservable.spec.ts new file mode 100644 index 0000000000..deae0833c2 --- /dev/null +++ b/packages/browser-core/src/browser/webSocketObservable.spec.ts @@ -0,0 +1,356 @@ +import { registerCleanupTask } from '../../test' +import type { Subscription } from '../tools/observable' +import type { WebSocketContext } from './webSocketObservable' +import { initWebSocketObservable, resetWebSocketObservable } from './webSocketObservable' + +// A minimal stand-in for the native `WebSocket` constructor. We do not connect to a real server in +// unit tests; instead we expose helpers to simulate the browser dispatching events on the instance. +class FakeWebSocket extends EventTarget { + static readonly CONNECTING = 0 + static readonly OPEN = 1 + static readonly CLOSING = 2 + static readonly CLOSED = 3 + + url: string + protocol = '' + bufferedAmount = 0 + readyState: number = FakeWebSocket.CONNECTING + onmessage: ((event: MessageEvent) => void) | null = null + onopen: ((event: Event) => void) | null = null + onclose: ((event: CloseEvent) => void) | null = null + + constructor(url: string | URL, protocols?: string | string[]) { + super() + this.url = String(url) + if (typeof protocols === 'string') { + this.protocol = protocols + } + } + + send(_data: string | ArrayBufferLike | Blob | ArrayBufferView): void { + // no-op; tests will set `bufferedAmount` before calling send to verify it is sampled. + } + + close(_code?: number, _reason?: string): void { + this.readyState = FakeWebSocket.CLOSED + } + + simulateOpen() { + this.readyState = FakeWebSocket.OPEN + const event = new Event('open') + this.dispatchEvent(event) + this.onopen?.(event) + } + + simulateMessage(data: unknown) { + const event = new MessageEvent('message', { data }) + this.dispatchEvent(event) + this.onmessage?.(event) + } + + simulateClose(code: number, reason: string, wasClean: boolean) { + this.readyState = FakeWebSocket.CLOSED + // CloseEvent is not always constructable in test environments; use a plain Event with assigned fields. + const event = Object.assign(new Event('close'), { code, reason, wasClean }) as CloseEvent + this.dispatchEvent(event) + this.onclose?.(event) + } +} + +type FakeWebSocketConstructor = typeof FakeWebSocket + +const windowAsWebSocketHost = window as unknown as { WebSocket: FakeWebSocketConstructor } + +describe('webSocketObservable', () => { + let originalWebSocket: FakeWebSocketConstructor + let contexts: WebSocketContext[] + let subscription: Subscription | undefined + + beforeEach(() => { + originalWebSocket = windowAsWebSocketHost.WebSocket + windowAsWebSocketHost.WebSocket = FakeWebSocket + contexts = [] + + registerCleanupTask(() => { + subscription?.unsubscribe() + subscription = undefined + resetWebSocketObservable() + windowAsWebSocketHost.WebSocket = originalWebSocket + }) + }) + + function startTracking() { + subscription = initWebSocketObservable({ allowUntrustedEvents: true }).subscribe((context) => { + contexts.push(context) + }) + } + + function getContexts(state: T) { + return contexts.filter((context): context is Extract => context.state === state) + } + + describe('when tracking is started', () => { + beforeEach(() => { + startTracking() + }) + + describe('connecting context', () => { + it('emits a "connecting" context when a WebSocket is constructed', () => { + const url = 'wss://example.com/socket' + const ws = new windowAsWebSocketHost.WebSocket(url) + + const connectingContexts = getContexts('connecting') + expect(connectingContexts.length).toBe(1) + expect(connectingContexts[0].url).toBe(url) + expect(connectingContexts[0].instance).toBe(ws as unknown as WebSocket) + expect(connectingContexts[0].startClocks.timeStamp).toEqual(jasmine.any(Number)) + }) + + it('coerces URL objects to strings in the "connecting" context', () => { + const url = 'wss://example.com/socket' + new windowAsWebSocketHost.WebSocket(new URL(url)) + + expect(getContexts('connecting')[0].url).toBe(url) + }) + + it('does not include protocols in the "connecting" context when omitted', () => { + new windowAsWebSocketHost.WebSocket('wss://example.com/socket') + + expect(getContexts('connecting')[0].protocols).toBeUndefined() + }) + + it('includes string protocols in the "connecting" context', () => { + const url = 'wss://example.com/socket' + const protocols = 'chat.v1' + new windowAsWebSocketHost.WebSocket(url, protocols) + + expect(getContexts('connecting')[0].protocols).toBe(protocols) + }) + + it('includes array protocols in the "connecting" context', () => { + const url = 'wss://example.com/socket' + const protocols = ['chat.v1', 'json'] + new windowAsWebSocketHost.WebSocket(url, protocols) + + expect(getContexts('connecting')[0].protocols).toEqual(protocols) + }) + }) + + describe('preservation of native behavior', () => { + it('does not clobber a customer-set onmessage handler', () => { + const ws = new windowAsWebSocketHost.WebSocket('wss://example.com/socket') + const customerHandler = jasmine.createSpy() + ws.onmessage = customerHandler + + ws.simulateMessage('hello') + + expect(customerHandler).toHaveBeenCalledTimes(1) + expect(getContexts('message-in').length).toBe(1) + }) + + it('does not clobber a customer-set onopen handler', () => { + const ws = new windowAsWebSocketHost.WebSocket('wss://example.com/socket') + const customerHandler = jasmine.createSpy() + ws.onopen = customerHandler + + ws.simulateOpen() + + expect(customerHandler).toHaveBeenCalledTimes(1) + expect(getContexts('open').length).toBe(1) + }) + + it('does not clobber a customer-set onclose handler', () => { + const ws = new windowAsWebSocketHost.WebSocket('wss://example.com/socket') + const customerHandler = jasmine.createSpy() + ws.onclose = customerHandler + + ws.simulateClose(1000, 'bye', true) + + expect(customerHandler).toHaveBeenCalledTimes(1) + expect(getContexts('closed').length).toBe(1) + }) + }) + + describe('open context', () => { + it('emits an "open" context when the WebSocket opens', () => { + const ws = new windowAsWebSocketHost.WebSocket('wss://example.com/socket') + const negotiatedProtocol = 'chat.v1' + ws.protocol = negotiatedProtocol + ws.simulateOpen() + + const openContexts = getContexts('open') + expect(openContexts.length).toBe(1) + expect(openContexts[0].protocol).toBe(negotiatedProtocol) + expect(openContexts[0].instance).toBe(ws as unknown as WebSocket) + expect(openContexts[0].openClocks.timeStamp).toEqual(jasmine.any(Number)) + }) + + it('emits an "open" context with empty protocol when no sub-protocol negotiated', () => { + const ws = new windowAsWebSocketHost.WebSocket('wss://example.com/socket') + ws.simulateOpen() + + const openContexts = getContexts('open') + expect(openContexts.length).toBe(1) + expect(openContexts[0].protocol).toBe('') + }) + }) + + describe('message-in context', () => { + it('emits "message-in" with byte-length size for string payloads', () => { + const ws = new windowAsWebSocketHost.WebSocket('wss://example.com/socket') + ws.simulateOpen() + const payload = 'hello world' + ws.simulateMessage(payload) + + const messageInContexts = getContexts('message-in') + expect(messageInContexts.length).toBe(1) + expect(messageInContexts[0].size).toBe(payload.length) + }) + + it('emits "message-in" with UTF-8 byte length for multi-byte strings', () => { + const ws = new windowAsWebSocketHost.WebSocket('wss://example.com/socket') + ws.simulateOpen() + // 'é' is 2 bytes in UTF-8 and 'あ' is 3 bytes; total is 5 bytes for 2 chars + const payload = 'éあ' + ws.simulateMessage(payload) + + expect(getContexts('message-in')[0].size).toBe(new TextEncoder().encode(payload).byteLength) + }) + + it('emits "message-in" with byteLength for ArrayBuffer payloads', () => { + const ws = new windowAsWebSocketHost.WebSocket('wss://example.com/socket') + ws.simulateOpen() + const byteLength = 16 + ws.simulateMessage(new ArrayBuffer(byteLength)) + + expect(getContexts('message-in')[0].size).toBe(byteLength) + }) + + it('emits "message-in" with byteLength for ArrayBufferView payloads', () => { + const ws = new windowAsWebSocketHost.WebSocket('wss://example.com/socket') + ws.simulateOpen() + const viewByteLength = 12 + ws.simulateMessage(new Uint8Array(new ArrayBuffer(32), 4, viewByteLength)) + + expect(getContexts('message-in')[0].size).toBe(viewByteLength) + }) + + it('emits "message-in" with size for Blob payloads', () => { + const ws = new windowAsWebSocketHost.WebSocket('wss://example.com/socket') + ws.simulateOpen() + const blob = new Blob(['hello']) + ws.simulateMessage(blob) + + expect(getContexts('message-in')[0].size).toBe(blob.size) + }) + }) + + describe('message-out context', () => { + it('emits "message-out" with size and bufferedAmountPreSend for string payloads', () => { + const ws = new windowAsWebSocketHost.WebSocket('wss://example.com/socket') + const bufferedAmountPreSend = 42 + ws.bufferedAmount = bufferedAmountPreSend + const payload = 'hello' + ws.send(payload) + + const messageOutContexts = getContexts('message-out') + expect(messageOutContexts.length).toBe(1) + expect(messageOutContexts[0].size).toBe(payload.length) + expect(messageOutContexts[0].bufferedAmountPreSend).toBe(bufferedAmountPreSend) + expect(messageOutContexts[0].at.timeStamp).toEqual(jasmine.any(Number)) + }) + + it('emits "message-out" with byteLength for ArrayBuffer payloads', () => { + const ws = new windowAsWebSocketHost.WebSocket('wss://example.com/socket') + const byteLength = 8 + ws.send(new ArrayBuffer(byteLength)) + + expect(getContexts('message-out')[0].size).toBe(byteLength) + }) + + it('emits "message-out" with byteLength for ArrayBufferView payloads', () => { + const ws = new windowAsWebSocketHost.WebSocket('wss://example.com/socket') + const viewByteLength = 10 + ws.send(new Uint8Array(new ArrayBuffer(20), 2, viewByteLength)) + + expect(getContexts('message-out')[0].size).toBe(viewByteLength) + }) + + it('emits "message-out" with size for Blob payloads', () => { + const ws = new windowAsWebSocketHost.WebSocket('wss://example.com/socket') + const blob = new Blob(['hello world']) + ws.send(blob) + + expect(getContexts('message-out')[0].size).toBe(blob.size) + }) + }) + + describe('closed context', () => { + it('emits a "closed" context with code, reason, and wasClean', () => { + const ws = new windowAsWebSocketHost.WebSocket('wss://example.com/socket') + const closeCode = 1000 + const closeReason = 'bye' + const wasClean = true + ws.simulateClose(closeCode, closeReason, wasClean) + + const closeContexts = getContexts('closed') + expect(closeContexts.length).toBe(1) + expect(closeContexts[0].code).toBe(closeCode) + expect(closeContexts[0].reason).toBe(closeReason) + expect(closeContexts[0].wasClean).toBe(wasClean) + expect(closeContexts[0].at.timeStamp).toEqual(jasmine.any(Number)) + }) + }) + + describe('subscription lifecycle', () => { + it('restores the native WebSocket constructor when all subscribers unsubscribe', () => { + subscription?.unsubscribe() + subscription = undefined + + expect(windowAsWebSocketHost.WebSocket).toBe(FakeWebSocket) + }) + + it('does not emit any further events after all subscribers unsubscribe', () => { + subscription?.unsubscribe() + subscription = undefined + + const ws = new windowAsWebSocketHost.WebSocket('wss://example.com/socket') + ws.simulateOpen() + ws.send('hello') + + expect(contexts.length).toBe(0) + }) + }) + }) + + describe('with conflicting allowUntrustedEvents policies across callers', () => { + it('does not emit open or message-in for untrusted events when the customer disallows them', () => { + initWebSocketObservable({ allowUntrustedEvents: true }) + subscription = initWebSocketObservable({ allowUntrustedEvents: false }).subscribe((context) => { + contexts.push(context) + }) + + const ws = new windowAsWebSocketHost.WebSocket('wss://example.com/socket') + ws.simulateOpen() + ws.simulateMessage('hello') + + expect(getContexts('connecting').length).toBe(1) + expect(getContexts('open').length).toBe(0) + expect(getContexts('message-in').length).toBe(0) + }) + + it('emits open and message-in for untrusted events when every caller allows them', () => { + initWebSocketObservable({ allowUntrustedEvents: true }) + subscription = initWebSocketObservable({ allowUntrustedEvents: true }).subscribe((context) => { + contexts.push(context) + }) + + const ws = new windowAsWebSocketHost.WebSocket('wss://example.com/socket') + ws.simulateOpen() + ws.simulateMessage('hello') + + expect(getContexts('open').length).toBe(1) + expect(getContexts('message-in').length).toBe(1) + }) + }) +}) diff --git a/packages/browser-core/src/browser/webSocketObservable.ts b/packages/browser-core/src/browser/webSocketObservable.ts new file mode 100644 index 0000000000..e42db0b1fe --- /dev/null +++ b/packages/browser-core/src/browser/webSocketObservable.ts @@ -0,0 +1,200 @@ +import type { ClocksState } from '@datadog/js-core/time' +import { clocksNow } from '@datadog/js-core/time' +import type { GlobalObject } from '../tools/globalObject' +import { globalObject } from '../tools/globalObject' +import { instrumentConstructor, instrumentMethod } from '../tools/instrumentMethod' +import { Observable } from '../tools/observable' +import { addEventListener } from './addEventListener' + +interface WebSocketObservableConfiguration { + allowUntrustedEvents?: boolean | undefined +} + +type GlobalWithWebSocket = GlobalObject & { WebSocket: typeof WebSocket } + +function isGlobalWithWebSocket(global: GlobalObject): global is GlobalWithWebSocket { + return typeof (global as { WebSocket?: unknown }).WebSocket === 'function' +} + +export interface WebSocketConnectingContext { + state: 'connecting' + instance: WebSocket + url: string + protocols?: string | string[] + startClocks: ClocksState +} + +export interface WebSocketOpenContext { + state: 'open' + instance: WebSocket + openClocks: ClocksState + protocol: string +} + +export interface WebSocketMessageInContext { + state: 'message-in' + instance: WebSocket + size: number + at: ClocksState +} + +export interface WebSocketMessageOutContext { + state: 'message-out' + instance: WebSocket + size: number + bufferedAmountPreSend: number + at: ClocksState +} + +export interface WebSocketClosedContext { + state: 'closed' + instance: WebSocket + code: number + reason: string + wasClean: boolean + at: ClocksState +} + +export type WebSocketContext = + | WebSocketConnectingContext + | WebSocketOpenContext + | WebSocketMessageInContext + | WebSocketMessageOutContext + | WebSocketClosedContext + +let webSocketObservable: Observable | undefined + +// The singleton WebSocket observable applies the latest caller's allowUntrustedEvents policy so +// that the customer's configuration overrides an early call (e.g. from bufferedData) that opts +// in before the customer config is parsed. +let allowUntrustedEvents: boolean | undefined + +export function initWebSocketObservable( + configuration: WebSocketObservableConfiguration = {} +): Observable { + if (configuration.allowUntrustedEvents !== undefined) { + allowUntrustedEvents = configuration.allowUntrustedEvents + } + + if (!webSocketObservable) { + webSocketObservable = createWebSocketObservable() + } + + return webSocketObservable +} + +function createWebSocketObservable() { + return new Observable((observable) => { + if (!isGlobalWithWebSocket(globalObject)) { + return undefined + } + + const stopListeners: Array<() => void> = [] + + const { stop: stopInstrumentingConstructor } = instrumentConstructor( + globalObject, + 'WebSocket', + ({ parameters, onPostCall }) => { + const url = String(parameters[0]) + const protocols = parameters[1] + const startClocks = clocksNow() + onPostCall((instance) => { + observable.notify({ + state: 'connecting', + instance, + url, + ...(protocols !== undefined ? { protocols } : {}), + startClocks, + }) + attachInstanceListeners(instance, observable, stopListeners) + }) + } + ) + + const { stop: stopInstrumentingSend } = instrumentMethod( + globalObject.WebSocket.prototype, + 'send', + ({ target: instance, parameters: [data], onPostCall }) => { + const size = computePayloadSize(data) + const bufferedAmountPreSend = instance.bufferedAmount + onPostCall(() => { + observable.notify({ + state: 'message-out', + instance, + size, + bufferedAmountPreSend, + at: clocksNow(), + }) + }) + } + ) + + return () => { + stopInstrumentingConstructor() + stopInstrumentingSend() + stopListeners.forEach((stop) => stop()) + stopListeners.length = 0 + } + }) +} + +function attachInstanceListeners( + instance: WebSocket, + observable: Observable, + stopListeners: Array<() => void> +) { + const { stop: stopOpen } = addEventListener({ allowUntrustedEvents }, instance, 'open', () => { + observable.notify({ + state: 'open', + instance, + openClocks: clocksNow(), + protocol: instance.protocol || '', + }) + }) + const { stop: stopMessage } = addEventListener({ allowUntrustedEvents }, instance, 'message', (event) => { + observable.notify({ + state: 'message-in', + instance, + size: computePayloadSize(event.data), + at: clocksNow(), + }) + }) + const { stop: stopClose } = addEventListener({ allowUntrustedEvents }, instance, 'close', (event) => { + observable.notify({ + state: 'closed', + instance, + code: event.code, + reason: event.reason, + wasClean: event.wasClean, + at: clocksNow(), + }) + }) + + stopListeners.push(stopOpen, stopMessage, stopClose) +} + +function computePayloadSize(data: unknown): number { + if (typeof data === 'string') { + return new TextEncoder().encode(data).byteLength + } + if (data instanceof ArrayBuffer) { + return data.byteLength + } + if (ArrayBuffer.isView(data)) { + return data.byteLength + } + if (typeof Blob !== 'undefined' && data instanceof Blob) { + return data.size + } + return 0 +} + +/** + * Reset the WebSocket observable global state. Test-only. + * + * @internal + */ +export function resetWebSocketObservable() { + webSocketObservable = undefined + allowUntrustedEvents = undefined +} diff --git a/packages/browser-core/src/index.ts b/packages/browser-core/src/index.ts index 7c2a230eee..9e556685a8 100644 --- a/packages/browser-core/src/index.ts +++ b/packages/browser-core/src/index.ts @@ -126,6 +126,15 @@ export type { XhrCompleteContext, XhrStartContext, XhrContext } from './browser/ export { initXhrObservable } from './browser/xhrObservable' export type { FetchResolveContext, FetchStartContext, FetchContext } from './browser/fetchObservable' export { initFetchObservable, ResponseBodyAction } from './browser/fetchObservable' +export type { + WebSocketContext, + WebSocketConnectingContext, + WebSocketOpenContext, + WebSocketMessageInContext, + WebSocketMessageOutContext, + WebSocketClosedContext as WebSocketCloseContext, +} from './browser/webSocketObservable' +export { initWebSocketObservable } from './browser/webSocketObservable' export { fetch } from './browser/fetch' export type { PageMayExitEvent } from './browser/pageMayExitObservable' export { createPageMayExitObservable, PageExitReason, isPageExitReason } from './browser/pageMayExitObservable' diff --git a/packages/browser-core/test/forEach.spec.ts b/packages/browser-core/test/forEach.spec.ts index f95041eb14..1c393d7960 100644 --- a/packages/browser-core/test/forEach.spec.ts +++ b/packages/browser-core/test/forEach.spec.ts @@ -4,6 +4,7 @@ import { resetValueHistoryGlobals } from '../src/tools/valueHistory' import { resetFetchObservable } from '../src/browser/fetchObservable' import { resetConsoleObservable } from '../src/domain/console/consoleObservable' import { resetXhrObservable } from '../src/browser/xhrObservable' +import { resetWebSocketObservable } from '../src/browser/webSocketObservable' import { resetGetCurrentSite } from '../src/browser/cookie' import { resetReplayStats } from '../../browser-rum/src/domain/replayStats' import { resetInteractionCountPolyfill } from '../../browser-rum-core/src/domain/view/viewMetrics/interactionCountPolyfill' @@ -34,6 +35,7 @@ afterEach(() => { resetFetchObservable() resetConsoleObservable() resetXhrObservable() + resetWebSocketObservable() resetGetCurrentSite() resetReplayStats() resetMonitor()