From f15c86321e45a25cc4423434f0d9e64f0023b06a Mon Sep 17 00:00:00 2001 From: Evgeny Shurakov Date: Tue, 30 Jun 2026 14:50:48 +0200 Subject: [PATCH 1/4] CloudAgent - Support CLI-configured models in remote sessions --- .../src/app/(app)/agent-chat/model-picker.tsx | 303 +------ .../agents/mobile-session-manager.test.ts | 127 +++ .../agents/mobile-session-manager.ts | 14 +- .../agents/model-picker-content.tsx | 201 +++++ .../src/components/agents/model-selector.tsx | 276 +++++- .../agents/session-detail-content.tsx | 162 +++- .../agents/use-session-config-sync.test.ts | 46 + .../agents/use-session-config-sync.ts | 100 ++- .../lib/hooks/use-session-model-options.ts | 393 ++++++++ apps/mobile/src/lib/model-picker-rows.test.ts | 69 +- apps/mobile/src/lib/model-picker-rows.ts | 64 +- apps/mobile/src/lib/picker-bridge.test.ts | 170 ++++ apps/mobile/src/lib/picker-bridge.ts | 66 +- .../src/lib/use-session-model-options.test.ts | 290 ++++++ .../cloud-agent-next/CloudAgentProvider.tsx | 21 +- .../cloud-agent-next/CloudChatPage.tsx | 160 +++- .../cloud-agent-next/FeedbackDialog.test.ts | 13 + .../cloud-agent-next/FeedbackDialog.tsx | 12 +- .../hooks/useOrganizationModels.ts | 14 +- .../hooks/useSessionModels.test.ts | 369 ++++++++ .../hooks/useSessionModels.ts | 428 +++++++++ .../model-context-lengths.test.ts | 58 +- .../cloud-agent-next/model-context-lengths.ts | 76 +- .../src/components/shared/ModelCombobox.tsx | 278 +++--- .../shared/model-combobox-options.test.ts | 52 ++ .../shared/model-combobox-options.ts | 66 ++ .../cli-live-transport.test.ts | 841 +++++++++++++++++- .../lib/cloud-agent-sdk/cli-live-transport.ts | 259 +++++- .../cloud-agent-transport.test.ts | 41 +- .../cloud-agent-sdk/cloud-agent-transport.ts | 28 +- apps/web/src/lib/cloud-agent-sdk/index.ts | 26 +- .../lib/cloud-agent-sdk/normalizer.test.ts | 57 ++ .../web/src/lib/cloud-agent-sdk/normalizer.ts | 26 +- .../remote-model-catalog.test.ts | 393 ++++++++ .../cloud-agent-sdk/remote-model-catalog.ts | 69 ++ apps/web/src/lib/cloud-agent-sdk/schemas.ts | 240 +++++ .../cloud-agent-sdk/session-manager.test.ts | 583 +++++++++++- .../lib/cloud-agent-sdk/session-manager.ts | 220 ++++- .../cloud-agent-sdk/session-transport.test.ts | 142 ++- .../src/lib/cloud-agent-sdk/session.test.ts | 71 +- apps/web/src/lib/cloud-agent-sdk/session.ts | 28 +- apps/web/src/lib/cloud-agent-sdk/transport.ts | 34 +- apps/web/src/lib/cloud-agent-sdk/types.ts | 7 +- .../user-web-connection.test.ts | 138 +++ .../cloud-agent-sdk/user-web-connection.ts | 44 +- .../web/src/lib/session-ingest-client.test.ts | 39 +- apps/web/src/lib/session-ingest-client.ts | 14 +- .../routers/cli-sessions-v2-router.test.ts | 61 ++ .../web/src/routers/cli-sessions-v2-router.ts | 8 +- .../src/dos/UserConnectionDO.test.ts | 518 ++++++++++- .../src/dos/UserConnectionDO.ts | 251 +++++- 51 files changed, 7151 insertions(+), 815 deletions(-) create mode 100644 apps/mobile/src/components/agents/model-picker-content.tsx create mode 100644 apps/mobile/src/components/agents/use-session-config-sync.test.ts create mode 100644 apps/mobile/src/lib/hooks/use-session-model-options.ts create mode 100644 apps/mobile/src/lib/picker-bridge.test.ts create mode 100644 apps/mobile/src/lib/use-session-model-options.test.ts create mode 100644 apps/web/src/components/cloud-agent-next/FeedbackDialog.test.ts create mode 100644 apps/web/src/components/cloud-agent-next/hooks/useSessionModels.test.ts create mode 100644 apps/web/src/components/cloud-agent-next/hooks/useSessionModels.ts create mode 100644 apps/web/src/components/shared/model-combobox-options.test.ts create mode 100644 apps/web/src/components/shared/model-combobox-options.ts create mode 100644 apps/web/src/lib/cloud-agent-sdk/remote-model-catalog.test.ts create mode 100644 apps/web/src/lib/cloud-agent-sdk/remote-model-catalog.ts diff --git a/apps/mobile/src/app/(app)/agent-chat/model-picker.tsx b/apps/mobile/src/app/(app)/agent-chat/model-picker.tsx index 3ac89e5ff1..392a6fd043 100644 --- a/apps/mobile/src/app/(app)/agent-chat/model-picker.tsx +++ b/apps/mobile/src/app/(app)/agent-chat/model-picker.tsx @@ -1,304 +1,5 @@ -import * as Haptics from 'expo-haptics'; -import { useFocusEffect, useRouter } from 'expo-router'; -import { BookOpenCheck, Check, Search } from 'lucide-react-native'; -import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; -import { FlatList, Pressable, ScrollView, TextInput, View } from 'react-native'; -import { useSafeAreaInsets } from 'react-native-safe-area-context'; - -import { Text } from '@/components/ui/text'; -import { - BYOK_MODEL_LABEL, - FREE_MODEL_DATA_LABEL, - FREE_MODEL_FREE_LABEL, - hasUserByokAvailable, - isFreeModelOption, - mayTrainOnYourPrompts, -} from '@/lib/free-model-data-disclosure'; -import { useThemeColors } from '@/lib/hooks/use-theme-colors'; -import { type ModelOption, thinkingEffortLabel } from '@/lib/hooks/use-available-models'; -import { buildModelPickerRows, type ModelPickerRow } from '@/lib/model-picker-rows'; -import { clearModelPickerBridge, getModelPickerBridge } from '@/lib/picker-bridge'; - -function getVariantForModel(model: ModelOption, currentVariant: string) { - if (currentVariant && model.variants.includes(currentVariant)) { - return currentVariant; - } - return model.variants[0] ?? ''; -} +import { ModelPickerContent } from '@/components/agents/model-picker-content'; export default function ModelPickerScreen() { - const router = useRouter(); - const colors = useThemeColors(); - const { bottom } = useSafeAreaInsets(); - const [search, setSearch] = useState(''); - const [bridge, setBridge] = useState(() => getModelPickerBridge()); - - const [selectedModel, setSelectedModel] = useState(bridge?.currentValue ?? ''); - const [selectedVariant, setSelectedVariant] = useState(bridge?.currentVariant ?? ''); - - const bridgeRef = useRef(bridge); - const selectedModelRef = useRef(selectedModel); - const selectedVariantRef = useRef(selectedVariant); - const closePickerTimerRef = useRef | null>(null); - - const closePicker = useCallback(() => { - router.back(); - }, [router]); - - useFocusEffect( - useCallback(() => { - const nextBridge = getModelPickerBridge(); - const nextModel = nextBridge?.currentValue ?? ''; - const nextVariant = nextBridge?.currentVariant ?? ''; - - bridgeRef.current = nextBridge; - selectedModelRef.current = nextModel; - selectedVariantRef.current = nextVariant; - setBridge(nextBridge); - setSelectedModel(nextModel); - setSelectedVariant(nextVariant); - setSearch(''); - - return () => { - if (closePickerTimerRef.current) { - clearTimeout(closePickerTimerRef.current); - closePickerTimerRef.current = null; - } - - const activeBridge = bridgeRef.current; - if (activeBridge) { - activeBridge.onSelect(selectedModelRef.current, selectedVariantRef.current); - clearModelPickerBridge(); - bridgeRef.current = null; - } - }; - }, []) - ); - - const currentModelOption = useMemo( - () => bridge?.options.find(m => m.id === selectedModel), - [bridge, selectedModel] - ); - - useEffect(() => { - if (!currentModelOption) { - return; - } - - const nextVariant = getVariantForModel(currentModelOption, selectedVariantRef.current); - if (nextVariant === selectedVariantRef.current) { - return; - } - - selectedVariantRef.current = nextVariant; - setSelectedVariant(nextVariant); - }, [currentModelOption]); - - const rows = useMemo( - () => buildModelPickerRows({ models: bridge?.options ?? [], search }), - [bridge, search] - ); - - const handleSelectVariant = useCallback( - (variant: string) => { - void Haptics.selectionAsync(); - selectedVariantRef.current = variant; - setSelectedVariant(variant); - - if (closePickerTimerRef.current) { - clearTimeout(closePickerTimerRef.current); - } - closePickerTimerRef.current = setTimeout(() => { - closePickerTimerRef.current = null; - closePicker(); - }, 175); - }, - [closePicker] - ); - - const handleSelectModel = useCallback( - (id: string) => { - void Haptics.selectionAsync(); - const model = bridge?.options.find(m => m.id === id); - if (!model) { - return; - } - - const nextVariant = getVariantForModel(model, selectedVariantRef.current); - selectedModelRef.current = id; - selectedVariantRef.current = nextVariant; - setSelectedModel(id); - setSelectedVariant(nextVariant); - - if (model.variants.length <= 1) { - closePicker(); - } - }, - [bridge, closePicker] - ); - - if (!bridge) { - return ( - - No models available - - ); - } - - return ( - item.key} - keyboardShouldPersistTaps="handled" - keyboardDismissMode="on-drag" - contentContainerStyle={{ paddingBottom: bottom }} - ListHeaderComponent={ - - - Select Model - - Done - - - - - - - - } - ListEmptyComponent={ - - - {search.trim() ? 'No models match your search' : 'No models available'} - - - } - renderItem={({ item }) => { - if (item.type === 'header') { - return ( - - - {item.title} - - - ); - } - - const modelOption = item.model; - const selected = modelOption.id === selectedModel; - const free = isFreeModelOption(modelOption); - const byok = hasUserByokAvailable(modelOption); - const collectsData = mayTrainOnYourPrompts(modelOption); - const hasVariants = modelOption.variants.length > 1; - const accessibilityLabel = [ - modelOption.name, - byok ? BYOK_MODEL_LABEL : undefined, - free && !byok ? FREE_MODEL_FREE_LABEL : undefined, - collectsData ? FREE_MODEL_DATA_LABEL : undefined, - selected ? 'selected' : undefined, - ] - .filter(Boolean) - .join(', '); - - return ( - - { - handleSelectModel(modelOption.id); - }} - accessibilityRole="button" - accessibilityLabel={accessibilityLabel} - > - - {modelOption.name} - {modelOption.id} - {free || byok || collectsData ? ( - - {free && !byok ? ( - - - {FREE_MODEL_FREE_LABEL} - - - ) : null} - {byok ? ( - - - {BYOK_MODEL_LABEL} - - - ) : null} - {collectsData ? ( - - ) : null} - - ) : null} - - {selected && } - - - {selected && hasVariants ? ( - - - Thinking effort - - - {modelOption.variants.map(variant => { - const isActive = variant === selectedVariant; - return ( - { - handleSelectVariant(variant); - }} - accessibilityRole="button" - accessibilityLabel={`${thinkingEffortLabel(variant)} thinking effort${isActive ? ', selected' : ''}`} - > - - {thinkingEffortLabel(variant)} - - - ); - })} - - - ) : null} - - ); - }} - /> - ); + return ; } diff --git a/apps/mobile/src/components/agents/mobile-session-manager.test.ts b/apps/mobile/src/components/agents/mobile-session-manager.test.ts index a7185c8d3b..f965db1444 100644 --- a/apps/mobile/src/components/agents/mobile-session-manager.test.ts +++ b/apps/mobile/src/components/agents/mobile-session-manager.test.ts @@ -6,6 +6,10 @@ const mocks = vi.hoisted(() => ({ createSessionManager: vi.fn(config => ({ config })), createNativeUserWebConnectionLifecycleHooks: vi.fn(() => ({ marker: 'native-lifecycle-hooks' })), getWithRuntimeStateQuery: vi.fn(), + getSessionQuery: vi.fn(), + getSessionMessagesQuery: vi.fn(), + sendMessageMutate: vi.fn(), + prepareSessionMutate: vi.fn(), })); function noCleanup(): void { @@ -64,8 +68,23 @@ vi.mock('@/lib/config', () => ({ vi.mock('@/lib/trpc', () => ({ trpcClient: { cliSessionsV2: { + get: { query: mocks.getSessionQuery }, + getSessionMessages: { query: mocks.getSessionMessagesQuery }, getWithRuntimeState: { query: mocks.getWithRuntimeStateQuery }, }, + activeSessions: { + list: { query: vi.fn() }, + }, + cloudAgentNext: { + sendMessage: { mutate: mocks.sendMessageMutate }, + prepareSession: { mutate: mocks.prepareSessionMutate }, + }, + organizations: { + cloudAgentNext: { + sendMessage: { mutate: mocks.sendMessageMutate }, + prepareSession: { mutate: mocks.prepareSessionMutate }, + }, + }, }, })); @@ -75,11 +94,22 @@ type CapturedSessionManagerConfig = { getAuthToken?: () => Promise; lifecycleHooks?: unknown; fetchSession: (kiloSessionId: string) => Promise<{ associatedPr: unknown }>; + fetchSnapshot: (kiloSessionId: string) => Promise<{ info: unknown; messages: unknown[] }>; + prepare: (input: { + prompt: string; + mode: string; + model: string; + initialPayload?: unknown; + }) => Promise; }; describe('createMobileAgentSessionManager', () => { beforeEach(() => { vi.clearAllMocks(); + mocks.prepareSessionMutate.mockResolvedValue({ + cloudAgentSessionId: 'agent_123', + kiloSessionId: 'ses_123', + }); }); it('injects the app-scoped user web connection without raw viewer transport options', async () => { @@ -111,6 +141,103 @@ describe('createMobileAgentSessionManager', () => { expect(config.lifecycleHooks).toEqual({ marker: 'native-lifecycle-hooks' }); }); + it('converts an initial Kilo model ref to the Cloud Agent prepare payload', async () => { + const { createMobileAgentSessionManager } = + await import('@/components/agents/mobile-session-manager'); + + createMobileAgentSessionManager({ + store: createStore(), + userWebConnection, + }); + + const config = mocks.createSessionManager.mock.calls[0]?.[0] as CapturedSessionManagerConfig; + await config.prepare({ + prompt: 'Initial prompt', + mode: 'code', + model: 'fallback-model', + initialPayload: { + type: 'prompt', + prompt: 'Initial prompt', + mode: 'code', + model: { providerID: 'kilo', modelID: 'anthropic/claude-sonnet-4' }, + variant: 'high', + }, + }); + + expect(mocks.prepareSessionMutate).toHaveBeenCalledWith( + { + prompt: 'Initial prompt', + mode: 'code', + model: 'fallback-model', + initialPayload: { + type: 'prompt', + prompt: 'Initial prompt', + mode: 'code', + model: 'anthropic/claude-sonnet-4', + variant: 'high', + }, + }, + { context: { skipBatch: true } } + ); + }); + + it('rejects a non-Kilo initial model ref before Cloud Agent prepare', async () => { + const { createMobileAgentSessionManager } = + await import('@/components/agents/mobile-session-manager'); + + createMobileAgentSessionManager({ + store: createStore(), + userWebConnection, + }); + + const config = mocks.createSessionManager.mock.calls[0]?.[0] as CapturedSessionManagerConfig; + await expect( + config.prepare({ + prompt: 'Initial prompt', + mode: 'code', + model: 'fallback-model', + initialPayload: { + type: 'prompt', + prompt: 'Initial prompt', + mode: 'code', + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + }, + }) + ).rejects.toThrow('Cloud Agent only supports Kilo models'); + expect(mocks.prepareSessionMutate).not.toHaveBeenCalled(); + }); + + it('preserves snapshot model metadata from the session export', async () => { + const { createMobileAgentSessionManager } = + await import('@/components/agents/mobile-session-manager'); + mocks.getSessionQuery.mockResolvedValue({ + session_id: 'ses_123', + parent_session_id: null, + }); + mocks.getSessionMessagesQuery.mockResolvedValue({ + info: { + id: 'ses_123', + model: { providerID: 'anthropic', id: 'claude-sonnet-4', variant: 'high' }, + }, + messages: [], + }); + + createMobileAgentSessionManager({ + store: createStore(), + userWebConnection, + }); + + const config = mocks.createSessionManager.mock.calls[0]?.[0] as CapturedSessionManagerConfig; + await expect(config.fetchSnapshot('ses_123')).resolves.toEqual({ + info: { + id: 'ses_123', + parentID: undefined, + model: { providerID: 'anthropic', id: 'claude-sonnet-4', variant: 'high' }, + }, + messages: [], + }); + }); + it('propagates associatedPr from fetched session data', async () => { const { createMobileAgentSessionManager } = await import('@/components/agents/mobile-session-manager'); diff --git a/apps/mobile/src/components/agents/mobile-session-manager.ts b/apps/mobile/src/components/agents/mobile-session-manager.ts index 26fc3d0230..11a284aff2 100644 --- a/apps/mobile/src/components/agents/mobile-session-manager.ts +++ b/apps/mobile/src/components/agents/mobile-session-manager.ts @@ -38,12 +38,15 @@ function normalizeTransportPayload(payload: TransportSendPayload): SendMessagePa if (!payload.model) { throw new Error('Model is required'); } + if (payload.model.providerID !== 'kilo') { + throw new Error('Cloud Agent only supports Kilo models'); + } return { type: 'prompt', prompt: payload.prompt, mode: normalizeAgentMode(payload.mode), - model: payload.model, + model: payload.model.modelID, variant: payload.variant, }; } @@ -125,10 +128,12 @@ export function createMobileAgentSessionManager({ trpcClient.cliSessionsV2.get.query({ session_id: id }), trpcClient.cliSessionsV2.getSessionMessages.query({ session_id: id }), ]); + const snapshotInfo = messagesResult.info as Partial; return { info: { - id: sessionData.session_id, - parentID: sessionData.parent_session_id ?? undefined, + id: snapshotInfo.id ?? sessionData.session_id, + parentID: snapshotInfo.parentID ?? sessionData.parent_session_id ?? undefined, + ...(snapshotInfo.model ? { model: snapshotInfo.model } : {}), }, messages: messagesResult.messages as SessionSnapshot['messages'], }; @@ -136,10 +141,9 @@ export function createMobileAgentSessionManager({ api: { send: async input => { await withCloudAgentDiagnostics('send', organizationId, async () => { - const payload = normalizeTransportPayload(input.payload); const baseInput = { cloudAgentSessionId: input.sessionId as string, - payload, + payload: input.payload, autoCommit: true, messageId: input.messageId, }; diff --git a/apps/mobile/src/components/agents/model-picker-content.tsx b/apps/mobile/src/components/agents/model-picker-content.tsx new file mode 100644 index 0000000000..e24849538e --- /dev/null +++ b/apps/mobile/src/components/agents/model-picker-content.tsx @@ -0,0 +1,201 @@ +import * as Haptics from 'expo-haptics'; +import { useFocusEffect, useRouter } from 'expo-router'; +import { Search } from 'lucide-react-native'; +import { useCallback, useMemo, useRef, useState } from 'react'; +import { FlatList, Pressable, TextInput, View } from 'react-native'; + +import { ModelPickerOptionRow } from '@/components/agents/model-selector'; +import { ScreenHeader } from '@/components/screen-header'; +import { Text } from '@/components/ui/text'; +import { type SessionModelOption } from '@/lib/hooks/use-session-model-options'; +import { useThemeColors } from '@/lib/hooks/use-theme-colors'; +import { buildModelPickerRows, type ModelPickerRow } from '@/lib/model-picker-rows'; +import { + clearModelPickerBridge, + commitModelPickerSelection, + getModelPickerBridge, + resolveModelPickerSelection, +} from '@/lib/picker-bridge'; + +export function ModelPickerContent() { + const router = useRouter(); + const colors = useThemeColors(); + const [search, setSearch] = useState(''); + const [bridge, setBridge] = useState(() => getModelPickerBridge()); + const [selectedModel, setSelectedModel] = useState(bridge?.currentValue ?? ''); + const [selectedVariant, setSelectedVariant] = useState(bridge?.currentVariant ?? ''); + const bridgeRef = useRef(bridge); + const selectedModelRef = useRef(selectedModel); + const selectedVariantRef = useRef(selectedVariant); + const selectionChangedRef = useRef(false); + const closePickerTimerRef = useRef | null>(null); + + const closePicker = useCallback(() => { + router.back(); + }, [router]); + + useFocusEffect( + useCallback(() => { + const nextBridge = getModelPickerBridge(); + const nextModel = nextBridge?.currentValue ?? ''; + const nextVariant = nextBridge?.currentVariant ?? ''; + + bridgeRef.current = nextBridge; + selectedModelRef.current = nextModel; + selectedVariantRef.current = nextVariant; + selectionChangedRef.current = false; + setBridge(nextBridge); + setSelectedModel(nextModel); + setSelectedVariant(nextVariant); + setSearch(''); + + return () => { + if (closePickerTimerRef.current) { + clearTimeout(closePickerTimerRef.current); + closePickerTimerRef.current = null; + } + + const activeBridge = bridgeRef.current; + if (activeBridge && selectionChangedRef.current) { + commitModelPickerSelection( + activeBridge, + selectedModelRef.current, + selectedVariantRef.current + ); + } + clearModelPickerBridge(); + bridgeRef.current = null; + }; + }, []) + ); + + const rows = useMemo( + () => buildModelPickerRows({ models: bridge?.options ?? [], search }), + [bridge, search] + ); + + const handleSelectVariant = useCallback( + (variant: string) => { + void Haptics.selectionAsync(); + selectionChangedRef.current = true; + selectedVariantRef.current = variant; + setSelectedVariant(variant); + + if (closePickerTimerRef.current) { + clearTimeout(closePickerTimerRef.current); + } + closePickerTimerRef.current = setTimeout(() => { + closePickerTimerRef.current = null; + closePicker(); + }, 175); + }, + [closePicker] + ); + + const handleSelectModel = useCallback( + (option: SessionModelOption) => { + if (option.unavailable || !bridge) { + return; + } + void Haptics.selectionAsync(); + const selection = resolveModelPickerSelection(bridge, option.id, selectedVariantRef.current); + if (!selection) { + return; + } + + selectionChangedRef.current = true; + selectedModelRef.current = option.id; + selectedVariantRef.current = selection.variant; + setSelectedModel(option.id); + setSelectedVariant(selection.variant); + if (option.variants.length <= 1) { + closePicker(); + } + }, + [bridge, closePicker] + ); + + if (!bridge) { + return ( + + + + No models available + + + ); + } + + return ( + + + Done + + } + /> + item.key} + keyboardShouldPersistTaps="handled" + keyboardDismissMode="on-drag" + contentInsetAdjustmentBehavior="automatic" + ListHeaderComponent={ + + + + + + + } + ListEmptyComponent={ + + + {search.trim() ? 'No models match your search' : 'No models available'} + + + } + renderItem={({ item }) => { + if (item.type === 'header') { + return ( + + + {item.title} + + + ); + } + + return ( + + ); + }} + /> + + ); +} diff --git a/apps/mobile/src/components/agents/model-selector.tsx b/apps/mobile/src/components/agents/model-selector.tsx index 377b85faec..4cb1992184 100644 --- a/apps/mobile/src/components/agents/model-selector.tsx +++ b/apps/mobile/src/components/agents/model-selector.tsx @@ -1,27 +1,99 @@ -import { Pressable, View } from 'react-native'; +/* eslint-disable max-lines -- The selector and picker row share model disclosure behavior. */ import { type Href, useRouter } from 'expo-router'; -import { BookOpenCheck, Brain, ChevronDown } from 'lucide-react-native'; +import { + AlertTriangle, + BookOpenCheck, + Brain, + Check, + ChevronDown, + Info, + RefreshCw, +} from 'lucide-react-native'; +import { createContext, type ReactNode, useContext, useMemo } from 'react'; +import { Pressable, ScrollView, View } from 'react-native'; import { Text } from '@/components/ui/text'; import { BYOK_MODEL_LABEL, + FREE_MODEL_DATA_LABEL, + FREE_MODEL_FREE_LABEL, getFreeModelDataAccessibilityLabel, hasUserByokAvailable, + isFreeModelOption, mayTrainOnYourPrompts, } from '@/lib/free-model-data-disclosure'; import { type ModelOption, thinkingEffortLabel } from '@/lib/hooks/use-available-models'; +import { + type SessionModelNotice, + type SessionModelOption, +} from '@/lib/hooks/use-session-model-options'; import { useThemeColors } from '@/lib/hooks/use-theme-colors'; -import { setModelPickerBridge } from '@/lib/picker-bridge'; +import { + type ModelPickerSelection, + type ModelPickerSelectionScope, + setModelPickerBridge, +} from '@/lib/picker-bridge'; import { cn } from '@/lib/utils'; type ModelSelectorProps = { value: string; variant: string; - options: ModelOption[]; - onSelect: (modelId: string, variant: string) => void; + options: (ModelOption | SessionModelOption)[]; + onSelect: (modelId: string, variant: string, pickerSelection?: ModelPickerSelection) => void; disabled?: boolean; }; +type ModelPickerSelectionScopeContextValue = { + selectionScope: ModelPickerSelectionScope; + isSelectionCurrent: (scope: ModelPickerSelectionScope) => boolean; +}; + +const UNFENCED_SELECTION_CONTEXT: ModelPickerSelectionScopeContextValue = { + selectionScope: { + sessionId: 'unscoped', + ownerConnectionId: null, + protocol: 'unknown', + catalogGenerationIdentity: null, + }, + isSelectionCurrent: () => true, +}; + +const ModelPickerSelectionScopeContext = createContext(UNFENCED_SELECTION_CONTEXT); + +export function ModelPickerSelectionScopeProvider({ + children, + selectionScope, + isSelectionCurrent, +}: Readonly) { + const contextValue = useMemo( + () => ({ selectionScope, isSelectionCurrent }), + [isSelectionCurrent, selectionScope] + ); + + return ( + + {children} + + ); +} + +function toSessionModelOption(option: ModelOption | SessionModelOption): SessionModelOption { + if ( + 'displayId' in option && + typeof option.displayId === 'string' && + 'showGatewayMetadata' in option && + typeof option.showGatewayMetadata === 'boolean' + ) { + return { + ...option, + displayId: option.displayId, + showGatewayMetadata: option.showGatewayMetadata, + }; + } + + return { ...option, displayId: option.id, showGatewayMetadata: true }; +} + function compactThinkingEffortLabel(variant: string) { if (variant === 'xhigh') { return 'XH'; @@ -41,12 +113,17 @@ export function ModelSelector({ }: Readonly) { const router = useRouter(); const colors = useThemeColors(); - const effectivelyDisabled = disabled || options.length === 0; - - const selectedModel = options.find(m => m.id === value); - const label = selectedModel?.name ?? (value || 'Model'); - const byok = hasUserByokAvailable(selectedModel); - const collectsData = mayTrainOnYourPrompts(selectedModel); + const selectionContext = useContext(ModelPickerSelectionScopeContext); + const pickerOptions = options.map(option => toSessionModelOption(option)); + const effectivelyDisabled = disabled || pickerOptions.every(option => option.unavailable); + const selectedModel = pickerOptions.find(option => option.id === value); + const providerAware = pickerOptions.some( + option => option.modelRef !== undefined || !option.showGatewayMetadata + ); + const showGatewayMetadata = selectedModel?.showGatewayMetadata ?? false; + const label = selectedModel?.name ?? (!providerAware && value ? value : 'Model'); + const byok = showGatewayMetadata && hasUserByokAvailable(selectedModel); + const collectsData = showGatewayMetadata && mayTrainOnYourPrompts(selectedModel); const hasVariants = selectedModel ? selectedModel.variants.length > 1 : false; const variantLabel = variant ? thinkingEffortLabel(variant) : ''; const compactVariantLabel = variant ? compactThinkingEffortLabel(variant) : ''; @@ -60,10 +137,14 @@ export function ModelSelector({ return; } setModelPickerBridge({ - options, + options: pickerOptions, currentValue: value, currentVariant: variant, - onSelect, + selectionScope: selectionContext.selectionScope, + isSelectionCurrent: selectionContext.isSelectionCurrent, + onSelect: selection => { + onSelect(selection.option.id, selection.variant, selection); + }, }); router.push('/(app)/agent-chat/model-picker' as Href); } @@ -107,3 +188,172 @@ export function ModelSelector({ ); } + +export function ModelPickerOptionRow({ + option, + selected, + selectedVariant, + onSelectModel, + onSelectVariant, +}: Readonly<{ + option: SessionModelOption; + selected: boolean; + selectedVariant: string; + onSelectModel: (option: SessionModelOption) => void; + onSelectVariant: (variant: string) => void; +}>) { + const colors = useThemeColors(); + const free = option.showGatewayMetadata && isFreeModelOption(option); + const byok = option.showGatewayMetadata && hasUserByokAvailable(option); + const collectsData = option.showGatewayMetadata && mayTrainOnYourPrompts(option); + const accessibilityLabel = [ + option.provider?.name, + option.name, + option.displayId, + byok ? BYOK_MODEL_LABEL : undefined, + free && !byok ? FREE_MODEL_FREE_LABEL : undefined, + collectsData ? FREE_MODEL_DATA_LABEL : undefined, + option.unavailable ? 'unavailable' : undefined, + selected ? 'selected' : undefined, + ] + .filter(Boolean) + .join(', '); + + return ( + + { + onSelectModel(option); + }} + disabled={option.unavailable} + accessibilityRole="button" + accessibilityLabel={accessibilityLabel} + > + + {option.name} + {option.modelRef ? ( + + Provider {option.modelRef.providerID} + + ) : null} + {option.displayId ? ( + + {option.modelRef ? `Model ${option.displayId}` : option.displayId} + + ) : null} + {option.unavailable ? ( + Unavailable + ) : null} + {free || byok || collectsData ? ( + + {free && !byok ? ( + + + {FREE_MODEL_FREE_LABEL} + + + ) : null} + {byok ? ( + + + {BYOK_MODEL_LABEL} + + + ) : null} + {collectsData ? ( + + ) : null} + + ) : null} + + {selected ? : null} + + {selected && option.variants.length > 1 ? ( + + + Thinking effort + + + {option.variants.map(thinkingVariant => { + const active = thinkingVariant === selectedVariant; + return ( + { + onSelectVariant(thinkingVariant); + }} + accessibilityRole="button" + accessibilityLabel={`${thinkingEffortLabel(thinkingVariant)} thinking effort${active ? ', selected' : ''}`} + > + + {thinkingEffortLabel(thinkingVariant)} + + + ); + })} + + + ) : null} + + ); +} + +export function SessionModelNotices({ + notices, + onRetry, +}: Readonly<{ notices: SessionModelNotice[]; onRetry: () => void }>) { + const colors = useThemeColors(); + if (notices.length === 0) { + return null; + } + + return ( + + {notices.map(notice => { + const Icon = + notice.id === 'legacy' || notice.id === 'local-provider' ? Info : AlertTriangle; + return ( + + + + {notice.message} + + {notice.retry ? ( + + + Retry + + ) : null} + + ); + })} + + ); +} diff --git a/apps/mobile/src/components/agents/session-detail-content.tsx b/apps/mobile/src/components/agents/session-detail-content.tsx index 57d49e7066..d0365f704f 100644 --- a/apps/mobile/src/components/agents/session-detail-content.tsx +++ b/apps/mobile/src/components/agents/session-detail-content.tsx @@ -1,6 +1,7 @@ +/* eslint-disable max-lines -- Session orchestration and its render paths are kept together. */ import { type CloudStatus, type KiloSessionId, type StoredMessage } from 'cloud-agent-sdk'; import { useAtomValue } from 'jotai'; -import { useCallback, useEffect, useMemo } from 'react'; +import { useCallback, useEffect, useMemo, useRef } from 'react'; import { ActivityIndicator, FlatList, KeyboardAvoidingView, Platform, View } from 'react-native'; import { useSafeAreaInsets } from 'react-native-safe-area-context'; import { toast } from 'sonner-native'; @@ -8,6 +9,10 @@ import { toast } from 'sonner-native'; import { ChatComposer } from '@/components/agents/chat-composer'; import { ConnectivityBanner } from '@/components/agents/connectivity-banner'; import { MessageBubble } from '@/components/agents/message-bubble'; +import { + ModelPickerSelectionScopeProvider, + SessionModelNotices, +} from '@/components/agents/model-selector'; import { PermissionCard } from '@/components/agents/permission-card'; import { QuestionCard } from '@/components/agents/question-card'; import { getSessionKeyboardContainerKind } from '@/components/agents/session-keyboard-container-state'; @@ -26,6 +31,16 @@ import { ScreenHeader } from '@/components/screen-header'; import { Text } from '@/components/ui/text'; import { useAppLifecycle } from '@/lib/hooks/use-app-lifecycle'; import { useAvailableModels } from '@/lib/hooks/use-available-models'; +import { + createRemoteModelOverride, + revalidateLegacyGatewayOverride, + useSessionModelOptions, +} from '@/lib/hooks/use-session-model-options'; +import { + areModelPickerSelectionScopesEqual, + type ModelPickerSelection, + type ModelPickerSelectionScope, +} from '@/lib/picker-bridge'; type SessionDetailContentProps = { sessionId: KiloSessionId; @@ -61,6 +76,10 @@ export function SessionDetailContent({ sessionId }: Readonly( + () => ({ + sessionId, + ownerConnectionId: remoteModelState.ownerConnectionId, + protocol: remoteModelState.protocol, + catalogGenerationIdentity, + }), + [ + catalogGenerationIdentity, + remoteModelState.ownerConnectionId, + remoteModelState.protocol, + sessionId, + ] + ); + const liveModelPickerSelectionScopeRef = useRef(modelPickerSelectionScope); + liveModelPickerSelectionScopeRef.current = modelPickerSelectionScope; + const isModelPickerSelectionCurrent = useCallback( + (selectionScope: ModelPickerSelectionScope) => + areModelPickerSelectionScopesEqual(liveModelPickerSelectionScopeRef.current, selectionScope), + [] + ); const { currentMode, @@ -84,7 +137,14 @@ export function SessionDetailContent({ sessionId }: Readonly { + if ( + activeSessionType !== 'remote' || + remoteModelState.protocol !== 'legacy' || + fetchedData?.kiloSessionId !== sessionId || + gatewayModelsLoading + ) { + return; + } + + const revalidatedOverride = revalidateLegacyGatewayOverride(remoteModelOverride, gatewayModels); + if (revalidatedOverride !== remoteModelOverride) { + manager.setRemoteModelOverride(revalidatedOverride); + } + }, [ + activeSessionType, + fetchedData?.kiloSessionId, + gatewayModels, + gatewayModelsLoading, + manager, + remoteModelOverride, + remoteModelState.protocol, + sessionId, + ]); + const lastAssistantIndex = useMemo(() => { for (let i = messages.length - 1; i >= 0; i -= 1) { if (messages[i]?.info.role === 'assistant') { @@ -130,6 +215,35 @@ export function SessionDetailContent({ sessionId }: Readonly { + if (activeSessionType === 'remote') { + if (pickerSelection?.option.action === 'use-session-model') { + manager.setRemoteModelOverride(null); + return; + } + + const selectedRef = pickerSelection?.option.modelRef; + const option = selectedRef + ? modelOptions.find( + candidate => + candidate.overrideSource === pickerSelection.option.overrideSource && + candidate.modelRef?.providerID === selectedRef.providerID && + candidate.modelRef.modelID === selectedRef.modelID + ) + : modelOptions.find(candidate => candidate.id === value); + if (option) { + manager.setRemoteModelOverride(createRemoteModelOverride(option, variant)); + } + return; + } + + setCurrentModel(value); + setCurrentVariant(variant); + }, + [activeSessionType, manager, modelOptions, setCurrentModel, setCurrentVariant] + ); + const shouldShowLoading = isLoading || (fetchedData === null && !statusIndicator && !error) || @@ -249,22 +363,32 @@ export function SessionDetailContent({ sessionId }: Readonly ) : ( - { - setCurrentModel(modelId); - setCurrentVariant(newVariant); - }} - /> + <> + { + manager.retryRemoteModels(); + }} + /> + + + + ))} ); diff --git a/apps/mobile/src/components/agents/use-session-config-sync.test.ts b/apps/mobile/src/components/agents/use-session-config-sync.test.ts new file mode 100644 index 0000000000..0dd5ce3ad9 --- /dev/null +++ b/apps/mobile/src/components/agents/use-session-config-sync.test.ts @@ -0,0 +1,46 @@ +import { describe, expect, it, vi } from 'vitest'; + +import { resolveSessionConfigSelection } from './use-session-config-sync'; + +vi.mock('@/components/agents/mode-options', () => ({ + normalizeAgentMode: (mode: string | null | undefined) => mode ?? 'code', +})); + +const gatewayModels = [ + { + id: 'gateway/first', + name: 'First Gateway Model', + displayId: 'gateway/first', + variants: ['high'], + isPreferred: true, + showGatewayMetadata: true, + }, +]; + +describe('resolveSessionConfigSelection', () => { + it('does not auto-select the first Gateway model for a remote session without an override', () => { + expect( + resolveSessionConfigSelection({ + activeSessionType: 'remote', + fetchedData: {}, + sessionConfig: { model: 'gateway/from-assistant', variant: 'high' }, + modelOptions: gatewayModels, + selectedModel: '', + selectedVariant: '', + }) + ).toEqual({ model: '', variant: '' }); + }); + + it('preserves the existing first-model default for Cloud Agent sessions', () => { + expect( + resolveSessionConfigSelection({ + activeSessionType: 'cloud-agent', + fetchedData: {}, + sessionConfig: null, + modelOptions: gatewayModels, + selectedModel: '', + selectedVariant: '', + }) + ).toEqual({ model: 'gateway/first', variant: 'high' }); + }); +}); diff --git a/apps/mobile/src/components/agents/use-session-config-sync.ts b/apps/mobile/src/components/agents/use-session-config-sync.ts index cc8618babd..bf569d3492 100644 --- a/apps/mobile/src/components/agents/use-session-config-sync.ts +++ b/apps/mobile/src/components/agents/use-session-config-sync.ts @@ -1,8 +1,9 @@ +import { type ResolvedSession } from 'cloud-agent-sdk'; import { useEffect, useState } from 'react'; import { normalizeAgentMode } from '@/components/agents/mode-options'; import { type AgentMode } from '@/components/agents/mode-selector'; -import { type ModelOption } from '@/lib/hooks/use-available-models'; +import { type SessionModelOption } from '@/lib/hooks/use-session-model-options'; type SessionConfigSnapshot = { mode?: string | null; @@ -10,12 +11,17 @@ type SessionConfigSnapshot = { variant?: string | null; }; -type UseSessionConfigSyncOptions = { +type ResolveSessionConfigSelectionOptions = { + activeSessionType: ResolvedSession['type'] | null; fetchedData: SessionConfigSnapshot | null; sessionConfig: SessionConfigSnapshot | null | undefined; - modelOptions: ModelOption[]; + modelOptions: SessionModelOption[]; + selectedModel: string; + selectedVariant: string; }; +type UseSessionConfigSyncOptions = ResolveSessionConfigSelectionOptions; + type UseSessionConfigSyncResult = { currentMode: AgentMode; currentModel: string; @@ -25,55 +31,77 @@ type UseSessionConfigSyncResult = { setCurrentVariant: (variant: string) => void; }; -// Keeps the composer's mode/model/variant in sync with the session's -// fetched data and the SDK session config (which is updated from assistant -// messages during snapshot replay). For sessions without a configured model -// (e.g. remote CLI sessions), auto-selects the first available model. +export function resolveSessionConfigSelection({ + activeSessionType, + fetchedData, + sessionConfig, + modelOptions, + selectedModel, + selectedVariant, +}: ResolveSessionConfigSelectionOptions): { model: string; variant: string } { + if (activeSessionType === 'remote') { + return { model: selectedModel, variant: selectedVariant }; + } + + const configuredModel = sessionConfig?.model ?? fetchedData?.model ?? ''; + if (configuredModel) { + return { + model: configuredModel, + variant: sessionConfig?.variant ?? fetchedData?.variant ?? '', + }; + } + + if (activeSessionType !== 'cloud-agent' || fetchedData === null) { + return { model: '', variant: '' }; + } + + const firstModel = modelOptions[0]; + return firstModel + ? { model: firstModel.id, variant: firstModel.variants[0] ?? '' } + : { model: '', variant: '' }; +} + export function useSessionConfigSync({ + activeSessionType, fetchedData, sessionConfig, modelOptions, + selectedModel, + selectedVariant, }: UseSessionConfigSyncOptions): UseSessionConfigSyncResult { + const initialSelection = resolveSessionConfigSelection({ + activeSessionType, + fetchedData, + sessionConfig, + modelOptions, + selectedModel, + selectedVariant, + }); const [currentMode, setCurrentMode] = useState(() => normalizeAgentMode(fetchedData?.mode) ); - const [currentModel, setCurrentModel] = useState(fetchedData?.model ?? ''); - const [currentVariant, setCurrentVariant] = useState(fetchedData?.variant ?? ''); + const [currentModel, setCurrentModel] = useState(initialSelection.model); + const [currentVariant, setCurrentVariant] = useState(initialSelection.variant); useEffect(() => { const mode = sessionConfig?.mode ?? fetchedData?.mode; if (mode) { setCurrentMode(normalizeAgentMode(mode)); } - - const model = sessionConfig?.model ?? fetchedData?.model; - if (model) { - setCurrentModel(model); - } - - const variant = sessionConfig?.variant ?? fetchedData?.variant; - if (variant) { - setCurrentVariant(variant); - } - }, [ - sessionConfig?.mode, - sessionConfig?.model, - sessionConfig?.variant, - fetchedData?.mode, - fetchedData?.model, - fetchedData?.variant, - ]); + }, [sessionConfig?.mode, fetchedData?.mode]); useEffect(() => { - if (currentModel || modelOptions.length === 0 || fetchedData === null) { - return; - } - const firstModel = modelOptions[0]; - if (firstModel) { - setCurrentModel(firstModel.id); - setCurrentVariant(firstModel.variants[0] ?? ''); - } - }, [currentModel, modelOptions, fetchedData]); + const selection = resolveSessionConfigSelection({ + activeSessionType, + fetchedData, + sessionConfig, + modelOptions, + selectedModel, + selectedVariant, + }); + setCurrentModel(selection.model); + setCurrentVariant(selection.variant); + }, [activeSessionType, sessionConfig, fetchedData, modelOptions, selectedModel, selectedVariant]); return { currentMode, diff --git a/apps/mobile/src/lib/hooks/use-session-model-options.ts b/apps/mobile/src/lib/hooks/use-session-model-options.ts new file mode 100644 index 0000000000..5a5ec6760e --- /dev/null +++ b/apps/mobile/src/lib/hooks/use-session-model-options.ts @@ -0,0 +1,393 @@ +/* eslint-disable max-lines -- Model source transitions stay beside their option projections. */ +import { useMemo } from 'react'; +import { + type ModelRef, + type ModelSelection, + type RemoteModelOverride, + type RemoteModelState, + type ResolvedSession, +} from 'cloud-agent-sdk'; + +import { type ModelOption } from '@/lib/hooks/use-available-models'; + +type SessionModelSource = + | 'cloud-agent-gateway' + | 'remote-cli-catalog' + | 'remote-legacy-gateway' + | 'remote-unavailable'; + +export type SessionModelNotice = { + id: 'loading' | 'legacy' | 'error' | 'stale' | 'truncated' | 'unavailable' | 'local-provider'; + message: string; + retry: boolean; +}; + +export type SessionModelOption = { + id: string; + name: string; + displayId: string; + variants: string[]; + isPreferred: boolean; + isFree?: boolean; + mayTrainOnYourPrompts?: boolean; + hasUserByokAvailable?: boolean; + provider?: { id: string; name: string }; + modelRef?: ModelRef; + overrideSource?: RemoteModelOverride['source']; + showGatewayMetadata: boolean; + unavailable?: boolean; + action?: 'use-session-model'; +}; + +type BuildSessionModelOptionsInput = { + activeSessionType: ResolvedSession['type'] | null; + remoteModelState: RemoteModelState; + observedModel: ModelSelection | null; + remoteModelOverride: RemoteModelOverride | null; + gatewayModels: ModelOption[]; + gatewayModelsLoading: boolean; + organizationId?: string; +}; + +type SessionModelOptions = { + source: SessionModelSource; + options: SessionModelOption[]; + selectedValue: string; + selectedVariant: string; + pickerDisabled: boolean; + isLoading: boolean; + notices: SessionModelNotice[]; +}; + +export function useSessionModelOptions({ + activeSessionType, + gatewayModels, + gatewayModelsLoading, + observedModel, + organizationId, + remoteModelOverride, + remoteModelState, +}: BuildSessionModelOptionsInput): SessionModelOptions { + return useMemo( + () => + buildSessionModelOptions({ + activeSessionType, + gatewayModels, + gatewayModelsLoading, + observedModel, + organizationId, + remoteModelOverride, + remoteModelState, + }), + [ + activeSessionType, + gatewayModels, + gatewayModelsLoading, + observedModel, + organizationId, + remoteModelOverride, + remoteModelState, + ] + ); +} + +export function buildSessionModelOptions( + input: BuildSessionModelOptionsInput +): SessionModelOptions { + if (input.activeSessionType === 'remote') { + if (input.remoteModelState.protocol === 'v1' && input.remoteModelState.catalog) { + return buildCliCatalogOptions(input); + } + if (input.remoteModelState.protocol === 'legacy') { + return buildLegacyGatewayOptions(input); + } + return buildUnavailableRemoteOptions(input); + } + + return { + source: 'cloud-agent-gateway', + options: input.gatewayModels.map(createGatewayOption), + selectedValue: '', + selectedVariant: '', + pickerDisabled: false, + isLoading: input.gatewayModelsLoading, + notices: [], + }; +} + +function buildUnavailableRemoteOptions(input: BuildSessionModelOptionsInput): SessionModelOptions { + const currentSelection = getCurrentRemoteSelection(input); + const option = currentSelection + ? createUnavailableOption(currentSelection.model) + : ({ + id: 'remote-session-model', + name: 'Session model', + displayId: '', + variants: [], + isPreferred: false, + showGatewayMetadata: false, + unavailable: true, + } satisfies SessionModelOption); + const loading = input.remoteModelState.refresh === 'loading'; + const notices: SessionModelNotice[] = [ + loading + ? { + id: 'loading', + message: 'Checking this CLI for available models.', + retry: true, + } + : { + id: 'error', + message: "Models from this CLI couldn't be loaded. Sending still uses the session model.", + retry: true, + }, + ]; + if (currentSelection) { + notices.push({ + id: 'unavailable', + message: `${currentSelection.model.modelID} is the session model. It can't be changed until this CLI's models load.`, + retry: false, + }); + } + + return { + source: 'remote-unavailable', + options: [option], + selectedValue: option.id, + selectedVariant: currentSelection?.variant ?? '', + pickerDisabled: true, + isLoading: loading, + notices, + }; +} + +export function revalidateLegacyGatewayOverride( + override: RemoteModelOverride | null, + gatewayModels: ModelOption[] +): RemoteModelOverride | null { + if (override?.source !== 'legacy-gateway') { + return override; + } + + const selectedModel = gatewayModels.find(model => model.id === override.selection.model.modelID); + if (!selectedModel) { + return null; + } + if (!override.selection.variant || selectedModel.variants.includes(override.selection.variant)) { + return override; + } + + return { + source: 'legacy-gateway', + selection: { model: override.selection.model }, + }; +} + +function buildLegacyGatewayOptions(input: BuildSessionModelOptionsInput): SessionModelOptions { + const options: SessionModelOption[] = input.gatewayModels.map(model => ({ + ...createGatewayOption(model), + modelRef: { providerID: 'kilo', modelID: model.id }, + overrideSource: 'legacy-gateway' as const, + })); + const remoteModelOverride = revalidateLegacyGatewayOverride( + input.remoteModelOverride, + input.gatewayModels + ); + const currentSelection = getCurrentRemoteSelection(input, remoteModelOverride); + let selectedOption = currentSelection + ? options.find( + option => option.modelRef && modelRefsEqual(option.modelRef, currentSelection.model) + ) + : undefined; + const notices: SessionModelNotice[] = [ + { + id: 'legacy', + message: + 'This CLI uses Gateway model fallback. Upgrade Kilo CLI to use its configured providers and models.', + retry: false, + }, + ]; + + if (currentSelection && !selectedOption) { + selectedOption = createUnavailableOption(currentSelection.model); + options.unshift(selectedOption); + notices.push(createUnavailableNotice(currentSelection.model)); + } + if (remoteModelOverride) { + options.unshift(createUseSessionModelOption()); + } + + const selectedVariant = + currentSelection?.variant && selectedOption?.variants.includes(currentSelection.variant) + ? currentSelection.variant + : ''; + + return { + source: 'remote-legacy-gateway', + options, + selectedValue: selectedOption?.id ?? '', + selectedVariant, + pickerDisabled: input.gatewayModelsLoading, + isLoading: input.gatewayModelsLoading, + notices, + }; +} + +function buildCliCatalogOptions(input: BuildSessionModelOptionsInput): SessionModelOptions { + const catalog = input.remoteModelState.catalog; + if (!catalog) { + throw new Error('CLI catalog is required for v1 model options'); + } + + let opaqueIndex = 0; + const options = catalog.providers.flatMap(provider => + provider.models.map(model => { + const option: SessionModelOption = { + id: `remote-model-${opaqueIndex}`, + name: model.name ?? model.id, + displayId: model.id, + variants: model.variants, + isPreferred: false, + provider: { id: provider.id, name: provider.name ?? provider.id }, + modelRef: { providerID: provider.id, modelID: model.id }, + overrideSource: 'cli-catalog', + showGatewayMetadata: false, + }; + opaqueIndex += 1; + return option; + }) + ); + const currentSelection = getCurrentRemoteSelection(input); + let selectedOption = currentSelection + ? options.find( + option => option.modelRef && modelRefsEqual(option.modelRef, currentSelection.model) + ) + : undefined; + const notices: SessionModelNotice[] = []; + + if (currentSelection && !selectedOption) { + selectedOption = createUnavailableOption(currentSelection.model); + options.unshift(selectedOption); + notices.push(createUnavailableNotice(currentSelection.model)); + } + if (input.remoteModelOverride) { + options.unshift(createUseSessionModelOption()); + } + if (input.remoteModelState.refresh === 'error') { + notices.unshift({ + id: 'stale', + message: 'Showing the last model catalog because refresh failed.', + retry: true, + }); + } + if (catalog.truncated) { + notices.push({ + id: 'truncated', + message: 'This CLI returned a partial model catalog. Some models or variants may be missing.', + retry: false, + }); + } + if (currentSelection && currentSelection.model.providerID !== 'kilo') { + notices.push({ + id: 'local-provider', + message: input.organizationId + ? "This model runs through your CLI provider, outside Kilo Gateway billing and this organization's model restrictions." + : 'This model runs through your CLI provider, outside Kilo Gateway billing.', + retry: false, + }); + } + + const selectedVariant = + currentSelection?.variant && selectedOption?.variants.includes(currentSelection.variant) + ? currentSelection.variant + : ''; + + return { + source: 'remote-cli-catalog', + options, + selectedValue: selectedOption?.id ?? '', + selectedVariant, + pickerDisabled: false, + isLoading: false, + notices, + }; +} + +function createUseSessionModelOption(): SessionModelOption { + return { + id: 'remote-use-session-model', + name: 'Use session model', + displayId: 'Stop overriding the model selected by the CLI', + variants: [], + isPreferred: false, + provider: { id: 'session', name: 'Session' }, + showGatewayMetadata: false, + action: 'use-session-model', + }; +} + +function createUnavailableOption(modelRef: ModelRef): SessionModelOption { + return { + id: 'remote-unavailable-model', + name: modelRef.modelID, + displayId: modelRef.modelID, + variants: [], + isPreferred: false, + provider: { id: modelRef.providerID, name: modelRef.providerID }, + modelRef, + showGatewayMetadata: false, + unavailable: true, + }; +} + +function createUnavailableNotice(modelRef: ModelRef): SessionModelNotice { + return { + id: 'unavailable', + message: `${modelRef.modelID} is the session model but is not available in this catalog.`, + retry: false, + }; +} + +function createGatewayOption(model: ModelOption): SessionModelOption { + return { + ...model, + displayId: model.id, + showGatewayMetadata: true, + }; +} + +function getCurrentRemoteSelection( + input: BuildSessionModelOptionsInput, + remoteModelOverride = input.remoteModelOverride +): ModelSelection | null { + const defaultModel = input.remoteModelState.catalog?.defaultModel; + return ( + remoteModelOverride?.selection ?? + input.observedModel ?? + (defaultModel ? { model: defaultModel } : null) + ); +} + +// Mirrors the SDK's modelRefsEqual. Kept local because mobile imports only +// types from cloud-agent-sdk (vitest does not resolve the SDK value barrel). +function modelRefsEqual(left: ModelRef, right: ModelRef): boolean { + return left.providerID === right.providerID && left.modelID === right.modelID; +} + +export function createRemoteModelOverride( + option: SessionModelOption | undefined, + variant: string +): RemoteModelOverride | null { + if (!option?.modelRef || !option.overrideSource || option.unavailable) { + return null; + } + + const validVariant = option.variants.includes(variant) ? variant : undefined; + return { + source: option.overrideSource, + selection: { + model: option.modelRef, + ...(validVariant ? { variant: validVariant } : {}), + }, + }; +} diff --git a/apps/mobile/src/lib/model-picker-rows.test.ts b/apps/mobile/src/lib/model-picker-rows.test.ts index fc1f5dd9cf..d7fd5f1038 100644 --- a/apps/mobile/src/lib/model-picker-rows.test.ts +++ b/apps/mobile/src/lib/model-picker-rows.test.ts @@ -1,45 +1,88 @@ import { describe, expect, it } from 'vitest'; -import { type ModelOption } from '@/lib/hooks/use-available-models'; +import { type SessionModelOption } from '@/lib/hooks/use-session-model-options'; import { buildModelPickerRows } from './model-picker-rows'; -const models: ModelOption[] = [ +const gatewayModels: SessionModelOption[] = [ { id: 'anthropic/claude-sonnet-4', name: 'Claude Sonnet 4', + displayId: 'anthropic/claude-sonnet-4', variants: ['low'], isPreferred: true, + showGatewayMetadata: true, }, { id: 'openai/gpt-5', name: 'GPT-5', + displayId: 'openai/gpt-5', variants: ['medium'], isPreferred: false, + showGatewayMetadata: true, + }, +]; + +const remoteModels: SessionModelOption[] = [ + { + id: 'remote-model-0', + name: 'Workspace Claude', + displayId: 'shared/model.id', + variants: ['low', 'high'], + isPreferred: false, + provider: { id: 'anthropic-local', name: 'Anthropic Local' }, + modelRef: { providerID: 'anthropic-local', modelID: 'shared/model.id' }, + overrideSource: 'cli-catalog', + showGatewayMetadata: false, + }, + { + id: 'remote-model-1', + name: 'Internal Deployment', + displayId: 'shared/model.id', + variants: [], + isPreferred: false, + provider: { id: 'custom-openai', name: 'Custom OpenAI' }, + modelRef: { providerID: 'custom-openai', modelID: 'shared/model.id' }, + overrideSource: 'cli-catalog', + showGatewayMetadata: false, }, ]; describe('buildModelPickerRows', () => { - it('groups preferred models before all other models', () => { - expect(buildModelPickerRows({ models, search: '' })).toEqual([ + it('preserves Recommended and All groups for Gateway models', () => { + expect(buildModelPickerRows({ models: gatewayModels, search: '' })).toEqual([ { key: 'recommended', title: 'RECOMMENDED', type: 'header' }, - { key: 'model:anthropic/claude-sonnet-4', model: models[0], type: 'model' }, + { key: 'model:anthropic/claude-sonnet-4', model: gatewayModels[0], type: 'model' }, { key: 'all', title: 'ALL MODELS', type: 'header' }, - { key: 'model:openai/gpt-5', model: models[1], type: 'model' }, + { key: 'model:openai/gpt-5', model: gatewayModels[1], type: 'model' }, ]); }); - it('filters models by name', () => { - expect(buildModelPickerRows({ models, search: 'Sonnet 4' })).toEqual([ + it('filters Gateway models by display name and id', () => { + expect(buildModelPickerRows({ models: gatewayModels, search: 'Sonnet 4' })).toEqual([ { key: 'recommended', title: 'RECOMMENDED', type: 'header' }, - { key: 'model:anthropic/claude-sonnet-4', model: models[0], type: 'model' }, + { key: 'model:anthropic/claude-sonnet-4', model: gatewayModels[0], type: 'model' }, + ]); + expect(buildModelPickerRows({ models: gatewayModels, search: 'openai/' })).toEqual([ + { key: 'all', title: 'ALL MODELS', type: 'header' }, + { key: 'model:openai/gpt-5', model: gatewayModels[1], type: 'model' }, ]); }); - it('filters models by id', () => { - expect(buildModelPickerRows({ models, search: 'openai/' })).toEqual([ - { key: 'all', title: 'ALL MODELS', type: 'header' }, - { key: 'model:openai/gpt-5', model: models[1], type: 'model' }, + it('groups CLI models by provider and searches provider/model display data, never opaque keys', () => { + expect(buildModelPickerRows({ models: remoteModels, search: '' })).toEqual([ + { key: 'provider:anthropic-local', title: 'ANTHROPIC LOCAL', type: 'header' }, + { key: 'model:remote-model-0', model: remoteModels[0], type: 'model' }, + { key: 'provider:custom-openai', title: 'CUSTOM OPENAI', type: 'header' }, + { key: 'model:remote-model-1', model: remoteModels[1], type: 'model' }, + ]); + expect(buildModelPickerRows({ models: remoteModels, search: 'custom-openai' })).toEqual([ + { key: 'provider:custom-openai', title: 'CUSTOM OPENAI', type: 'header' }, + { key: 'model:remote-model-1', model: remoteModels[1], type: 'model' }, ]); + expect(buildModelPickerRows({ models: remoteModels, search: 'shared/model.id' })).toHaveLength( + 4 + ); + expect(buildModelPickerRows({ models: remoteModels, search: 'remote-model-0' })).toEqual([]); }); }); diff --git a/apps/mobile/src/lib/model-picker-rows.ts b/apps/mobile/src/lib/model-picker-rows.ts index 863a1d6ef4..e5842a9179 100644 --- a/apps/mobile/src/lib/model-picker-rows.ts +++ b/apps/mobile/src/lib/model-picker-rows.ts @@ -1,36 +1,62 @@ -import { type ModelOption } from '@/lib/hooks/use-available-models'; +import { type SessionModelOption } from '@/lib/hooks/use-session-model-options'; export type ModelPickerRow = | { key: string; title: string; type: 'header' } - | { key: string; model: ModelOption; type: 'model' }; + | { key: string; model: SessionModelOption; type: 'model' }; + +type ModelGroup = { + key: string; + title: string; + models: SessionModelOption[]; +}; export function buildModelPickerRows({ models, search, }: { - models: ModelOption[]; + models: SessionModelOption[]; search: string; }): ModelPickerRow[] { const query = search.toLowerCase().trim(); - const filtered = models.filter( - m => !query || m.name.toLowerCase().includes(query) || m.id.toLowerCase().includes(query) - ); - - const recommended = filtered.filter(m => m.isPreferred); - const all = filtered.filter(m => !m.isPreferred); - const result: ModelPickerRow[] = []; + const filtered = models.filter(model => !query || searchableText(model).includes(query)); + const groups = new Map(); - if (recommended.length > 0) { - result.push({ key: 'recommended', title: 'RECOMMENDED', type: 'header' }); - for (const modelOption of recommended) { - result.push({ key: `model:${modelOption.id}`, model: modelOption, type: 'model' }); + for (const model of filtered) { + const group = groupForModel(model); + const existing = groups.get(group.key); + if (existing) { + existing.models.push(model); + } else { + groups.set(group.key, { ...group, models: [model] }); } } - if (all.length > 0) { - result.push({ key: 'all', title: 'ALL MODELS', type: 'header' }); - for (const modelOption of all) { - result.push({ key: `model:${modelOption.id}`, model: modelOption, type: 'model' }); + + const rows: ModelPickerRow[] = []; + for (const group of groups.values()) { + rows.push({ key: group.key, title: group.title, type: 'header' }); + for (const model of group.models) { + rows.push({ key: `model:${model.id}`, model, type: 'model' }); } } - return result; + return rows; +} + +function searchableText(model: SessionModelOption): string { + return [model.name, model.displayId, model.provider?.name, model.provider?.id] + .filter(Boolean) + .join(' ') + .toLowerCase(); +} + +function groupForModel(model: SessionModelOption): Pick { + if (model.provider) { + return { + key: `provider:${model.provider.id}`, + title: model.provider.name.toUpperCase(), + }; + } + if (model.isPreferred) { + return { key: 'recommended', title: 'RECOMMENDED' }; + } + return { key: 'all', title: 'ALL MODELS' }; } diff --git a/apps/mobile/src/lib/picker-bridge.test.ts b/apps/mobile/src/lib/picker-bridge.test.ts new file mode 100644 index 0000000000..9189500e16 --- /dev/null +++ b/apps/mobile/src/lib/picker-bridge.test.ts @@ -0,0 +1,170 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import { + areModelPickerSelectionScopesEqual, + clearModelPickerBridge, + commitModelPickerSelection, + getModelPickerBridge, + resolveModelPickerSelection, + setModelPickerBridge, +} from './picker-bridge'; + +const remoteOption = { + id: 'remote-model-0', + name: 'Workspace Claude', + displayId: 'shared/model.id', + variants: ['low', 'high'], + isPreferred: false, + provider: { id: 'anthropic-local', name: 'Anthropic Local' }, + modelRef: { providerID: 'anthropic-local', modelID: 'shared/model.id' }, + overrideSource: 'cli-catalog' as const, + showGatewayMetadata: false, +}; + +const currentSelectionScope = { + selectionScope: { + sessionId: 'session-a', + ownerConnectionId: 'owner-a', + protocol: 'v1' as const, + catalogGenerationIdentity: {}, + }, + isSelectionCurrent: () => true, +}; + +describe('model picker bridge', () => { + beforeEach(() => { + clearModelPickerBridge(); + }); + + it('preserves exact model identity and override source while resetting an invalid variant', () => { + const onSelect = vi.fn(); + setModelPickerBridge({ + ...currentSelectionScope, + options: [remoteOption], + currentValue: remoteOption.id, + currentVariant: 'removed', + onSelect: selection => { + onSelect(selection); + }, + }); + + const bridge = getModelPickerBridge(); + expect(bridge).not.toBeNull(); + if (!bridge) { + throw new Error('Expected model picker bridge'); + } + + const selection = resolveModelPickerSelection(bridge, remoteOption.id, 'removed'); + if (!selection) { + throw new Error('Expected model picker selection'); + } + expect(selection).toEqual({ + option: remoteOption, + variant: 'low', + }); + expect(selection.option.modelRef).toEqual({ + providerID: 'anthropic-local', + modelID: 'shared/model.id', + }); + expect(selection.option.overrideSource).toBe('cli-catalog'); + }); + + it('treats session, owner, protocol, and catalog generation changes as stale scopes', () => { + const catalogGenerationIdentity = {}; + const scope = { + sessionId: 'session-a', + ownerConnectionId: 'owner-a', + protocol: 'v1' as const, + catalogGenerationIdentity, + }; + + expect(areModelPickerSelectionScopesEqual(scope, scope)).toBe(true); + expect(areModelPickerSelectionScopesEqual(scope, { ...scope, sessionId: 'session-b' })).toBe( + false + ); + expect( + areModelPickerSelectionScopesEqual(scope, { ...scope, ownerConnectionId: 'owner-b' }) + ).toBe(false); + expect(areModelPickerSelectionScopesEqual(scope, { ...scope, protocol: 'legacy' })).toBe(false); + expect( + areModelPickerSelectionScopesEqual(scope, { ...scope, catalogGenerationIdentity: {} }) + ).toBe(false); + }); + + it('discards a detached selection when its session catalog scope is stale', () => { + const onSelect = vi.fn(); + const catalogGenerationIdentity = {}; + const bridge = { + options: [remoteOption], + currentValue: remoteOption.id, + currentVariant: 'low', + selectionScope: { + sessionId: 'session-a', + ownerConnectionId: 'owner-a', + protocol: 'v1' as const, + catalogGenerationIdentity, + }, + isSelectionCurrent: vi.fn(() => false), + onSelect, + }; + + expect(commitModelPickerSelection(bridge, remoteOption.id, 'high')).toBe(false); + expect(bridge.isSelectionCurrent).toHaveBeenCalledWith(bridge.selectionScope); + expect(onSelect).not.toHaveBeenCalled(); + }); + + it('commits a detached selection while its session catalog scope is current', () => { + const onSelect = vi.fn(); + const bridge = { + options: [remoteOption], + currentValue: remoteOption.id, + currentVariant: 'low', + selectionScope: { + sessionId: 'session-a', + ownerConnectionId: 'owner-a', + protocol: 'v1' as const, + catalogGenerationIdentity: {}, + }, + isSelectionCurrent: vi.fn(() => true), + onSelect, + }; + + expect(commitModelPickerSelection(bridge, remoteOption.id, 'high')).toBe(true); + expect(onSelect).toHaveBeenCalledWith({ option: remoteOption, variant: 'high' }); + }); + + it('passes the Use session model action through without inventing an override', () => { + const resetOption = { + id: 'remote-use-session-model', + name: 'Use session model', + displayId: 'Stop overriding the model selected by the CLI', + variants: [], + isPreferred: false, + provider: { id: 'session', name: 'Session' }, + showGatewayMetadata: false, + action: 'use-session-model' as const, + }; + const onSelect = vi.fn(); + setModelPickerBridge({ + ...currentSelectionScope, + options: [resetOption], + currentValue: remoteOption.id, + currentVariant: 'high', + onSelect: selection => { + onSelect(selection); + }, + }); + + const bridge = getModelPickerBridge(); + if (!bridge) { + throw new Error('Expected model picker bridge'); + } + const selection = resolveModelPickerSelection(bridge, resetOption.id, 'high'); + if (!selection) { + throw new Error('Expected model picker selection'); + } + bridge.onSelect(selection); + + expect(onSelect).toHaveBeenCalledWith({ option: resetOption, variant: '' }); + }); +}); diff --git a/apps/mobile/src/lib/picker-bridge.ts b/apps/mobile/src/lib/picker-bridge.ts index 2bbe673d10..e9bc7c6a1a 100644 --- a/apps/mobile/src/lib/picker-bridge.ts +++ b/apps/mobile/src/lib/picker-bridge.ts @@ -1,13 +1,39 @@ import { type AgentMode } from '@/components/agents/mode-selector'; -import { type ModelOption } from '@/lib/hooks/use-available-models'; +import { type SessionModelOption } from '@/lib/hooks/use-session-model-options'; + +export type ModelPickerSelection = { + option: SessionModelOption; + variant: string; +}; + +export type ModelPickerSelectionScope = { + sessionId: string; + ownerConnectionId: string | null; + protocol: 'unknown' | 'legacy' | 'v1'; + catalogGenerationIdentity: object | null; +}; type ModelPickerBridge = { - options: ModelOption[]; + options: SessionModelOption[]; currentValue: string; currentVariant: string; - onSelect: (id: string, variant: string) => void; + selectionScope: ModelPickerSelectionScope; + isSelectionCurrent: (scope: ModelPickerSelectionScope) => boolean; + onSelect: (selection: ModelPickerSelection) => void; }; +export function areModelPickerSelectionScopesEqual( + left: ModelPickerSelectionScope, + right: ModelPickerSelectionScope +): boolean { + return ( + left.sessionId === right.sessionId && + left.ownerConnectionId === right.ownerConnectionId && + left.protocol === right.protocol && + left.catalogGenerationIdentity === right.catalogGenerationIdentity + ); +} + type ModePickerBridge = { currentValue: AgentMode; onSelect: (mode: AgentMode) => void; @@ -28,6 +54,40 @@ let modelBridge: ModelPickerBridge | null = null; let modeBridge: ModePickerBridge | null = null; let repoBridge: RepoPickerBridge | null = null; +export function resolveModelPickerSelection( + bridge: ModelPickerBridge, + value: string, + variant: string +): ModelPickerSelection | null { + const option = bridge.options.find(candidate => candidate.id === value); + if (!option) { + return null; + } + + return { + option, + variant: option.variants.includes(variant) ? variant : (option.variants[0] ?? ''), + }; +} + +export function commitModelPickerSelection( + bridge: ModelPickerBridge, + value: string, + variant: string +): boolean { + if (!bridge.isSelectionCurrent(bridge.selectionScope)) { + return false; + } + + const selection = resolveModelPickerSelection(bridge, value, variant); + if (!selection) { + return false; + } + + bridge.onSelect(selection); + return true; +} + export function setModelPickerBridge(bridge: ModelPickerBridge) { modelBridge = bridge; } diff --git a/apps/mobile/src/lib/use-session-model-options.test.ts b/apps/mobile/src/lib/use-session-model-options.test.ts new file mode 100644 index 0000000000..d5cc935e12 --- /dev/null +++ b/apps/mobile/src/lib/use-session-model-options.test.ts @@ -0,0 +1,290 @@ +import { describe, expect, it } from 'vitest'; + +import { + buildSessionModelOptions, + createRemoteModelOverride, + revalidateLegacyGatewayOverride, +} from './hooks/use-session-model-options'; + +const gatewayModels = [ + { + id: 'gateway/model', + name: 'Gateway Model', + variants: ['high'], + isPreferred: true, + isFree: true, + }, +]; + +describe('revalidateLegacyGatewayOverride', () => { + it('clears an override when its Gateway model is removed', () => { + expect( + revalidateLegacyGatewayOverride( + { + source: 'legacy-gateway', + selection: { + model: { providerID: 'kilo', modelID: 'removed/model' }, + variant: 'high', + }, + }, + gatewayModels + ) + ).toBeNull(); + }); + + it('retains a valid Gateway model and variant without replacing the override', () => { + const override = { + source: 'legacy-gateway' as const, + selection: { + model: { providerID: 'kilo', modelID: 'gateway/model' }, + variant: 'high', + }, + }; + + expect(revalidateLegacyGatewayOverride(override, gatewayModels)).toBe(override); + }); + + it('retains a Gateway model while removing a variant no longer offered', () => { + expect( + revalidateLegacyGatewayOverride( + { + source: 'legacy-gateway', + selection: { + model: { providerID: 'kilo', modelID: 'gateway/model' }, + variant: 'removed', + }, + }, + gatewayModels + ) + ).toEqual({ + source: 'legacy-gateway', + selection: { model: { providerID: 'kilo', modelID: 'gateway/model' } }, + }); + }); +}); + +describe('buildSessionModelOptions', () => { + it('uses provider-aware CLI rows with distinct opaque values for duplicate model IDs', () => { + const result = buildSessionModelOptions({ + activeSessionType: 'remote', + remoteModelState: { + ownerConnectionId: 'cli-owner', + protocol: 'v1', + refresh: 'idle', + catalog: { + protocolVersion: 1, + truncated: false, + providers: [ + { + id: 'anthropic-local', + name: 'Anthropic Local', + models: [ + { + id: 'shared/model.id', + name: 'Claude Workspace', + variants: ['low', 'high'], + capabilities: { attachment: true, reasoning: true }, + limits: { context: 200_000, output: 8192 }, + }, + ], + }, + { + id: 'custom-openai', + name: 'Custom OpenAI', + models: [ + { + id: 'shared/model.id', + name: 'Internal Deployment', + variants: [], + capabilities: { attachment: false, reasoning: false }, + limits: { context: 32_000, output: 4096 }, + }, + ], + }, + ], + }, + }, + observedModel: { + model: { providerID: 'anthropic-local', modelID: 'shared/model.id' }, + variant: 'high', + }, + remoteModelOverride: null, + gatewayModels, + gatewayModelsLoading: false, + organizationId: 'org-persisted', + }); + + expect(result.source).toBe('remote-cli-catalog'); + expect(result.options).toHaveLength(2); + expect(result.options[0]?.id).not.toBe(result.options[1]?.id); + expect(result.options.map(option => option.id)).not.toContain('shared/model.id'); + expect(result.options.map(option => option.displayId)).toEqual([ + 'shared/model.id', + 'shared/model.id', + ]); + expect(result.options.map(option => option.provider?.name)).toEqual([ + 'Anthropic Local', + 'Custom OpenAI', + ]); + expect(result.options[0]).toMatchObject({ + modelRef: { providerID: 'anthropic-local', modelID: 'shared/model.id' }, + overrideSource: 'cli-catalog', + variants: ['low', 'high'], + showGatewayMetadata: false, + }); + expect(result.selectedValue).toBe(result.options[0]?.id); + expect(result.selectedVariant).toBe('high'); + }); + + it('uses Gateway rows only for an exact legacy CLI and preserves override provenance', () => { + const result = buildSessionModelOptions({ + activeSessionType: 'remote', + remoteModelState: { + ownerConnectionId: 'legacy-owner', + protocol: 'legacy', + refresh: 'idle', + }, + observedModel: { + model: { providerID: 'local-provider', modelID: 'private-model' }, + variant: 'old-variant', + }, + remoteModelOverride: null, + gatewayModels, + gatewayModelsLoading: false, + organizationId: 'org-from-session', + }); + + const gatewayOption = result.options.find(option => option.overrideSource === 'legacy-gateway'); + const unavailableOption = result.options.find(option => option.unavailable); + + expect(result.source).toBe('remote-legacy-gateway'); + expect(gatewayOption).toMatchObject({ + id: 'gateway/model', + modelRef: { providerID: 'kilo', modelID: 'gateway/model' }, + overrideSource: 'legacy-gateway', + showGatewayMetadata: true, + }); + expect(unavailableOption).toMatchObject({ + displayId: 'private-model', + modelRef: { providerID: 'local-provider', modelID: 'private-model' }, + unavailable: true, + }); + expect(result.selectedValue).toBe(unavailableOption?.id); + expect(result.selectedVariant).toBe(''); + expect(result.notices.map(notice => notice.id)).toEqual(['legacy', 'unavailable']); + }); + + it('does not project a removed legacy Gateway override as the selected model', () => { + const result = buildSessionModelOptions({ + activeSessionType: 'remote', + remoteModelState: { + ownerConnectionId: 'legacy-owner', + protocol: 'legacy', + refresh: 'idle', + }, + observedModel: { + model: { providerID: 'kilo', modelID: 'gateway/model' }, + variant: 'high', + }, + remoteModelOverride: { + source: 'legacy-gateway', + selection: { + model: { providerID: 'kilo', modelID: 'removed/model' }, + variant: 'high', + }, + }, + gatewayModels, + gatewayModelsLoading: false, + organizationId: undefined, + }); + + expect(result.selectedValue).toBe('gateway/model'); + expect(result.selectedVariant).toBe('high'); + expect(result.options.some(option => option.unavailable)).toBe(false); + expect(result.options.some(option => option.action === 'use-session-model')).toBe(false); + }); + + it('disables model changes when remote discovery fails without exposing Gateway rows', () => { + const result = buildSessionModelOptions({ + activeSessionType: 'remote', + remoteModelState: { + ownerConnectionId: 'cli-owner', + protocol: 'unknown', + refresh: 'error', + error: 'catalog request failed', + }, + observedModel: null, + remoteModelOverride: null, + gatewayModels, + gatewayModelsLoading: false, + organizationId: 'org-persisted', + }); + + expect(result.source).toBe('remote-unavailable'); + expect(result.options).toEqual([ + expect.objectContaining({ name: 'Session model', unavailable: true }), + ]); + expect(result.options.some(option => option.id === gatewayModels[0]?.id)).toBe(false); + expect(result.pickerDisabled).toBe(true); + expect(result.notices).toEqual([expect.objectContaining({ id: 'error', retry: true })]); + }); + + it('retains a stale v1 catalog with reset, truncation, and local-provider notices', () => { + const result = buildSessionModelOptions({ + activeSessionType: 'remote', + remoteModelState: { + ownerConnectionId: 'cli-owner', + protocol: 'v1', + refresh: 'error', + error: 'refresh failed', + catalog: { + protocolVersion: 1, + truncated: true, + providers: [ + { + id: 'local-provider', + name: 'Local Provider', + models: [ + { + id: 'private-model', + variants: ['low', 'high'], + capabilities: { attachment: false, reasoning: true }, + limits: { context: 64_000, output: 4096 }, + }, + ], + }, + ], + }, + }, + observedModel: null, + remoteModelOverride: { + source: 'cli-catalog', + selection: { + model: { providerID: 'local-provider', modelID: 'private-model' }, + variant: 'high', + }, + }, + gatewayModels, + gatewayModelsLoading: false, + organizationId: 'org-persisted', + }); + + const resetOption = result.options.find(option => option.action === 'use-session-model'); + const cliOption = result.options.find(option => option.overrideSource === 'cli-catalog'); + + expect(result.notices.map(notice => notice.id)).toEqual([ + 'stale', + 'truncated', + 'local-provider', + ]); + expect(result.notices[2]?.message).toContain("organization's model restrictions"); + expect(resetOption).toMatchObject({ name: 'Use session model', action: 'use-session-model' }); + expect(createRemoteModelOverride(resetOption, 'high')).toBeNull(); + expect(createRemoteModelOverride(cliOption, 'removed')).toEqual({ + source: 'cli-catalog', + selection: { model: { providerID: 'local-provider', modelID: 'private-model' } }, + }); + expect(result.selectedValue).toBe(cliOption?.id); + expect(result.selectedVariant).toBe('high'); + }); +}); diff --git a/apps/web/src/components/cloud-agent-next/CloudAgentProvider.tsx b/apps/web/src/components/cloud-agent-next/CloudAgentProvider.tsx index 95d1c8654e..d7d7d02204 100644 --- a/apps/web/src/components/cloud-agent-next/CloudAgentProvider.tsx +++ b/apps/web/src/components/cloud-agent-next/CloudAgentProvider.tsx @@ -35,11 +35,14 @@ function normalizeTransportPayload(payload: TransportSendPayload): SendMessagePa if (payload.type === 'prompt') { if (!payload.mode) throw new Error('Cloud Agent mode is required'); if (!payload.model) throw new Error('Cloud Agent model is required'); + if (payload.model.providerID !== 'kilo') { + throw new Error('Cloud Agent only supports Kilo models'); + } return { type: 'prompt', prompt: payload.prompt, mode: payload.mode, - model: payload.model, + model: payload.model.modelID, variant: payload.variant, }; } @@ -130,13 +133,15 @@ export function CloudAgentProvider({ children, organizationId }: CloudAgentProvi trpcClient.cliSessionsV2.get.query({ session_id: id }), trpcClient.cliSessionsV2.getSessionMessages.query({ session_id: id }), ]); + // Zod .passthrough() adds index signatures that TS can't prove assignable to strict types. + // The tRPC/Zod layer has already validated the shape, so these casts are safe at this boundary. + const snapshotInfo = messagesResult.info as Partial; return { info: { - id: sessionData.session_id, - parentID: sessionData.parent_session_id ?? undefined, + id: snapshotInfo.id ?? sessionData.session_id, + parentID: snapshotInfo.parentID ?? sessionData.parent_session_id ?? undefined, + ...(snapshotInfo.model ? { model: snapshotInfo.model } : {}), }, - // Zod .passthrough() adds index signatures that TS can't prove assignable to strict types. - // The tRPC/Zod layer has already validated the shape, so this cast is safe at this boundary. messages: messagesResult.messages as SessionSnapshot['messages'], }; }, @@ -149,13 +154,11 @@ export function CloudAgentProvider({ children, organizationId }: CloudAgentProvi api: { send: async input => { - const normalizedPayload = normalizeTransportPayload(input.payload); - if (organizationId) { return trpcClient.organizations.cloudAgentNext.sendMessage.mutate( { cloudAgentSessionId: input.sessionId, - payload: normalizedPayload, + payload: input.payload, autoCommit: true, organizationId, messageId: input.messageId, @@ -167,7 +170,7 @@ export function CloudAgentProvider({ children, organizationId }: CloudAgentProvi return trpcClient.cloudAgentNext.sendMessage.mutate( { cloudAgentSessionId: input.sessionId, - payload: normalizedPayload, + payload: input.payload, autoCommit: true, messageId: input.messageId, attachments: input.attachments ?? input.images, diff --git a/apps/web/src/components/cloud-agent-next/CloudChatPage.tsx b/apps/web/src/components/cloud-agent-next/CloudChatPage.tsx index 44021bcbd3..4c82dd6861 100644 --- a/apps/web/src/components/cloud-agent-next/CloudChatPage.tsx +++ b/apps/web/src/components/cloud-agent-next/CloudChatPage.tsx @@ -5,7 +5,7 @@ import { useAtomValue, useSetAtom } from 'jotai'; import { useSearchParams } from 'next/navigation'; import { useMutation, useQueryClient } from '@tanstack/react-query'; import { useTRPC } from '@/lib/trpc/utils'; -import { ArrowDown, GitBranch } from 'lucide-react'; +import { ArrowDown, GitBranch, Info, RefreshCw } from 'lucide-react'; import { v4 as uuidv4 } from 'uuid'; import type { KiloSessionId } from '@/lib/cloud-agent-sdk'; @@ -37,7 +37,12 @@ import { terminalTabId, } from './terminal-tabs'; import { isMessageStreaming } from './types'; -import { useOrganizationModels } from './hooks/useOrganizationModels'; +import { + createRemoteModelOverride, + useSessionModels, + validateRemoteModelOverride, + type SessionModelNotice, +} from './hooks/useSessionModels'; import { ContextUsageIndicator } from './ContextUsageIndicator'; import { resolveContextWindow } from './model-context-lengths'; import { useSlashCommandSets } from '@/hooks/useSlashCommandSets'; @@ -142,6 +147,41 @@ type CloudChatPageProps = { organizationId?: string }; type TerminalStatusSummary = { status: TerminalStatus; statusText: string }; +function SessionModelNotices({ + notices, + onRetry, +}: { + notices: SessionModelNotice[]; + onRetry: () => void; +}) { + if (notices.length === 0) return null; + + return ( +
+ {notices.map(notice => ( +
+ + {notice.message} + {notice.retry && ( + + )} +
+ ))} +
+ ); +} + function TerminalPaneSlot({ terminalId, active, @@ -191,7 +231,7 @@ export default function CloudChatPage({ organizationId }: CloudChatPageProps) { const childSessionDrawerFocusTargetRef = useRef(null); // URL-driven session switching - const sessionIdFromParams = searchParams?.get('sessionId'); + const sessionIdFromParams = searchParams?.get('sessionId') ?? null; useEffect(() => { if (sessionIdFromParams) { childSessionDrawerFocusTargetRef.current = null; @@ -222,6 +262,10 @@ export default function CloudChatPage({ organizationId }: CloudChatPageProps) { const contextUsage = useAtomValue(manager.atoms.contextUsage); const getChildMessages = useAtomValue(manager.atoms.childMessages); const fetchedSessionData = useAtomValue(manager.atoms.fetchedSessionData); + const activeSessionType = useAtomValue(manager.atoms.activeSessionType); + const remoteModelState = useAtomValue(manager.atoms.remoteModelState); + const observedModel = useAtomValue(manager.atoms.observedModel); + const remoteModelOverride = useAtomValue(manager.atoms.remoteModelOverride); const setSessionConfig = useSetAtom(manager.atoms.sessionConfig); @@ -237,10 +281,38 @@ export default function CloudChatPage({ organizationId }: CloudChatPageProps) { setTerminalStatuses({}); }, [sessionId]); - // -- Organization models -------------------------------------------------- - const { modelOptions, isLoadingModels, contextLengthByModelId } = - useOrganizationModels(organizationId); - const contextWindow = resolveContextWindow(contextUsage, contextLengthByModelId); + // -- Session models ------------------------------------------------------- + const sessionModels = useSessionModels({ + activeSessionType, + remoteModelState, + observedModel, + remoteModelOverride, + gatewayModelId: sessionConfig?.model, + gatewayVariant: sessionConfig?.variant, + fetchedSessionData, + routeOrganizationId: organizationId, + sessionIdFromParams, + }); + const { modelOptions, isLoadingModels } = sessionModels; + + useEffect(() => { + if (sessionModels.source !== 'remote-legacy-gateway' || isLoadingModels) return; + + const validatedOverride = validateRemoteModelOverride( + remoteModelOverride, + modelOptions, + 'legacy-gateway' + ); + if (validatedOverride !== remoteModelOverride) { + manager.setRemoteModelOverride(validatedOverride); + } + }, [isLoadingModels, manager, modelOptions, remoteModelOverride, sessionModels.source]); + + const contextWindow = resolveContextWindow( + contextUsage, + sessionModels.gatewayContextLengthByModelId, + sessionModels.remoteContextLengthByProviderAndModel + ); const { availableCommands } = useSlashCommandSets(); // -- Sound effects -------------------------------------------------------- @@ -445,6 +517,10 @@ export default function CloudChatPage({ organizationId }: CloudChatPageProps) { void manager.interrupt(); }, [manager]); + const handleRetryRemoteModels = useCallback(() => { + manager.retryRemoteModels(); + }, [manager]); + const handleToggleSound = useCallback(() => { setSoundEnabled(prev => !prev); }, [setSoundEnabled]); @@ -565,11 +641,13 @@ export default function CloudChatPage({ organizationId }: CloudChatPageProps) { const agentVariantOverride = agentModelOverride ? selectedRuntimeAgent?.variant?.trim() || undefined : undefined; - const displayModel = agentModelOverride ?? sessionConfig?.model; - const modelPickerLocked = !!agentModelOverride; + const modelPickerLocked = activeSessionType === 'cloud-agent' && !!agentModelOverride; + const displayModel = modelPickerLocked ? agentModelOverride : sessionModels.selectedValue; const lockTooltip = modelPickerLocked ? `Locked by agent "${selectedRuntimeAgent?.name}"` - : undefined; + : sessionModels.modelPickerDisabled + ? 'Model changes are unavailable until this CLI model catalog is loaded.' + : undefined; const handleModeChange = useCallback( (mode: AgentMode) => { @@ -580,23 +658,50 @@ export default function CloudChatPage({ organizationId }: CloudChatPageProps) { const handleModelChange = useCallback( (model: string) => { + if (activeSessionType === 'remote') { + const option = modelOptions.find(candidate => candidate.id === model); + manager.setRemoteModelOverride( + createRemoteModelOverride(option, sessionModels.selectedVariant) + ); + return; + } if (!sessionConfig) return; - // Reset variant to first available (typically "none") when switching models if current is invalid - const newModelVariants = modelOptions.find(m => m.id === model)?.variants ?? []; + + const newModelVariants = + modelOptions.find(candidate => candidate.id === model)?.variants ?? []; const validVariant = sessionConfig.variant && newModelVariants.includes(sessionConfig.variant) ? sessionConfig.variant : newModelVariants[0]; setSessionConfig({ ...sessionConfig, model, variant: validVariant }); }, - [sessionConfig, setSessionConfig, modelOptions] + [ + activeSessionType, + manager, + modelOptions, + sessionConfig, + sessionModels.selectedVariant, + setSessionConfig, + ] ); const handleVariantChange = useCallback( (variant: string) => { + if (activeSessionType === 'remote') { + const option = modelOptions.find(candidate => candidate.id === sessionModels.selectedValue); + manager.setRemoteModelOverride(createRemoteModelOverride(option, variant)); + return; + } if (sessionConfig) setSessionConfig({ ...sessionConfig, variant }); }, - [sessionConfig, setSessionConfig] + [ + activeSessionType, + manager, + modelOptions, + sessionConfig, + sessionModels.selectedValue, + setSessionConfig, + ] ); // -- Delayed loading indicator (avoid flash for fast switches) ------------ @@ -612,18 +717,23 @@ export default function CloudChatPage({ organizationId }: CloudChatPageProps) { // -- Derived state -------------------------------------------------------- const showChatInterface = Boolean(sessionConfig) || Boolean(sessionIdFromParams); - const currentModelOption = modelOptions.find(m => m.id === sessionConfig?.model); + const currentModelOption = modelOptions.find(model => + activeSessionType === 'remote' + ? model.id === sessionModels.selectedValue + : model.id === sessionConfig?.model + ); const modelDisplayName = currentModelOption?.name ? formatShortModelDisplayName(currentModelOption.name) : undefined; - const availableVariants = currentModelOption?.variants ?? []; + const availableVariants = + activeSessionType === 'remote' + ? sessionModels.availableVariants + : (currentModelOption?.variants ?? []); // When an agent locks the model, swap the user's session variant for the // agent's variant (which may be undefined — i.e. no thinking-effort chip). // The variant picker is hidden in that case; it only shows when the user is // free to pick their own model. - const displayVariant = modelPickerLocked - ? agentVariantOverride - : (sessionConfig?.variant ?? undefined); + const displayVariant = modelPickerLocked ? agentVariantOverride : sessionModels.selectedVariant; const displayAvailableVariants = modelPickerLocked ? [] : availableVariants; const placeholder = isLoading @@ -798,6 +908,10 @@ export default function CloudChatPage({ organizationId }: CloudChatPageProps) { )}
+ { + it('retains the Gateway model only for Cloud Agent feedback', () => { + expect(feedbackModelForSession('cloud-agent', 'anthropic/claude-sonnet-4')).toBe( + 'anthropic/claude-sonnet-4' + ); + expect(feedbackModelForSession('remote', 'private-provider/private-model')).toBeUndefined(); + expect(feedbackModelForSession('read-only', 'private-provider/private-model')).toBeUndefined(); + expect(feedbackModelForSession(null, 'private-provider/private-model')).toBeUndefined(); + }); +}); diff --git a/apps/web/src/components/cloud-agent-next/FeedbackDialog.tsx b/apps/web/src/components/cloud-agent-next/FeedbackDialog.tsx index 5045faf11c..91202f5e9d 100644 --- a/apps/web/src/components/cloud-agent-next/FeedbackDialog.tsx +++ b/apps/web/src/components/cloud-agent-next/FeedbackDialog.tsx @@ -16,6 +16,7 @@ import { DialogTrigger, } from '@/components/ui/dialog'; import { useManager } from './CloudAgentProvider'; +import type { ResolvedSession } from '@/lib/cloud-agent-sdk'; import type { StoredMessage } from './types'; import { isTextPart } from './types'; @@ -33,6 +34,7 @@ export function FeedbackDialog({ organizationId, kiloSessionId }: FeedbackDialog const messages = useAtomValue(manager.atoms.messagesList); const isStreaming = useAtomValue(manager.atoms.isStreaming); const currentSessionId = useAtomValue(manager.atoms.sessionId); + const activeSessionType = useAtomValue(manager.atoms.activeSessionType); const sessionConfig = useAtomValue(manager.atoms.sessionConfig); const trpc = useTRPC(); @@ -71,7 +73,7 @@ export function FeedbackDialog({ organizationId, kiloSessionId }: FeedbackDialog kilo_session_id: kiloSessionId ?? undefined, organization_id: organizationId ?? undefined, feedback_text: feedbackText.trim(), - model: sessionConfig?.model || undefined, + model: feedbackModelForSession(activeSessionType, sessionConfig?.model), repository: sessionConfig?.repository || undefined, is_streaming: isStreaming, message_count: messages.length, @@ -82,6 +84,7 @@ export function FeedbackDialog({ organizationId, kiloSessionId }: FeedbackDialog currentSessionId, kiloSessionId, organizationId, + activeSessionType, sessionConfig, isStreaming, messages, @@ -151,6 +154,13 @@ export function FeedbackDialog({ organizationId, kiloSessionId }: FeedbackDialog ); } +export function feedbackModelForSession( + sessionType: ResolvedSession['type'] | null, + model: string | null | undefined +): string | undefined { + return sessionType === 'cloud-agent' ? model || undefined : undefined; +} + function buildRecentMessages( messages: StoredMessage[] ): { role: string; text: string; ts: number }[] { diff --git a/apps/web/src/components/cloud-agent-next/hooks/useOrganizationModels.ts b/apps/web/src/components/cloud-agent-next/hooks/useOrganizationModels.ts index 944809dc0a..64f5255bef 100644 --- a/apps/web/src/components/cloud-agent-next/hooks/useOrganizationModels.ts +++ b/apps/web/src/components/cloud-agent-next/hooks/useOrganizationModels.ts @@ -28,11 +28,17 @@ type UseOrganizationModelsReturn = { * If organizationId is provided, the models API applies org access policy. * * @param organizationId - Optional organization ID to filter models for + * @param enabled - Whether the Gateway model catalog should be fetched */ -export function useOrganizationModels(organizationId?: string): UseOrganizationModelsReturn { +export function useOrganizationModels( + organizationId?: string, + enabled = true +): UseOrganizationModelsReturn { // Fetch models for the model selector - const { data: openRouterModels, isLoading: isLoadingOpenRouter } = - useModelSelectorList(organizationId); + const { data: openRouterModels, isLoading: isLoadingOpenRouter } = useModelSelectorList( + organizationId, + enabled + ); const { data: defaultsData } = useOrganizationDefaults(organizationId); @@ -57,7 +63,7 @@ export function useOrganizationModels(organizationId?: string): UseOrganizationM return { modelOptions, - isLoadingModels: isLoadingOpenRouter, + isLoadingModels: enabled && isLoadingOpenRouter, contextLengthByModelId, defaultModel: defaultsData?.defaultModel, }; diff --git a/apps/web/src/components/cloud-agent-next/hooks/useSessionModels.test.ts b/apps/web/src/components/cloud-agent-next/hooks/useSessionModels.test.ts new file mode 100644 index 0000000000..be70abb985 --- /dev/null +++ b/apps/web/src/components/cloud-agent-next/hooks/useSessionModels.test.ts @@ -0,0 +1,369 @@ +import { describe, expect, it } from '@jest/globals'; +import type { RemoteModelState } from '@/lib/cloud-agent-sdk'; +import type { ModelOption } from '@/components/shared/ModelCombobox'; +import { + buildSessionModels, + createRemoteModelOverride, + resolveGatewayOrganizationId, + validateRemoteModelOverride, + type SessionModelOption, +} from './useSessionModels'; + +const gatewayModels = [ + { + id: 'anthropic/claude-sonnet-4', + name: 'Claude Sonnet 4', + isFree: true, + hasUserByokAvailable: true, + variants: ['none', 'high'], + }, +] satisfies ModelOption[]; + +const emptyRemoteState = { + ownerConnectionId: null, + protocol: 'unknown', + refresh: 'idle', +} satisfies RemoteModelState; + +describe('resolveGatewayOrganizationId', () => { + it('uses the persisted session organization instead of the route organization', () => { + expect( + resolveGatewayOrganizationId({ organizationId: 'org-persisted' }, 'org-route', 'ses_existing') + ).toBe('org-persisted'); + }); + + it('keeps a persisted personal session personal on an organization route', () => { + expect( + resolveGatewayOrganizationId({ organizationId: null }, 'org-route', 'ses_existing') + ).toBeUndefined(); + }); + + it('does not use the route organization while an existing remote session loads', () => { + expect(resolveGatewayOrganizationId(null, 'org-route', 'ses_remote')).toBeUndefined(); + }); + + it('uses the route organization while creating a Cloud Agent', () => { + expect(resolveGatewayOrganizationId(null, 'org-route', null)).toBe('org-route'); + }); +}); + +describe('buildSessionModels', () => { + it('keeps Cloud Agent on the existing Gateway catalog and selection', () => { + const result = buildSessionModels({ + activeSessionType: 'cloud-agent', + remoteModelState: emptyRemoteState, + observedModel: null, + remoteModelOverride: null, + gatewayModels, + gatewayModelsLoading: false, + gatewayModelId: 'anthropic/claude-sonnet-4', + gatewayVariant: 'high', + gatewayOrganizationId: 'org-persisted', + }); + + expect(result.source).toBe('cloud-agent-gateway'); + expect(result.modelOptions).toBe(gatewayModels); + expect(result.selectedValue).toBe('anthropic/claude-sonnet-4'); + expect(result.selectedVariant).toBe('high'); + expect(result.availableVariants).toEqual(['none', 'high']); + expect(result.modelPickerDisabled).toBe(false); + expect(result.notices).toEqual([]); + expect(result.gatewayOrganizationId).toBe('org-persisted'); + }); + + it('projects a v1 CLI catalog into distinct opaque provider-aware options', () => { + const result = buildSessionModels({ + activeSessionType: 'remote', + remoteModelState: { + ownerConnectionId: 'cli-owner', + protocol: 'v1', + refresh: 'idle', + catalog: { + protocolVersion: 1, + truncated: false, + providers: [ + { + id: 'anthropic-local', + name: 'Anthropic Local', + models: [ + { + id: 'shared/model.id', + name: 'Claude Workspace', + variants: ['low', 'high'], + capabilities: { attachment: true, reasoning: true }, + limits: { context: 200_000, output: 8_192 }, + }, + ], + }, + { + id: 'custom-openai', + name: 'Custom OpenAI', + models: [ + { + id: 'shared/model.id', + name: 'Internal Deployment', + variants: [], + capabilities: { attachment: false, reasoning: false }, + limits: { context: 32_000, output: 4_096 }, + }, + ], + }, + ], + }, + }, + observedModel: { + model: { providerID: 'anthropic-local', modelID: 'shared/model.id' }, + variant: 'high', + }, + remoteModelOverride: null, + gatewayModels, + gatewayModelsLoading: false, + gatewayModelId: gatewayModels[0].id, + gatewayOrganizationId: 'org-persisted', + }); + + const cliOptions = result.modelOptions as SessionModelOption[]; + expect(result.source).toBe('remote-cli-catalog'); + expect(cliOptions).toHaveLength(2); + expect(cliOptions[0].id).not.toBe(cliOptions[1].id); + expect(cliOptions.map(option => option.displayId)).toEqual([ + 'shared/model.id', + 'shared/model.id', + ]); + expect(cliOptions.map(option => option.providerGroup?.label)).toEqual([ + 'Anthropic Local', + 'Custom OpenAI', + ]); + expect(cliOptions[0]).toMatchObject({ + modelRef: { providerID: 'anthropic-local', modelID: 'shared/model.id' }, + overrideSource: 'cli-catalog', + variants: ['low', 'high'], + supportsVision: true, + supportsReasoning: true, + showGatewayMetadata: false, + }); + expect(cliOptions[0].searchTerms).toEqual( + expect.arrayContaining([ + 'anthropic-local', + 'Anthropic Local', + 'shared/model.id', + 'Claude Workspace', + ]) + ); + expect(result.selectedValue).toBe(cliOptions[0].id); + expect(result.selectedVariant).toBe('high'); + expect(result.availableVariants).toEqual(['low', 'high']); + }); + + it('uses persisted-organization Gateway fallback only for an exact legacy CLI', () => { + const result = buildSessionModels({ + activeSessionType: 'remote', + remoteModelState: { + ownerConnectionId: 'legacy-owner', + protocol: 'legacy', + refresh: 'idle', + }, + observedModel: { + model: { providerID: 'local-provider', modelID: 'private-model' }, + variant: 'old-variant', + }, + remoteModelOverride: null, + gatewayModels, + gatewayModelsLoading: false, + gatewayModelId: gatewayModels[0].id, + gatewayOrganizationId: 'org-from-fetched-session', + }); + + const gatewayOption = result.modelOptions.find( + option => option.modelRef?.providerID === 'kilo' + ); + const unavailableObserved = result.modelOptions.find(option => option.unavailable); + + expect(result.source).toBe('remote-legacy-gateway'); + expect(result.gatewayOrganizationId).toBe('org-from-fetched-session'); + expect(result.notices.map(notice => notice.id)).toEqual(['legacy']); + expect(gatewayOption).toMatchObject({ + id: 'anthropic/claude-sonnet-4', + modelRef: { providerID: 'kilo', modelID: 'anthropic/claude-sonnet-4' }, + overrideSource: 'legacy-gateway', + isFree: true, + hasUserByokAvailable: true, + }); + expect(unavailableObserved).toMatchObject({ + displayId: 'private-model', + modelRef: { providerID: 'local-provider', modelID: 'private-model' }, + unavailable: true, + }); + expect(result.selectedValue).toBe(unavailableObserved?.id); + expect(createRemoteModelOverride(gatewayOption, 'stale-variant')).toEqual({ + source: 'legacy-gateway', + selection: { + model: { providerID: 'kilo', modelID: 'anthropic/claude-sonnet-4' }, + }, + }); + }); + + it('does not expose Gateway models when remote capability discovery fails', () => { + const result = buildSessionModels({ + activeSessionType: 'remote', + remoteModelState: { + ownerConnectionId: 'cli-owner', + protocol: 'unknown', + refresh: 'error', + error: 'catalog request failed', + }, + observedModel: null, + remoteModelOverride: null, + gatewayModels, + gatewayModelsLoading: false, + gatewayModelId: gatewayModels[0].id, + gatewayOrganizationId: 'org-persisted', + }); + + expect(result.source).toBe('remote-unavailable'); + expect(result.modelOptions).toEqual([ + expect.objectContaining({ + name: 'Session model', + unavailable: true, + }), + ]); + expect(result.modelOptions.some(option => option.id === gatewayModels[0].id)).toBe(false); + expect(result.selectedValue).toBe(result.modelOptions[0].id); + expect(result.modelPickerDisabled).toBe(true); + expect(result.notices).toEqual([expect.objectContaining({ id: 'error', retry: true })]); + }); + + it('keeps a stale v1 catalog with reset, truncation, and local-provider disclosure', () => { + const result = buildSessionModels({ + activeSessionType: 'remote', + remoteModelState: { + ownerConnectionId: 'cli-owner', + protocol: 'v1', + refresh: 'error', + error: 'refresh failed', + catalog: { + protocolVersion: 1, + truncated: true, + providers: [ + { + id: 'local-provider', + name: 'Local Provider', + models: [ + { + id: 'private-model', + variants: ['low', 'high'], + capabilities: { attachment: false, reasoning: true }, + limits: { context: 64_000, output: 4_096 }, + }, + ], + }, + ], + }, + }, + observedModel: null, + remoteModelOverride: { + source: 'cli-catalog', + selection: { + model: { providerID: 'local-provider', modelID: 'private-model' }, + variant: 'high', + }, + }, + gatewayModels, + gatewayModelsLoading: false, + gatewayOrganizationId: 'org-persisted', + }); + + const resetOption = result.modelOptions.find(option => option.action === 'use-session-model'); + const cliOption = result.modelOptions.find(option => option.overrideSource === 'cli-catalog'); + + expect(result.source).toBe('remote-cli-catalog'); + expect(result.notices.map(notice => notice.id)).toEqual([ + 'stale', + 'truncated', + 'local-provider', + ]); + expect(result.notices[2].message).toContain("organization's model restrictions"); + expect(resetOption).toMatchObject({ name: 'Use session model', action: 'use-session-model' }); + expect(createRemoteModelOverride(resetOption, 'high')).toBeNull(); + expect(createRemoteModelOverride(cliOption, 'removed-variant')).toEqual({ + source: 'cli-catalog', + selection: { + model: { providerID: 'local-provider', modelID: 'private-model' }, + }, + }); + expect(result.selectedValue).toBe(cliOption?.id); + expect(result.selectedVariant).toBe('high'); + }); +}); + +describe('validateRemoteModelOverride', () => { + const legacyModelOptions = gatewayModels.map(model => ({ + ...model, + modelRef: { providerID: 'kilo', modelID: model.id }, + overrideSource: 'legacy-gateway' as const, + })); + + it('clears a same-source override when its exact model is no longer available', () => { + expect( + validateRemoteModelOverride( + { + source: 'legacy-gateway', + selection: { + model: { providerID: 'kilo', modelID: 'removed/model' }, + variant: 'high', + }, + }, + legacyModelOptions, + 'legacy-gateway' + ) + ).toBeNull(); + }); + + it('keeps a valid model and drops a variant removed from the same source', () => { + expect( + validateRemoteModelOverride( + { + source: 'legacy-gateway', + selection: { + model: { providerID: 'kilo', modelID: 'anthropic/claude-sonnet-4' }, + variant: 'removed-variant', + }, + }, + legacyModelOptions, + 'legacy-gateway' + ) + ).toEqual({ + source: 'legacy-gateway', + selection: { + model: { providerID: 'kilo', modelID: 'anthropic/claude-sonnet-4' }, + }, + }); + }); + + it('preserves a valid override by reference', () => { + const override = { + source: 'legacy-gateway', + selection: { + model: { providerID: 'kilo', modelID: 'anthropic/claude-sonnet-4' }, + variant: 'high', + }, + } as const; + + expect(validateRemoteModelOverride(override, legacyModelOptions, 'legacy-gateway')).toBe( + override + ); + }); + + it('does not validate an override from a different source', () => { + const override = { + source: 'cli-catalog', + selection: { + model: { providerID: 'local-provider', modelID: 'private-model' }, + }, + } as const; + + expect(validateRemoteModelOverride(override, legacyModelOptions, 'legacy-gateway')).toBe( + override + ); + }); +}); diff --git a/apps/web/src/components/cloud-agent-next/hooks/useSessionModels.ts b/apps/web/src/components/cloud-agent-next/hooks/useSessionModels.ts new file mode 100644 index 0000000000..a943aa3309 --- /dev/null +++ b/apps/web/src/components/cloud-agent-next/hooks/useSessionModels.ts @@ -0,0 +1,428 @@ +import { useMemo } from 'react'; +import { + createModelRefKeyMap, + modelRefsEqual, + type FetchedSessionData, + type ModelRef, + type ModelSelection, + type RemoteModelOverride, + type RemoteModelState, + type ResolvedSession, +} from '@/lib/cloud-agent-sdk'; +import type { ModelOption } from '@/components/shared/ModelCombobox'; +import { + buildContextLengthByProviderAndModel, + type ContextLengthByProviderAndModel, +} from '../model-context-lengths'; +import { useOrganizationModels } from './useOrganizationModels'; + +type ActiveSessionType = ResolvedSession['type']; + +type SessionModelSource = + | 'cloud-agent-gateway' + | 'remote-cli-catalog' + | 'remote-legacy-gateway' + | 'remote-unavailable'; + +type SessionModelNotice = { + id: 'loading' | 'legacy' | 'error' | 'stale' | 'truncated' | 'local-provider'; + message: string; + retry: boolean; +}; + +type SessionModelOption = ModelOption & { + displayId?: string; + providerGroup?: { id: string; label: string }; + searchTerms?: string[]; + supportsReasoning?: boolean; + showGatewayMetadata?: boolean; + modelRef?: ModelRef; + overrideSource?: RemoteModelOverride['source']; + unavailable?: boolean; + action?: 'use-session-model'; +}; + +type BuildSessionModelsInput = { + activeSessionType: ActiveSessionType | null; + remoteModelState: RemoteModelState; + observedModel: ModelSelection | null; + remoteModelOverride: RemoteModelOverride | null; + gatewayModels: ModelOption[]; + gatewayModelsLoading: boolean; + gatewayModelId?: string; + gatewayVariant?: string | null; + gatewayOrganizationId?: string; +}; + +type UseSessionModelsInput = Omit< + BuildSessionModelsInput, + 'gatewayModels' | 'gatewayModelsLoading' | 'gatewayOrganizationId' +> & { + fetchedSessionData: Pick | null; + routeOrganizationId?: string; + sessionIdFromParams: string | null; +}; + +type SessionModels = { + source: SessionModelSource; + modelOptions: SessionModelOption[]; + selectedValue?: string; + selectedVariant?: string; + availableVariants: string[]; + modelPickerDisabled: boolean; + isLoadingModels: boolean; + notices: SessionModelNotice[]; + gatewayOrganizationId?: string; +}; + +type UseSessionModelsResult = SessionModels & { + gatewayContextLengthByModelId: ReadonlyMap; + remoteContextLengthByProviderAndModel?: ContextLengthByProviderAndModel; +}; + +export function resolveGatewayOrganizationId( + fetchedSessionData: Pick | null, + routeOrganizationId: string | undefined, + sessionIdFromParams: string | null +): string | undefined { + if (fetchedSessionData) return fetchedSessionData.organizationId ?? undefined; + return sessionIdFromParams ? undefined : routeOrganizationId; +} + +export function useSessionModels(input: UseSessionModelsInput): UseSessionModelsResult { + const { + activeSessionType, + remoteModelState, + observedModel, + remoteModelOverride, + gatewayModelId, + gatewayVariant, + fetchedSessionData, + routeOrganizationId, + sessionIdFromParams, + } = input; + const gatewayOrganizationId = resolveGatewayOrganizationId( + fetchedSessionData, + routeOrganizationId, + sessionIdFromParams + ); + const usesGateway = activeSessionType !== 'remote' || remoteModelState.protocol === 'legacy'; + const gatewayModels = useOrganizationModels(gatewayOrganizationId, usesGateway); + const models = useMemo( + () => + buildSessionModels({ + activeSessionType, + remoteModelState, + observedModel, + remoteModelOverride, + gatewayModels: gatewayModels.modelOptions, + gatewayModelsLoading: gatewayModels.isLoadingModels, + gatewayModelId, + gatewayVariant, + gatewayOrganizationId, + }), + [ + activeSessionType, + gatewayModelId, + gatewayModels.isLoadingModels, + gatewayModels.modelOptions, + gatewayOrganizationId, + gatewayVariant, + observedModel, + remoteModelOverride, + remoteModelState, + ] + ); + const remoteContextLengthByProviderAndModel = useMemo( + () => + models.source === 'remote-cli-catalog' && remoteModelState.catalog + ? buildContextLengthByProviderAndModel(remoteModelState.catalog.providers) + : undefined, + [models.source, remoteModelState.catalog] + ); + + return { + ...models, + gatewayContextLengthByModelId: gatewayModels.contextLengthByModelId, + remoteContextLengthByProviderAndModel, + }; +} + +export function buildSessionModels(input: BuildSessionModelsInput): SessionModels { + if (input.activeSessionType === 'remote') { + if (input.remoteModelState.protocol === 'v1' && input.remoteModelState.catalog) { + return buildCliCatalogModels(input); + } + if (input.remoteModelState.protocol === 'legacy') { + return buildLegacyGatewayModels(input); + } + return buildUnavailableRemoteModels(input); + } + + const selectedOption = input.gatewayModels.find(model => model.id === input.gatewayModelId); + + return { + source: 'cloud-agent-gateway', + modelOptions: input.gatewayModels, + selectedValue: input.gatewayModelId, + selectedVariant: input.gatewayVariant ?? undefined, + availableVariants: selectedOption?.variants ?? [], + modelPickerDisabled: false, + isLoadingModels: input.gatewayModelsLoading, + notices: [], + gatewayOrganizationId: input.gatewayOrganizationId, + }; +} + +function buildUnavailableRemoteModels(input: BuildSessionModelsInput): SessionModels { + const currentSelection = currentRemoteSelection(input); + const unavailableOption = currentSelection + ? createUnavailableOption(currentSelection.model) + : ({ + id: 'remote-session-model-unavailable', + name: 'Session model', + showGatewayMetadata: false, + unavailable: true, + } satisfies SessionModelOption); + const loading = input.remoteModelState.refresh === 'loading'; + + return { + source: 'remote-unavailable', + modelOptions: [unavailableOption], + selectedValue: unavailableOption.id, + selectedVariant: currentSelection?.variant, + availableVariants: [], + modelPickerDisabled: true, + isLoadingModels: loading, + notices: [ + loading + ? { + id: 'loading', + message: 'Checking this CLI for available models.', + retry: true, + } + : { + id: 'error', + message: + "Models from this CLI couldn't be loaded. Sending still uses the session model.", + retry: true, + }, + ], + gatewayOrganizationId: input.gatewayOrganizationId, + }; +} + +function buildLegacyGatewayModels(input: BuildSessionModelsInput): SessionModels { + const modelOptions: SessionModelOption[] = input.gatewayModels.map(model => ({ + ...model, + modelRef: { providerID: 'kilo', modelID: model.id }, + overrideSource: 'legacy-gateway' as const, + })); + const currentSelection = currentRemoteSelection(input); + let selectedOption = currentSelection + ? modelOptions.find( + option => option.modelRef && modelRefsEqual(option.modelRef, currentSelection.model) + ) + : undefined; + + if (currentSelection && !selectedOption) { + const unavailableOption = createUnavailableOption(currentSelection.model); + modelOptions.unshift(unavailableOption); + selectedOption = unavailableOption; + } + if (input.remoteModelOverride) { + modelOptions.unshift(createUseSessionModelOption()); + } + + const selectedVariant = + currentSelection?.variant && selectedOption?.variants?.includes(currentSelection.variant) + ? currentSelection.variant + : undefined; + + return { + source: 'remote-legacy-gateway', + modelOptions, + selectedValue: selectedOption?.id, + selectedVariant, + availableVariants: selectedOption?.variants ?? [], + modelPickerDisabled: input.gatewayModelsLoading, + isLoadingModels: input.gatewayModelsLoading, + notices: [ + { + id: 'legacy', + message: + 'This CLI uses Gateway model fallback. Upgrade Kilo CLI to use its configured providers and models.', + retry: false, + }, + ], + gatewayOrganizationId: input.gatewayOrganizationId, + }; +} + +function buildCliCatalogModels(input: BuildSessionModelsInput): SessionModels { + const catalog = input.remoteModelState.catalog; + if (!catalog) throw new Error('CLI catalog is required for v1 model options'); + + const keyMap = createModelRefKeyMap(); + const modelOptions: SessionModelOption[] = catalog.providers.flatMap(provider => + provider.models.map(model => { + const modelRef = { providerID: provider.id, modelID: model.id }; + return { + id: keyMap.getOrCreateKey(modelRef), + name: model.name ?? model.id, + displayId: model.id, + providerGroup: { id: provider.id, label: provider.name ?? provider.id }, + searchTerms: [provider.id, provider.name, model.id, model.name].filter( + (term): term is string => term !== undefined + ), + supportsVision: model.capabilities.attachment, + supportsReasoning: model.capabilities.reasoning, + showGatewayMetadata: false, + variants: model.variants, + modelRef, + overrideSource: 'cli-catalog' as const, + } satisfies SessionModelOption; + }) + ); + const currentSelection = currentRemoteSelection(input); + let selectedOption = currentSelection + ? modelOptions.find( + option => option.modelRef && modelRefsEqual(option.modelRef, currentSelection.model) + ) + : undefined; + + if (currentSelection && !selectedOption) { + const unavailableOption = createUnavailableOption(currentSelection.model); + modelOptions.unshift(unavailableOption); + selectedOption = unavailableOption; + } + if (input.remoteModelOverride) { + modelOptions.unshift(createUseSessionModelOption()); + } + + const selectedVariant = + currentSelection?.variant && selectedOption?.variants?.includes(currentSelection.variant) + ? currentSelection.variant + : undefined; + const notices: SessionModelNotice[] = []; + if (input.remoteModelState.refresh === 'error') { + notices.push({ + id: 'stale', + message: 'Showing the last model catalog because refresh failed.', + retry: true, + }); + } + if (catalog.truncated) { + notices.push({ + id: 'truncated', + message: 'This CLI returned a partial model catalog. Some models or variants may be missing.', + retry: false, + }); + } + if (currentSelection && currentSelection.model.providerID !== 'kilo') { + notices.push({ + id: 'local-provider', + message: input.gatewayOrganizationId + ? "This model runs through your CLI provider, outside Kilo Gateway billing and this organization's model restrictions." + : 'This model runs through your CLI provider, outside Kilo Gateway billing.', + retry: false, + }); + } + + return { + source: 'remote-cli-catalog', + modelOptions, + selectedValue: selectedOption?.id, + selectedVariant, + availableVariants: selectedOption?.variants ?? [], + modelPickerDisabled: false, + isLoadingModels: false, + notices, + gatewayOrganizationId: input.gatewayOrganizationId, + }; +} + +function currentRemoteSelection(input: BuildSessionModelsInput): ModelSelection | null { + const defaultModel = input.remoteModelState.catalog?.defaultModel; + return ( + input.remoteModelOverride?.selection ?? + input.observedModel ?? + (defaultModel ? { model: defaultModel } : null) + ); +} + +function createUseSessionModelOption(): SessionModelOption { + return { + id: 'remote-use-session-model', + name: 'Use session model', + displayId: 'Stop overriding the model selected by the CLI', + providerGroup: { id: 'session', label: 'Session' }, + searchTerms: ['session', 'default', 'observed'], + showGatewayMetadata: false, + action: 'use-session-model', + }; +} + +function createUnavailableOption(modelRef: ModelRef): SessionModelOption { + const keyMap = createModelRefKeyMap(); + return { + id: `unavailable-${keyMap.getOrCreateKey(modelRef)}`, + name: modelRef.modelID, + displayId: modelRef.modelID, + providerGroup: { id: modelRef.providerID, label: modelRef.providerID }, + searchTerms: [modelRef.providerID, modelRef.modelID], + showGatewayMetadata: false, + modelRef, + unavailable: true, + }; +} + +export function validateRemoteModelOverride( + override: RemoteModelOverride | null, + modelOptions: readonly SessionModelOption[], + source: RemoteModelOverride['source'] +): RemoteModelOverride | null { + if (!override || override.source !== source) return override; + + const selectedOption = modelOptions.find( + option => + !option.unavailable && + option.overrideSource === source && + option.modelRef !== undefined && + modelRefsEqual(option.modelRef, override.selection.model) + ); + + if (!selectedOption) return null; + + const variant = override.selection.variant; + if (!variant || selectedOption.variants?.includes(variant)) return override; + + return { + source: override.source, + selection: { model: override.selection.model }, + }; +} + +export function createRemoteModelOverride( + option: SessionModelOption | undefined, + variant?: string +): RemoteModelOverride | null { + if (!option?.modelRef || !option.overrideSource || option.unavailable) return null; + const validVariant = variant && option.variants?.includes(variant) ? variant : undefined; + return { + source: option.overrideSource, + selection: { + model: option.modelRef, + ...(validVariant ? { variant: validVariant } : {}), + }, + }; +} + +export type { + BuildSessionModelsInput, + SessionModelNotice, + SessionModelOption, + SessionModels, + UseSessionModelsInput, + UseSessionModelsResult, +}; diff --git a/apps/web/src/components/cloud-agent-next/model-context-lengths.test.ts b/apps/web/src/components/cloud-agent-next/model-context-lengths.test.ts index e958d5f16d..31d1ad6b99 100644 --- a/apps/web/src/components/cloud-agent-next/model-context-lengths.test.ts +++ b/apps/web/src/components/cloud-agent-next/model-context-lengths.test.ts @@ -1,5 +1,9 @@ import type { ContextUsage } from '@/lib/cloud-agent-sdk/context-usage'; -import { buildContextLengthByModelId, resolveContextWindow } from './model-context-lengths'; +import { + buildContextLengthByModelId, + buildContextLengthByProviderAndModel, + resolveContextWindow, +} from './model-context-lengths'; const contextUsage = { contextTokens: 32_418, @@ -59,6 +63,28 @@ describe('buildContextLengthByModelId', () => { }); }); +describe('buildContextLengthByProviderAndModel', () => { + it('keeps duplicate model ids distinct across CLI providers', () => { + expect( + buildContextLengthByProviderAndModel([ + { + id: 'anthropic-local', + models: [{ id: 'shared/model', limits: { context: 200_000 } }], + }, + { + id: 'custom-openai', + models: [{ id: 'shared/model', limits: { context: 32_000 } }], + }, + ]) + ).toEqual( + new Map([ + ['anthropic-local', new Map([['shared/model', 200_000]])], + ['custom-openai', new Map([['shared/model', 32_000]])], + ]) + ); + }); +}); + describe('resolveContextWindow', () => { it('resolves a known kilo response by exact emitted model id', () => { expect(resolveContextWindow(contextUsage, new Map([[contextUsage.modelID, 200_000]]))).toBe( @@ -66,7 +92,35 @@ describe('resolveContextWindow', () => { ); }); - it('returns undefined for missing usage or a non-kilo provider', () => { + it('resolves remote usage by exact provider and model identity', () => { + const remoteLengths = buildContextLengthByProviderAndModel([ + { + id: 'anthropic-local', + models: [{ id: 'shared/model', limits: { context: 200_000 } }], + }, + { + id: 'custom-openai', + models: [{ id: 'shared/model', limits: { context: 32_000 } }], + }, + ]); + + expect( + resolveContextWindow( + { ...contextUsage, providerID: 'anthropic-local', modelID: 'shared/model' }, + new Map(), + remoteLengths + ) + ).toBe(200_000); + expect( + resolveContextWindow( + { ...contextUsage, providerID: 'custom-openai', modelID: 'shared/model' }, + new Map(), + remoteLengths + ) + ).toBe(32_000); + }); + + it('returns undefined for missing usage or a non-kilo provider without a CLI catalog', () => { expect( resolveContextWindow(undefined, new Map([[contextUsage.modelID, 200_000]])) ).toBeUndefined(); diff --git a/apps/web/src/components/cloud-agent-next/model-context-lengths.ts b/apps/web/src/components/cloud-agent-next/model-context-lengths.ts index b156e08857..4456c804ec 100644 --- a/apps/web/src/components/cloud-agent-next/model-context-lengths.ts +++ b/apps/web/src/components/cloud-agent-next/model-context-lengths.ts @@ -5,6 +5,37 @@ type ModelContextLength = { context_length?: number | null; }; +type ProviderModelContextLength = { + id: string; + models: readonly { + id: string; + limits: { context: number }; + }[]; +}; + +export type ContextLengthByProviderAndModel = ReadonlyMap>; + +// First positive value wins; a later conflicting value blacklists the id so a +// model with inconsistent context lengths is treated as unknown rather than +// resolving to an arbitrary one. +function recordUniqueContextLength( + lengths: Map, + conflicts: Set, + id: string, + contextLength: number +): void { + if (!Number.isFinite(contextLength) || contextLength <= 0) return; + if (conflicts.has(id)) return; + + const existingContextLength = lengths.get(id); + if (existingContextLength === undefined) { + lengths.set(id, contextLength); + } else if (existingContextLength !== contextLength) { + lengths.delete(id); + conflicts.add(id); + } +} + export function buildContextLengthByModelId( models: readonly ModelContextLength[] ): ReadonlyMap { @@ -14,31 +45,50 @@ export function buildContextLengthByModelId( for (const model of models) { const contextLength = model.context_length; if (contextLength === undefined || contextLength === null) continue; - if (!Number.isFinite(contextLength) || contextLength <= 0) continue; - if (conflictingModelIds.has(model.id)) continue; + recordUniqueContextLength(contextLengthByModelId, conflictingModelIds, model.id, contextLength); + } + + return contextLengthByModelId; +} - const existingContextLength = contextLengthByModelId.get(model.id); - if (existingContextLength === undefined) { - contextLengthByModelId.set(model.id, contextLength); - continue; +export function buildContextLengthByProviderAndModel( + providers: readonly ProviderModelContextLength[] +): ContextLengthByProviderAndModel { + const lengths = new Map>(); + const conflicts = new Map>(); + + for (const provider of providers) { + let providerLengths = lengths.get(provider.id); + if (!providerLengths) { + providerLengths = new Map(); + lengths.set(provider.id, providerLengths); + } + let providerConflicts = conflicts.get(provider.id); + if (!providerConflicts) { + providerConflicts = new Set(); + conflicts.set(provider.id, providerConflicts); } - if (existingContextLength !== contextLength) { - contextLengthByModelId.delete(model.id); - conflictingModelIds.add(model.id); + for (const model of provider.models) { + recordUniqueContextLength(providerLengths, providerConflicts, model.id, model.limits.context); } } - return contextLengthByModelId; + return lengths; } export function resolveContextWindow( contextUsage: ContextUsage | undefined, - contextLengthByModelId: ReadonlyMap + contextLengthByModelId: ReadonlyMap, + contextLengthByProviderAndModel?: ContextLengthByProviderAndModel ): number | undefined { - if (contextUsage?.providerID !== 'kilo') return undefined; + if (!contextUsage) return undefined; - const contextWindow = contextLengthByModelId.get(contextUsage.modelID); + const contextWindow = contextLengthByProviderAndModel + ? contextLengthByProviderAndModel.get(contextUsage.providerID)?.get(contextUsage.modelID) + : contextUsage.providerID === 'kilo' + ? contextLengthByModelId.get(contextUsage.modelID) + : undefined; if (contextWindow === undefined || !Number.isFinite(contextWindow) || contextWindow <= 0) { return undefined; } diff --git a/apps/web/src/components/shared/ModelCombobox.tsx b/apps/web/src/components/shared/ModelCombobox.tsx index dd35bcf614..4245d17746 100644 --- a/apps/web/src/components/shared/ModelCombobox.tsx +++ b/apps/web/src/components/shared/ModelCombobox.tsx @@ -13,9 +13,13 @@ import { CommandItem, CommandList, } from '@/components/ui/command'; -import { BookOpenCheck, ChevronsUpDown, Check, Image } from 'lucide-react'; +import { BookOpenCheck, Brain, ChevronsUpDown, Check, Image } from 'lucide-react'; import { cn } from '@/lib/utils'; -import { preferredModels } from '@/lib/ai-gateway/models'; +import { + buildModelOptionGroups, + getModelOptionKeywords, + type ModelOptionGroup, +} from './model-combobox-options'; import { Tooltip, TooltipContent, TooltipTrigger } from '@/components/ui/tooltip'; import { formatShortModelDisplayName } from '@/lib/format-model-name'; import { @@ -31,10 +35,19 @@ import { export type ModelOption = { id: string; // e.g., "anthropic/claude-sonnet-4.5" name: string; // e.g., "Claude Sonnet 4.5" + /** Exact user-facing ID when `id` is an opaque selection value. */ + displayId?: string; + /** Optional provider group for provider-aware catalogs. */ + providerGroup?: { id: string; label: string }; + /** Additional user-facing search terms. Opaque selection values stay excluded. */ + searchTerms?: string[]; supportsVision?: boolean; + supportsReasoning?: boolean; isFree?: boolean; mayTrainOnYourPrompts?: boolean; hasUserByokAvailable?: boolean; + showGatewayMetadata?: boolean; + unavailable?: boolean; /** Ordered list of variant key names (e.g., ["none","low","medium","high","max"]) */ variants?: string[]; }; @@ -95,35 +108,13 @@ export function ModelCombobox({ listRef.current?.scrollTo({ top: 0 }); }, []); - // Sort models: preferred models first (in preferredModels order), then others alphabetically - // This must be called before any early returns to follow Rules of Hooks - const sortedModels = useMemo(() => { - const preferred: ModelOption[] = []; - const others: ModelOption[] = []; - - models.forEach(model => { - if (preferredModels.includes(model.id)) { - preferred.push(model); - } else { - others.push(model); - } - }); - - // Sort preferred by their index in preferredModels array - preferred.sort((a, b) => { - return preferredModels.indexOf(a.id) - preferredModels.indexOf(b.id); - }); - - // Sort others alphabetically by name - others.sort((a, b) => a.name.localeCompare(b.name)); - - return { preferred, others }; - }, [models]); + const modelGroups = useMemo(() => buildModelOptionGroups(models), [models]); const selectedModel = models.find(model => model.id === value); const isCompact = variant === 'compact'; const showLabel = !isCompact && label; - const selectedCollectsData = mayTrainOnYourPrompts(selectedModel); + const selectedCollectsData = + selectedModel?.showGatewayMetadata !== false && mayTrainOnYourPrompts(selectedModel); if (isLoading) { if (isCompact) { @@ -236,82 +227,14 @@ export function ModelCombobox({ {noResultsText} - {sortedModels.preferred.length > 0 && ( - - {sortedModels.preferred.map(model => ( - { - onValueChange(model.id); - setOpen(false); - }} - className="flex items-center gap-2" - > -
-
- {model.name} - {model.supportsVision === true && ( - - - - - Supports vision - - )} - -
- {model.id} -
- -
- ))} -
- )} - {sortedModels.others.length > 0 && ( - - {sortedModels.others.map(model => ( - { - onValueChange(model.id); - setOpen(false); - }} - className="flex items-center gap-2" - > -
-
- {model.name} - {model.supportsVision === true && ( - - - - - Supports vision - - )} - -
- {model.id} -
- -
- ))} -
- )} + { + onValueChange(modelId); + setOpen(false); + }} + />
@@ -360,82 +283,14 @@ export function ModelCombobox({ {noResultsText} - {sortedModels.preferred.length > 0 && ( - - {sortedModels.preferred.map(model => ( - { - onValueChange(model.id); - setOpen(false); - }} - className="flex items-center gap-2" - > -
-
- {model.name} - {model.supportsVision === true && ( - - - - - Supports vision - - )} - -
- {model.id} -
- -
- ))} -
- )} - {sortedModels.others.length > 0 && ( - - {sortedModels.others.map(model => ( - { - onValueChange(model.id); - setOpen(false); - }} - className="flex items-center gap-2" - > -
-
- {model.name} - {model.supportsVision === true && ( - - - - - Supports vision - - )} - -
- {model.id} -
- -
- ))} -
- )} + { + onValueChange(modelId); + setOpen(false); + }} + />
@@ -445,6 +300,71 @@ export function ModelCombobox({ ); } +function ModelOptionGroups({ + groups, + value, + onSelect, +}: { + groups: ModelOptionGroup[]; + value?: string; + onSelect: (value: string) => void; +}) { + return groups.map(group => ( + + {group.models.map(model => { + const keywords = getModelOptionKeywords(model); + return ( + onSelect(model.id)} + className="flex items-center gap-2" + > +
+
+ {model.name} + {model.supportsVision === true && ( + + + + + Supports vision + + )} + {model.supportsReasoning === true && ( + + + + + Supports reasoning + + )} + {model.showGatewayMetadata !== false && } + {model.unavailable && ( + + Unavailable + + )} +
+ + {model.displayId ?? model.id} + +
+ +
+ ); + })} +
+ )); +} + function FreeModelDataIcon({ compact = false }: { compact?: boolean }) { return ( diff --git a/apps/web/src/components/shared/model-combobox-options.test.ts b/apps/web/src/components/shared/model-combobox-options.test.ts new file mode 100644 index 0000000000..168a2327e2 --- /dev/null +++ b/apps/web/src/components/shared/model-combobox-options.test.ts @@ -0,0 +1,52 @@ +import { describe, expect, it } from '@jest/globals'; +import { preferredModels } from '@/lib/ai-gateway/models'; +import type { ModelOption } from './ModelCombobox'; +import { buildModelOptionGroups, getModelOptionKeywords } from './model-combobox-options'; + +describe('model combobox options', () => { + it('groups CLI options by provider and searches provider/model names and exact ids', () => { + const options = [ + { + id: 'remote-model-0', + name: 'Workspace Claude', + displayId: 'shared/model.id', + providerGroup: { id: 'anthropic-local', label: 'Anthropic Local' }, + searchTerms: ['anthropic-local', 'Anthropic Local', 'shared/model.id'], + showGatewayMetadata: false, + }, + { + id: 'remote-model-1', + name: 'Internal Deployment', + displayId: 'shared/model.id', + providerGroup: { id: 'custom-openai', label: 'Custom OpenAI' }, + searchTerms: ['custom-openai', 'Custom OpenAI', 'shared/model.id'], + showGatewayMetadata: false, + }, + ] satisfies ModelOption[]; + + expect(buildModelOptionGroups(options)).toEqual([ + { id: 'provider:anthropic-local', heading: 'Anthropic Local', models: [options[0]] }, + { id: 'provider:custom-openai', heading: 'Custom OpenAI', models: [options[1]] }, + ]); + expect(getModelOptionKeywords(options[0])).toEqual( + expect.arrayContaining([ + 'Workspace Claude', + 'shared/model.id', + 'anthropic-local', + 'Anthropic Local', + ]) + ); + expect(getModelOptionKeywords(options[0])).not.toContain('remote-model-0'); + }); + + it('keeps existing Gateway options in Recommended and All Models groups', () => { + const preferred = { id: preferredModels[0], name: 'Preferred Gateway model' }; + const other = { id: 'provider/other-model', name: 'Other Gateway model' }; + + expect(buildModelOptionGroups([other, preferred])).toEqual([ + { id: 'recommended', heading: 'Recommended', models: [preferred] }, + { id: 'all-models', heading: 'All Models', models: [other] }, + ]); + expect(getModelOptionKeywords(other)).toContain('provider/other-model'); + }); +}); diff --git a/apps/web/src/components/shared/model-combobox-options.ts b/apps/web/src/components/shared/model-combobox-options.ts new file mode 100644 index 0000000000..9e9b2bbad5 --- /dev/null +++ b/apps/web/src/components/shared/model-combobox-options.ts @@ -0,0 +1,66 @@ +import { preferredModels } from '@/lib/ai-gateway/models'; +import type { ModelOption } from './ModelCombobox'; + +export type ModelOptionGroup = { + id: string; + heading: string; + models: ModelOption[]; +}; + +export function buildModelOptionGroups(models: ModelOption[]): ModelOptionGroup[] { + const groups: ModelOptionGroup[] = []; + const groupIndexes = new Map(); + const ungrouped: ModelOption[] = []; + + for (const model of models) { + if (!model.providerGroup) { + ungrouped.push(model); + continue; + } + + const groupId = `provider:${model.providerGroup.id}`; + const existingIndex = groupIndexes.get(groupId); + if (existingIndex !== undefined) { + groups[existingIndex].models.push(model); + continue; + } + + groupIndexes.set(groupId, groups.length); + groups.push({ id: groupId, heading: model.providerGroup.label, models: [model] }); + } + + if (ungrouped.length === 0) return groups; + + const preferred: ModelOption[] = []; + const others: ModelOption[] = []; + for (const model of ungrouped) { + if (preferredModels.includes(model.id)) preferred.push(model); + else others.push(model); + } + preferred.sort( + (left, right) => preferredModels.indexOf(left.id) - preferredModels.indexOf(right.id) + ); + others.sort((left, right) => left.name.localeCompare(right.name)); + + if (preferred.length > 0) { + groups.push({ id: 'recommended', heading: 'Recommended', models: preferred }); + } + if (others.length > 0) { + groups.push({ id: 'all-models', heading: 'All Models', models: others }); + } + return groups; +} + +export function getModelOptionKeywords(model: ModelOption): string[] { + return Array.from( + new Set( + [ + model.name, + model.displayId ?? (model.providerGroup ? undefined : model.id), + model.providerGroup?.id, + model.providerGroup?.label, + ...(model.searchTerms ?? []), + ].filter((term): term is string => Boolean(term)) + ) + ); +} diff --git a/apps/web/src/lib/cloud-agent-sdk/cli-live-transport.test.ts b/apps/web/src/lib/cloud-agent-sdk/cli-live-transport.test.ts index 0dee16ae91..324169efb5 100644 --- a/apps/web/src/lib/cloud-agent-sdk/cli-live-transport.test.ts +++ b/apps/web/src/lib/cloud-agent-sdk/cli-live-transport.test.ts @@ -1,10 +1,89 @@ import type { ChatEvent, ServiceEvent } from './normalizer'; import { createCliLiveTransport } from './cli-live-transport'; -import type { UserWebCliEvent, UserWebConnection, UserWebSystemEvent } from './user-web-connection'; +import type { + RemoteModelCatalogV1, + RemoteModelCatalogWireV1, + RemoteModelState, +} from './remote-model-catalog'; +import { + UserWebCommandError, + type UserWebCliEvent, + type UserWebConnection, + type UserWebSystemEvent, +} from './user-web-connection'; import type { KiloSessionId, SessionSnapshot } from './types'; import { kiloId, makeSnapshot, stubTextPart, stubUserMessage } from './test-helpers'; const KILO_SESSION_ID = kiloId('kilo-ses-1'); +const WIRE_CATALOG = { + all: [ + { + id: 'anthropic', + name: 'Anthropic', + source: 'env', + env: [], + options: {}, + models: { + 'claude-sonnet-4': { + id: 'claude-sonnet-4', + providerID: 'anthropic', + api: { id: 'claude-sonnet-4', url: '', npm: '' }, + name: 'Claude Sonnet 4', + capabilities: { + temperature: true, + reasoning: true, + attachment: true, + toolcall: true, + input: { text: true, audio: false, image: true, video: false, pdf: true }, + output: { text: true, audio: false, image: false, video: false, pdf: false }, + interleaved: false, + }, + cost: { input: 0, output: 0, cache: { read: 0, write: 0 } }, + limit: { context: 200_000, output: 64_000 }, + status: 'active', + options: {}, + headers: {}, + release_date: '', + variants: { high: {} }, + }, + }, + }, + ], + default: { anthropic: 'claude-sonnet-4' }, + connected: ['anthropic'], + failed: [], + protocolVersion: 1, + currentModel: { + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + variant: 'high', + }, + defaultModel: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + truncated: false, +} satisfies RemoteModelCatalogWireV1; +const REMOTE_CATALOG = { + protocolVersion: 1, + providers: [ + { + id: 'anthropic', + name: 'Anthropic', + models: [ + { + id: 'claude-sonnet-4', + name: 'Claude Sonnet 4', + variants: ['high'], + capabilities: { attachment: true, reasoning: true }, + limits: { context: 200_000, output: 64_000 }, + }, + ], + }, + ], + currentModel: { + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + variant: 'high', + }, + defaultModel: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + truncated: false, +} satisfies RemoteModelCatalogV1; type FakeUserWebConnection = UserWebConnection & { emitCli: (event: UserWebCliEvent) => void; @@ -49,6 +128,8 @@ function createTransportWithSinks(opts?: { connection?: FakeUserWebConnection; fetchSnapshot?: (kiloSessionId: KiloSessionId) => Promise; onError?: (message: string) => void; + onRemoteModelStateChange?: (state: RemoteModelState) => void; + onCapabilityChange?: () => void; }) { const userWebConnection = opts?.connection ?? createConnection(); const chatEvents: ChatEvent[] = []; @@ -58,6 +139,8 @@ function createTransportWithSinks(opts?: { userWebConnection, fetchSnapshot: opts?.fetchSnapshot, onError: opts?.onError, + onRemoteModelStateChange: opts?.onRemoteModelStateChange, + onCapabilityChange: opts?.onCapabilityChange, })({ onChatEvent: event => chatEvents.push(event), onServiceEvent: event => serviceEvents.push(event), @@ -65,6 +148,15 @@ function createTransportWithSinks(opts?: { return { userWebConnection, transport, chatEvents, serviceEvents }; } +function emitOwner(connection: FakeUserWebConnection, connectionId = 'owner'): void { + connection.emitSystem({ + event: 'sessions.list', + data: { + sessions: [{ id: KILO_SESSION_ID, status: 'active', title: 'Tracked', connectionId }], + }, + }); +} + function emitMessageUpdated(connection: FakeUserWebConnection, sessionId = KILO_SESSION_ID): void { connection.emitCli({ sessionId, @@ -76,6 +168,328 @@ function emitMessageUpdated(connection: FakeUserWebConnection, sessionId = KILO_ } describe('CliLiveTransport unified user web connection', () => { + it('discovers and publishes a v1 catalog for the current owner', async () => { + const connection = createConnection(); + jest + .mocked(connection.sendCommand) + .mockImplementation((_sessionId, command) => + Promise.resolve(command === 'list_models' ? WIRE_CATALOG : { ok: true }) + ); + const states: RemoteModelState[] = []; + const { transport } = createTransportWithSinks({ + connection, + onRemoteModelStateChange: state => states.push(state), + }); + + transport.connect(); + emitOwner(connection); + await Promise.resolve(); + await Promise.resolve(); + + expect(connection.sendCommand).toHaveBeenCalledWith( + KILO_SESSION_ID, + 'list_models', + { protocolVersion: 1 }, + 'owner' + ); + expect(states.at(-1)).toEqual({ + ownerConnectionId: 'owner', + protocol: 'v1', + catalog: REMOTE_CATALOG, + refresh: 'idle', + }); + expect(transport.canSend?.()).toBe(true); + transport.destroy(); + }); + + it('classifies only the exact unknown list_models command error as legacy', async () => { + const connection = createConnection(); + jest + .mocked(connection.sendCommand) + .mockRejectedValueOnce(new Error('unknown command: list_models')); + const states: RemoteModelState[] = []; + const { transport } = createTransportWithSinks({ + connection, + onRemoteModelStateChange: state => states.push(state), + }); + + transport.connect(); + emitOwner(connection); + await Promise.resolve(); + await Promise.resolve(); + + expect(states.at(-1)).toEqual({ + ownerConnectionId: 'owner', + protocol: 'legacy', + refresh: 'idle', + }); + + jest + .mocked(connection.sendCommand) + .mockRejectedValueOnce(new Error('prefix: unknown command: list_models')); + transport.retryRemoteModels?.(); + await Promise.resolve(); + await Promise.resolve(); + + expect(states.at(-1)).toEqual({ + ownerConnectionId: 'owner', + protocol: 'legacy', + refresh: 'error', + error: 'prefix: unknown command: list_models', + }); + transport.destroy(); + }); + + it('keeps protocol unknown and owner send capability after a malformed initial catalog', async () => { + const connection = createConnection(); + jest.mocked(connection.sendCommand).mockResolvedValueOnce({ + protocolVersion: 1, + providers: 'invalid', + truncated: false, + }); + const states: RemoteModelState[] = []; + const { transport } = createTransportWithSinks({ + connection, + onRemoteModelStateChange: state => states.push(state), + }); + + transport.connect(); + emitOwner(connection); + await Promise.resolve(); + await Promise.resolve(); + + expect(states.at(-1)).toEqual({ + ownerConnectionId: 'owner', + protocol: 'unknown', + refresh: 'error', + error: 'Invalid remote model catalog', + }); + expect(states).not.toContainEqual(expect.objectContaining({ protocol: 'legacy' })); + expect(transport.canSend?.()).toBe(true); + transport.destroy(); + }); + + it('keeps protocol unknown and owner send capability after a transient initial catalog error', async () => { + const connection = createConnection(); + jest.mocked(connection.sendCommand).mockRejectedValueOnce(new Error('catalog timed out')); + const states: RemoteModelState[] = []; + const { transport } = createTransportWithSinks({ + connection, + onRemoteModelStateChange: state => states.push(state), + }); + + transport.connect(); + emitOwner(connection); + await Promise.resolve(); + await Promise.resolve(); + + expect(states.at(-1)).toEqual({ + ownerConnectionId: 'owner', + protocol: 'unknown', + refresh: 'error', + error: 'catalog timed out', + }); + expect(states).not.toContainEqual(expect.objectContaining({ protocol: 'legacy' })); + expect(transport.canSend?.()).toBe(true); + transport.destroy(); + }); + + it('ignores a late catalog response from a replaced owner', async () => { + const connection = createConnection(); + let resolveFirstCatalog: ((catalog: RemoteModelCatalogWireV1) => void) | undefined; + const firstCatalog = new Promise(resolve => { + resolveFirstCatalog = resolve; + }); + const sourceProvider = WIRE_CATALOG.all[0]; + const sourceModel = sourceProvider.models['claude-sonnet-4']; + const replacementWireCatalog = { + ...WIRE_CATALOG, + all: [ + { + ...sourceProvider, + id: 'replacement-provider', + models: { + 'claude-sonnet-4': { ...sourceModel, providerID: 'replacement-provider' }, + }, + }, + ], + default: { 'replacement-provider': 'claude-sonnet-4' }, + connected: ['replacement-provider'], + } satisfies RemoteModelCatalogWireV1; + const replacementCatalog = { + ...REMOTE_CATALOG, + providers: [{ ...REMOTE_CATALOG.providers[0], id: 'replacement-provider' }], + } satisfies RemoteModelCatalogV1; + jest + .mocked(connection.sendCommand) + .mockReturnValueOnce(firstCatalog) + .mockResolvedValueOnce(replacementWireCatalog); + const states: RemoteModelState[] = []; + const { transport } = createTransportWithSinks({ + connection, + onRemoteModelStateChange: state => states.push(state), + }); + + transport.connect(); + emitOwner(connection, 'owner-a'); + emitOwner(connection, 'owner-b'); + await Promise.resolve(); + await Promise.resolve(); + + resolveFirstCatalog?.(WIRE_CATALOG); + await Promise.resolve(); + await Promise.resolve(); + + expect(states.at(-1)).toEqual({ + ownerConnectionId: 'owner-b', + protocol: 'v1', + catalog: replacementCatalog, + refresh: 'idle', + }); + expect(states).not.toContainEqual( + expect.objectContaining({ ownerConnectionId: 'owner-a', catalog: REMOTE_CATALOG }) + ); + transport.destroy(); + }); + + it('keeps one catalog request in flight for an owner', async () => { + const connection = createConnection(); + let resolveCatalog: ((catalog: RemoteModelCatalogWireV1) => void) | undefined; + jest.mocked(connection.sendCommand).mockReturnValue( + new Promise(resolve => { + resolveCatalog = resolve; + }) + ); + const { transport } = createTransportWithSinks({ connection }); + + transport.connect(); + emitOwner(connection); + transport.retryRemoteModels?.(); + connection.emitReconnect(); + + expect(connection.sendCommand).toHaveBeenCalledTimes(1); + + resolveCatalog?.(WIRE_CATALOG); + await Promise.resolve(); + await Promise.resolve(); + transport.destroy(); + }); + + it('retains a v1 catalog when a same-owner reconnect refresh fails', async () => { + const connection = createConnection(); + jest + .mocked(connection.sendCommand) + .mockResolvedValueOnce(WIRE_CATALOG) + .mockRejectedValueOnce(new Error('catalog refresh timed out')); + const states: RemoteModelState[] = []; + const { transport } = createTransportWithSinks({ + connection, + onRemoteModelStateChange: state => states.push(state), + }); + + transport.connect(); + emitOwner(connection); + await Promise.resolve(); + await Promise.resolve(); + await Promise.resolve(); + + connection.emitReconnect(); + await Promise.resolve(); + await Promise.resolve(); + + expect(states.at(-1)).toEqual({ + ownerConnectionId: 'owner', + protocol: 'v1', + catalog: REMOTE_CATALOG, + refresh: 'error', + error: 'catalog refresh timed out', + }); + expect(connection.sendCommand).toHaveBeenCalledTimes(2); + transport.destroy(); + }); + + it('clears owner-scoped catalog state and rediscovers after session reappearance', async () => { + const connection = createConnection(); + jest.mocked(connection.sendCommand).mockResolvedValue(WIRE_CATALOG); + const states: RemoteModelState[] = []; + const { transport } = createTransportWithSinks({ + connection, + onRemoteModelStateChange: state => states.push(state), + }); + + transport.connect(); + emitOwner(connection, 'owner-a'); + await Promise.resolve(); + await Promise.resolve(); + await Promise.resolve(); + + connection.emitSystem({ event: 'sessions.list', data: { sessions: [] } }); + expect(states.at(-1)).toEqual({ + ownerConnectionId: null, + protocol: 'unknown', + refresh: 'idle', + }); + + emitOwner(connection, 'owner-b'); + await Promise.resolve(); + await Promise.resolve(); + + expect(connection.sendCommand).toHaveBeenLastCalledWith( + KILO_SESSION_ID, + 'list_models', + { protocolVersion: 1 }, + 'owner-b' + ); + expect(connection.sendCommand).toHaveBeenCalledTimes(2); + transport.destroy(); + }); + + it('changes send readiness with owner presence and owner-fence failures', async () => { + const connection = createConnection(); + jest + .mocked(connection.sendCommand) + .mockImplementation((_sessionId, command) => + Promise.resolve(command === 'list_models' ? WIRE_CATALOG : { ok: true }) + ); + const capabilityChanges = jest.fn(); + const { transport } = createTransportWithSinks({ + connection, + onCapabilityChange: capabilityChanges, + }); + + transport.connect(); + expect(transport.canSend?.()).toBe(false); + + emitOwner(connection); + await Promise.resolve(); + await Promise.resolve(); + expect(transport.canSend?.()).toBe(true); + + jest.mocked(connection.sendCommand).mockRejectedValueOnce( + new UserWebCommandError({ + code: 'SESSION_OWNER_CHANGED', + message: 'Session owner changed', + }) + ); + await expect(transport.interrupt?.()).rejects.toMatchObject({ + code: 'SESSION_OWNER_CHANGED', + }); + expect(transport.canSend?.()).toBe(false); + + emitOwner(connection, 'replacement-owner'); + await Promise.resolve(); + await Promise.resolve(); + expect(transport.canSend?.()).toBe(true); + expect(capabilityChanges).toHaveBeenCalled(); + + connection.emitSystem({ + event: 'sessions.list', + data: { sessions: [] }, + }); + expect(transport.canSend?.()).toBe(false); + transport.destroy(); + }); + it('takes a session subscription lease without starting or destroying the injected connection', () => { const { userWebConnection, transport } = createTransportWithSinks(); @@ -190,6 +604,64 @@ describe('CliLiveTransport unified user web connection', () => { transport.destroy(); }); + it('applies a live session.updated after stale snapshot metadata', async () => { + let resolveSnapshot: ((snapshot: SessionSnapshot) => void) | undefined; + const fetchSnapshot = jest.fn( + () => + new Promise(resolve => { + resolveSnapshot = resolve; + }) + ); + const { userWebConnection, transport, serviceEvents } = createTransportWithSinks({ + fetchSnapshot, + }); + transport.connect(); + expect(fetchSnapshot).toHaveBeenCalledTimes(1); + + userWebConnection.emitCli({ + sessionId: KILO_SESSION_ID, + event: 'session.updated', + data: { + info: { + id: KILO_SESSION_ID, + model: { providerID: 'anthropic', id: 'live-model', variant: 'high' }, + }, + }, + }); + expect(serviceEvents).toHaveLength(0); + + resolveSnapshot?.({ + info: { + id: KILO_SESSION_ID, + model: { providerID: 'openai', id: 'stale-snapshot-model' }, + }, + messages: [], + }); + await Promise.resolve(); + await Promise.resolve(); + + const observedModels = serviceEvents.flatMap(event => { + if ( + (event.type === 'session.created' || event.type === 'session.updated') && + event.info.model + ) { + return [event.info.model]; + } + return []; + }); + expect(serviceEvents.map(event => event.type)).toEqual(['session.created', 'session.updated']); + expect(observedModels).toEqual([ + { providerID: 'openai', id: 'stale-snapshot-model' }, + { providerID: 'anthropic', id: 'live-model', variant: 'high' }, + ]); + expect(observedModels.at(-1)).toEqual({ + providerID: 'anthropic', + id: 'live-model', + variant: 'high', + }); + transport.destroy(); + }); + it('reports initial snapshot failure, drains buffered chat, and stays subscribed', async () => { const onError = jest.fn(); const { userWebConnection, transport, chatEvents } = createTransportWithSinks({ @@ -237,6 +709,336 @@ describe('CliLiveTransport unified user web connection', () => { transport.destroy(); }); + it('sends a v1 CLI-catalog override as a structured model with a catalog-valid variant', async () => { + const connection = createConnection(); + jest + .mocked(connection.sendCommand) + .mockImplementation((_sessionId, command) => + Promise.resolve(command === 'list_models' ? WIRE_CATALOG : { ok: true }) + ); + const { transport } = createTransportWithSinks({ connection }); + + transport.connect(); + emitOwner(connection); + await Promise.resolve(); + await Promise.resolve(); + await Promise.resolve(); + jest.mocked(connection.sendCommand).mockClear(); + + await transport.send?.({ + payload: { + type: 'prompt', + prompt: 'hello', + model: { providerID: 'kilo', modelID: 'stale-model' }, + variant: 'stale-variant', + }, + remoteModelOverride: { + source: 'cli-catalog', + selection: { + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + variant: 'high', + }, + }, + }); + + expect(connection.sendCommand).toHaveBeenCalledWith( + KILO_SESSION_ID, + 'send_message', + { + sessionID: KILO_SESSION_ID, + parts: [{ type: 'text', text: 'hello' }], + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + variant: 'high', + }, + 'owner' + ); + transport.destroy(); + }); + + it('rejects a CLI-catalog variant that is not advertised for the selected model', async () => { + const connection = createConnection(); + jest + .mocked(connection.sendCommand) + .mockImplementation((_sessionId, command) => + Promise.resolve(command === 'list_models' ? WIRE_CATALOG : { ok: true }) + ); + const { transport } = createTransportWithSinks({ connection }); + + transport.connect(); + emitOwner(connection); + await Promise.resolve(); + await Promise.resolve(); + await Promise.resolve(); + jest.mocked(connection.sendCommand).mockClear(); + + await expect( + transport.send?.({ + payload: { type: 'prompt', prompt: 'hello' }, + remoteModelOverride: { + source: 'cli-catalog', + selection: { + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + variant: 'not-advertised', + }, + }, + }) + ).rejects.toThrow('Selected remote model variant is not available in the current CLI catalog'); + + expect(connection.sendCommand).not.toHaveBeenCalled(); + transport.destroy(); + }); + + it('rejects a CLI-catalog override whose model is absent from the current catalog', async () => { + const connection = createConnection(); + jest + .mocked(connection.sendCommand) + .mockImplementation((_sessionId, command) => + Promise.resolve(command === 'list_models' ? WIRE_CATALOG : { ok: true }) + ); + const { transport } = createTransportWithSinks({ connection }); + + transport.connect(); + emitOwner(connection); + await Promise.resolve(); + await Promise.resolve(); + jest.mocked(connection.sendCommand).mockClear(); + + await expect( + transport.send?.({ + payload: { type: 'prompt', prompt: 'hello' }, + remoteModelOverride: { + source: 'cli-catalog', + selection: { model: { providerID: 'anthropic', modelID: 'removed-model' } }, + }, + }) + ).rejects.toThrow('Selected remote model is not available in the current CLI catalog'); + + expect(connection.sendCommand).not.toHaveBeenCalled(); + transport.destroy(); + }); + + it('rejects an explicit override while catalog protocol is unknown', async () => { + const connection = createConnection(); + jest + .mocked(connection.sendCommand) + .mockImplementation((_sessionId, command) => + command === 'list_models' ? new Promise(() => {}) : Promise.resolve({ ok: true }) + ); + const { transport } = createTransportWithSinks({ connection }); + + transport.connect(); + emitOwner(connection); + jest.mocked(connection.sendCommand).mockClear(); + + await expect( + transport.send?.({ + payload: { type: 'prompt', prompt: 'hello' }, + remoteModelOverride: { + source: 'cli-catalog', + selection: { + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + variant: 'high', + }, + }, + }) + ).rejects.toThrow( + 'Selected remote model override is incompatible with the connected CLI model protocol' + ); + + expect(connection.sendCommand).not.toHaveBeenCalled(); + transport.destroy(); + }); + + it('rejects a legacy Gateway override while the connected CLI uses v1 catalogs', async () => { + const connection = createConnection(); + jest + .mocked(connection.sendCommand) + .mockImplementation((_sessionId, command) => + Promise.resolve(command === 'list_models' ? WIRE_CATALOG : { ok: true }) + ); + const { transport } = createTransportWithSinks({ connection }); + + transport.connect(); + emitOwner(connection); + await Promise.resolve(); + await Promise.resolve(); + jest.mocked(connection.sendCommand).mockClear(); + + await expect( + transport.send?.({ + payload: { type: 'prompt', prompt: 'hello' }, + remoteModelOverride: { + source: 'legacy-gateway', + selection: { model: { providerID: 'kilo', modelID: 'anthropic/claude-sonnet-4' } }, + }, + }) + ).rejects.toThrow( + 'Selected remote model override is incompatible with the connected CLI model protocol' + ); + + expect(connection.sendCommand).not.toHaveBeenCalled(); + transport.destroy(); + }); + + it('omits model and variant when no explicit override is provided', async () => { + const connection = createConnection(); + jest + .mocked(connection.sendCommand) + .mockImplementation((_sessionId, command) => + Promise.resolve(command === 'list_models' ? WIRE_CATALOG : { ok: true }) + ); + const { transport } = createTransportWithSinks({ connection }); + + transport.connect(); + emitOwner(connection); + await Promise.resolve(); + await Promise.resolve(); + await Promise.resolve(); + jest.mocked(connection.sendCommand).mockClear(); + + await transport.send?.({ + payload: { + type: 'prompt', + prompt: 'use session precedence', + model: { providerID: 'kilo', modelID: 'observed-only' }, + variant: 'stale', + }, + }); + + expect(connection.sendCommand).toHaveBeenCalledWith( + KILO_SESSION_ID, + 'send_message', + { + sessionID: KILO_SESSION_ID, + parts: [{ type: 'text', text: 'use session precedence' }], + }, + 'owner' + ); + transport.destroy(); + }); + + it.each([ + ['Kilo', { providerID: 'kilo', modelID: 'anthropic/claude-sonnet-4' }], + ['non-Kilo', { providerID: 'anthropic', modelID: 'claude-sonnet-4' }], + ])( + 'omits an observed %s model for a legacy CLI when no override exists', + async (_label, model) => { + const connection = createConnection(); + jest + .mocked(connection.sendCommand) + .mockImplementation((_sessionId, command) => + command === 'list_models' + ? Promise.reject(new Error('unknown command: list_models')) + : Promise.resolve({ ok: true }) + ); + const { transport } = createTransportWithSinks({ connection }); + + transport.connect(); + emitOwner(connection); + await Promise.resolve(); + await Promise.resolve(); + jest.mocked(connection.sendCommand).mockClear(); + + await transport.send?.({ + payload: { + type: 'prompt', + prompt: 'use the session model', + model, + variant: 'observed-variant', + }, + }); + + expect(connection.sendCommand).toHaveBeenCalledWith( + KILO_SESSION_ID, + 'send_message', + { + sessionID: KILO_SESSION_ID, + parts: [{ type: 'text', text: 'use the session model' }], + }, + 'owner' + ); + transport.destroy(); + } + ); + + it('sends an explicit legacy Gateway override as a Kilo model string', async () => { + const connection = createConnection(); + jest + .mocked(connection.sendCommand) + .mockImplementation((_sessionId, command) => + command === 'list_models' + ? Promise.reject(new Error('unknown command: list_models')) + : Promise.resolve({ ok: true }) + ); + const { transport } = createTransportWithSinks({ connection }); + + transport.connect(); + emitOwner(connection); + await Promise.resolve(); + await Promise.resolve(); + jest.mocked(connection.sendCommand).mockClear(); + + await transport.send?.({ + payload: { + type: 'prompt', + prompt: 'hello', + model: { providerID: 'kilo', modelID: 'observed-model' }, + }, + remoteModelOverride: { + source: 'legacy-gateway', + selection: { + model: { providerID: 'kilo', modelID: 'anthropic/claude-sonnet-4' }, + variant: 'high', + }, + }, + }); + + expect(connection.sendCommand).toHaveBeenCalledWith( + KILO_SESSION_ID, + 'send_message', + { + sessionID: KILO_SESSION_ID, + parts: [{ type: 'text', text: 'hello' }], + model: 'anthropic/claude-sonnet-4', + variant: 'high', + }, + 'owner' + ); + transport.destroy(); + }); + + it('rejects a legacy Gateway override that does not use the Kilo provider', async () => { + const connection = createConnection(); + jest + .mocked(connection.sendCommand) + .mockImplementation((_sessionId, command) => + command === 'list_models' + ? Promise.reject(new Error('unknown command: list_models')) + : Promise.resolve({ ok: true }) + ); + const { transport } = createTransportWithSinks({ connection }); + + transport.connect(); + emitOwner(connection); + await Promise.resolve(); + await Promise.resolve(); + jest.mocked(connection.sendCommand).mockClear(); + + await expect( + transport.send?.({ + payload: { type: 'prompt', prompt: 'hello' }, + remoteModelOverride: { + source: 'legacy-gateway', + selection: { model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' } }, + }, + }) + ).rejects.toThrow( + 'Selected remote model override is incompatible with the connected CLI model protocol' + ); + expect(connection.sendCommand).not.toHaveBeenCalled(); + transport.destroy(); + }); + it.each([ [ 'send', @@ -245,17 +1047,25 @@ describe('CliLiveTransport unified user web connection', () => { type: 'prompt' as const, prompt: 'hello', mode: 'code', - model: 'm', - variant: 'v', + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + variant: 'high', + }, + remoteModelOverride: { + source: 'cli-catalog' as const, + selection: { + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + variant: 'high', + }, }, }), + 'send_message', { sessionID: KILO_SESSION_ID, parts: [{ type: 'text', text: 'hello' }], agent: 'code', - model: 'm', - variant: 'v', + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + variant: 'high', }, ], ['interrupt', () => undefined, 'interrupt', {}], @@ -282,14 +1092,31 @@ describe('CliLiveTransport unified user web connection', () => { ])( 'delegates %s commands through the injected connection', async (method, input, command, data) => { - const { userWebConnection, transport } = createTransportWithSinks(); + const connection = createConnection(); + jest + .mocked(connection.sendCommand) + .mockImplementation((_sessionId, commandName) => + Promise.resolve(commandName === 'list_models' ? WIRE_CATALOG : { ok: true }) + ); + const { userWebConnection, transport } = createTransportWithSinks({ connection }); const invoke = transport[method as keyof typeof transport] as ( value?: unknown ) => Promise; + transport.connect(); + emitOwner(connection); + await Promise.resolve(); + await Promise.resolve(); + jest.mocked(userWebConnection.sendCommand).mockClear(); + await invoke(input()); - expect(userWebConnection.sendCommand).toHaveBeenCalledWith(KILO_SESSION_ID, command, data); + expect(userWebConnection.sendCommand).toHaveBeenCalledWith( + KILO_SESSION_ID, + command, + data, + 'owner' + ); transport.destroy(); } ); diff --git a/apps/web/src/lib/cloud-agent-sdk/cli-live-transport.ts b/apps/web/src/lib/cloud-agent-sdk/cli-live-transport.ts index 1140ae4448..c5cbe1d5aa 100644 --- a/apps/web/src/lib/cloud-agent-sdk/cli-live-transport.ts +++ b/apps/web/src/lib/cloud-agent-sdk/cli-live-transport.ts @@ -3,16 +3,28 @@ * one remote CLI session into normalized transport events and commands. */ import { normalizeCliEvent, isChatEvent } from './normalizer'; -import { cliConnectionDataSchema, heartbeatDataSchema, sessionsListDataSchema } from './schemas'; -import type { TransportFactory, TransportSendPayload, TransportSink } from './transport'; +import { + cliConnectionDataSchema, + heartbeatDataSchema, + remoteModelCatalogV1Schema, + sessionsListDataSchema, +} from './schemas'; +import type { RemoteModelState } from './remote-model-catalog'; +import type { TransportFactory, TransportSendInput, TransportSink } from './transport'; import type { KiloSessionId, SessionSnapshot } from './types'; -import type { UserWebCliEvent, UserWebConnection } from './user-web-connection'; +import { + UserWebCommandError, + type UserWebCliEvent, + type UserWebConnection, +} from './user-web-connection'; type CliLiveTransportConfig = { kiloSessionId: KiloSessionId; userWebConnection: UserWebConnection; fetchSnapshot?: (kiloSessionId: KiloSessionId) => Promise; onError?: (message: string) => void; + onRemoteModelStateChange?: (state: RemoteModelState) => void; + onCapabilityChange?: () => void; }; function createCliLiveTransport(config: CliLiveTransportConfig): TransportFactory { @@ -21,6 +33,138 @@ function createCliLiveTransport(config: CliLiveTransportConfig): TransportFactor let cleanup: (() => void) | null = null; let sessionStopped = false; let ownerConnectionId: string | null = null; + let catalogRequestGeneration = 0; + let catalogRequestInFlight: { ownerConnectionId: string; generation: number } | null = null; + let remoteModelState: RemoteModelState = { + ownerConnectionId: null, + protocol: 'unknown', + refresh: 'idle', + }; + + function publishRemoteModelState(next: RemoteModelState): void { + remoteModelState = next; + config.onRemoteModelStateChange?.(next); + } + + function setOwnerConnectionId(nextOwnerConnectionId: string | null): void { + if (ownerConnectionId === nextOwnerConnectionId) return; + + ownerConnectionId = nextOwnerConnectionId; + catalogRequestGeneration += 1; + catalogRequestInFlight = null; + publishRemoteModelState({ + ownerConnectionId: nextOwnerConnectionId, + protocol: 'unknown', + refresh: nextOwnerConnectionId ? 'loading' : 'idle', + }); + config.onCapabilityChange?.(); + + if (nextOwnerConnectionId) discoverModels(nextOwnerConnectionId); + } + + function handleCatalogFailure( + error: unknown, + expectedOwnerConnectionId: string, + expectedGeneration: number, + expectedRequestGeneration: number + ): void { + if ( + expectedGeneration !== generation || + expectedRequestGeneration !== catalogRequestGeneration || + ownerConnectionId !== expectedOwnerConnectionId + ) { + return; + } + + if (error instanceof UserWebCommandError && error.code === 'SESSION_OWNER_CHANGED') { + setOwnerConnectionId(null); + return; + } + + if (error instanceof Error && error.message === 'unknown command: list_models') { + publishRemoteModelState({ + ownerConnectionId: expectedOwnerConnectionId, + protocol: 'legacy', + refresh: 'idle', + }); + return; + } + + publishRemoteModelState({ + ownerConnectionId: expectedOwnerConnectionId, + protocol: remoteModelState.protocol, + ...(remoteModelState.catalog ? { catalog: remoteModelState.catalog } : {}), + refresh: 'error', + error: error instanceof Error ? error.message : 'Failed to discover remote models', + }); + } + + function discoverModels(expectedOwnerConnectionId: string): void { + if (catalogRequestInFlight?.ownerConnectionId === expectedOwnerConnectionId) return; + + catalogRequestGeneration += 1; + const expectedRequestGeneration = catalogRequestGeneration; + const expectedGeneration = generation; + catalogRequestInFlight = { + ownerConnectionId: expectedOwnerConnectionId, + generation: expectedRequestGeneration, + }; + publishRemoteModelState({ + ownerConnectionId: expectedOwnerConnectionId, + protocol: remoteModelState.protocol, + ...(remoteModelState.catalog ? { catalog: remoteModelState.catalog } : {}), + refresh: 'loading', + }); + + void config.userWebConnection + .sendCommand( + config.kiloSessionId, + 'list_models', + { protocolVersion: 1 }, + expectedOwnerConnectionId + ) + .then( + result => { + if ( + expectedGeneration !== generation || + expectedRequestGeneration !== catalogRequestGeneration || + ownerConnectionId !== expectedOwnerConnectionId + ) { + return; + } + + const parsed = remoteModelCatalogV1Schema.safeParse(result); + if (!parsed.success) { + handleCatalogFailure( + new Error('Invalid remote model catalog'), + expectedOwnerConnectionId, + expectedGeneration, + expectedRequestGeneration + ); + return; + } + + publishRemoteModelState({ + ownerConnectionId: expectedOwnerConnectionId, + protocol: 'v1', + catalog: parsed.data, + refresh: 'idle', + }); + }, + error => + handleCatalogFailure( + error, + expectedOwnerConnectionId, + expectedGeneration, + expectedRequestGeneration + ) + ) + .finally(() => { + if (catalogRequestInFlight?.generation === expectedRequestGeneration) { + catalogRequestInFlight = null; + } + }); + } function replaySnapshot(snapshot: SessionSnapshot): void { sink.onServiceEvent({ type: 'session.created', info: snapshot.info }); @@ -62,6 +206,7 @@ function createCliLiveTransport(config: CliLiveTransportConfig): TransportFactor if (event === 'cli.disconnected') { const parsed = cliConnectionDataSchema.safeParse(data); if (parsed.success && ownerConnectionId === parsed.data.connectionId) { + setOwnerConnectionId(null); stopForDisconnectedSession(); } return; @@ -73,11 +218,12 @@ function createCliLiveTransport(config: CliLiveTransportConfig): TransportFactor const session = parsed.data.sessions.find(item => item.id === config.kiloSessionId); if (session) { - ownerConnectionId = session.connectionId; + setOwnerConnectionId(session.connectionId); sessionStopped = false; return; } + setOwnerConnectionId(null); stopForDisconnectedSession(); return; } @@ -88,19 +234,85 @@ function createCliLiveTransport(config: CliLiveTransportConfig): TransportFactor const session = parsed.data.sessions.find(item => item.id === config.kiloSessionId); if (session) { - ownerConnectionId = parsed.data.connectionId; + setOwnerConnectionId(parsed.data.connectionId); sessionStopped = false; return; } if (ownerConnectionId === parsed.data.connectionId) { + setOwnerConnectionId(null); stopForDisconnectedSession(); } } } - function sendCommand(command: string, data: unknown): Promise { - return config.userWebConnection.sendCommand(config.kiloSessionId, command, data); + async function sendCommand(command: string, data: unknown): Promise { + const expectedOwnerConnectionId = ownerConnectionId; + if (!expectedOwnerConnectionId) throw new Error('Remote session has no connected owner'); + + try { + return await config.userWebConnection.sendCommand( + config.kiloSessionId, + command, + data, + expectedOwnerConnectionId + ); + } catch (error) { + if (error instanceof UserWebCommandError && error.code === 'SESSION_OWNER_CHANGED') { + setOwnerConnectionId(null); + } + throw error; + } + } + + function getRemoteModelFields(input: TransportSendInput): + | { kind: 'none' } + | { + kind: 'structured'; + model: { providerID: string; modelID: string }; + variant?: string; + } + | { kind: 'legacy'; model: string; variant?: string } { + const override = input.remoteModelOverride; + if (!override) return { kind: 'none' }; + + if (remoteModelState.protocol === 'v1' && override.source === 'cli-catalog') { + const provider = remoteModelState.catalog?.providers.find( + item => item.id === override.selection.model.providerID + ); + const model = provider?.models.find(item => item.id === override.selection.model.modelID); + if (!model) { + throw new Error('Selected remote model is not available in the current CLI catalog'); + } + + const variant = override.selection.variant; + if (variant && !model.variants.includes(variant)) { + throw new Error( + 'Selected remote model variant is not available in the current CLI catalog' + ); + } + return { + kind: 'structured', + model: override.selection.model, + ...(variant ? { variant } : {}), + }; + } + + if ( + remoteModelState.protocol === 'legacy' && + override.source === 'legacy-gateway' && + override.selection.model.providerID === 'kilo' + ) { + return { + kind: 'legacy', + model: override.selection.model.modelID, + ...(override.selection.variant ? { variant: override.selection.variant } : {}), + }; + } + + throw new Error( + 'Selected remote model override is incompatible with the connected CLI model protocol' + ); } function releaseConnection(): void { @@ -115,6 +327,14 @@ function createCliLiveTransport(config: CliLiveTransportConfig): TransportFactor releaseConnection(); sessionStopped = false; ownerConnectionId = null; + catalogRequestGeneration += 1; + catalogRequestInFlight = null; + publishRemoteModelState({ + ownerConnectionId: null, + protocol: 'unknown', + refresh: 'idle', + }); + config.onCapabilityChange?.(); let bufferedCliEvents: UserWebCliEvent[] | null = []; let bufferedEventsFromSupersededSnapshot: UserWebCliEvent[] = []; @@ -182,7 +402,12 @@ function createCliLiveTransport(config: CliLiveTransportConfig): TransportFactor replayCurrentSnapshot(true); const offCli = config.userWebConnection.onCliEvent(config.kiloSessionId, msg => { const normalized = normalizeCliEvent(msg.event, msg.data); - if (normalized && isChatEvent(normalized) && bufferedCliEvents !== null) { + const shouldBufferForSnapshot = + normalized && + (isChatEvent(normalized) || + normalized.type === 'session.created' || + normalized.type === 'session.updated'); + if (shouldBufferForSnapshot && bufferedCliEvents !== null) { bufferedCliEvents.push(msg); return; } @@ -193,6 +418,7 @@ function createCliLiveTransport(config: CliLiveTransportConfig): TransportFactor }); const offReconnect = config.userWebConnection.onReconnect(() => { replayCurrentSnapshot(false); + if (ownerConnectionId) discoverModels(ownerConnectionId); }); const releaseSubscription = config.userWebConnection.subscribeToCliSession( config.kiloSessionId @@ -208,19 +434,28 @@ function createCliLiveTransport(config: CliLiveTransportConfig): TransportFactor }; }, - send: (input: { payload: TransportSendPayload }) => { + canSend: () => ownerConnectionId !== null, + retryRemoteModels: () => { + if (ownerConnectionId) discoverModels(ownerConnectionId); + }, + send: async (input: TransportSendInput) => { if (input.payload.type === 'command') { return Promise.reject( new Error('Slash commands are not supported on the CLI live transport yet') ); } const payload = input.payload; + const remoteModel = getRemoteModelFields(input); return sendCommand('send_message', { sessionID: config.kiloSessionId, parts: [{ type: 'text', text: payload.prompt }], ...(payload.mode ? { agent: payload.mode } : {}), - ...(payload.model ? { model: payload.model } : {}), - ...(payload.variant ? { variant: payload.variant } : {}), + ...(remoteModel.kind === 'none' + ? {} + : { + model: remoteModel.model, + ...(remoteModel.variant ? { variant: remoteModel.variant } : {}), + }), }); }, interrupt: () => sendCommand('interrupt', {}), @@ -250,11 +485,13 @@ function createCliLiveTransport(config: CliLiveTransportConfig): TransportFactor disconnect() { generation += 1; + setOwnerConnectionId(null); releaseConnection(); }, destroy() { generation += 1; + setOwnerConnectionId(null); releaseConnection(); }, }; diff --git a/apps/web/src/lib/cloud-agent-sdk/cloud-agent-transport.test.ts b/apps/web/src/lib/cloud-agent-sdk/cloud-agent-transport.test.ts index 86fbd1f037..dfb87e5313 100644 --- a/apps/web/src/lib/cloud-agent-sdk/cloud-agent-transport.test.ts +++ b/apps/web/src/lib/cloud-agent-sdk/cloud-agent-transport.test.ts @@ -392,12 +392,17 @@ describe('CloudAgentTransport lifecycle', () => { }); describe('CloudAgentTransport command delegation', () => { - it('send() delegates to api.send with bound sessionId', () => { + it('converts a Kilo model ref before delegating to api.send', async () => { const api = createMockApi(); const { transport } = createTransportWithSinks(undefined, undefined, api); - void transport.send!({ - payload: { type: 'prompt', prompt: 'hello', mode: 'code', model: 'gpt-4' }, + await transport.send!({ + payload: { + type: 'prompt', + prompt: 'hello', + mode: 'code', + model: { providerID: 'kilo', modelID: 'gpt-4' }, + }, }); expect(api.send).toHaveBeenCalledWith({ @@ -408,7 +413,26 @@ describe('CloudAgentTransport command delegation', () => { transport.destroy(); }); - it('send() delegates canonical document attachments to api.send', () => { + it('rejects a non-Kilo model ref before calling the Cloud Agent API', async () => { + const api = createMockApi(); + const { transport } = createTransportWithSinks(undefined, undefined, api); + + await expect( + transport.send!({ + payload: { + type: 'prompt', + prompt: 'hello', + mode: 'code', + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + }, + }) + ).rejects.toThrow('Cloud Agent only supports Kilo models'); + expect(api.send).not.toHaveBeenCalled(); + + transport.destroy(); + }); + + it('converts Kilo model refs while preserving canonical document attachments', async () => { const api = createMockApi(); const { transport } = createTransportWithSinks(undefined, undefined, api); const attachments = { @@ -416,8 +440,13 @@ describe('CloudAgentTransport command delegation', () => { files: ['87654321-4321-4321-8321-cba987654321.pdf'], }; - void transport.send!({ - payload: { type: 'prompt', prompt: 'read it', mode: 'code', model: 'gpt-4' }, + await transport.send!({ + payload: { + type: 'prompt', + prompt: 'read it', + mode: 'code', + model: { providerID: 'kilo', modelID: 'gpt-4' }, + }, attachments, }); diff --git a/apps/web/src/lib/cloud-agent-sdk/cloud-agent-transport.ts b/apps/web/src/lib/cloud-agent-sdk/cloud-agent-transport.ts index 60f1715483..1dfe3cb47b 100644 --- a/apps/web/src/lib/cloud-agent-sdk/cloud-agent-transport.ts +++ b/apps/web/src/lib/cloud-agent-sdk/cloud-agent-transport.ts @@ -13,11 +13,30 @@ import type { ServiceEvent } from './normalizer'; import type { CloudAgentSessionId, KiloSessionId, SessionSnapshot } from './types'; import type { CloudAgentApi, + CloudAgentSendPayload, CloudAgentStreamTicketResult, TransportFactory, + TransportSendPayload, TransportSink, } from './transport'; +function normalizeCloudAgentPayload(payload: TransportSendPayload): CloudAgentSendPayload { + if (payload.type === 'command') return payload; + if (!payload.mode) throw new Error('Cloud Agent mode is required'); + if (!payload.model) throw new Error('Cloud Agent model is required'); + if (payload.model.providerID !== 'kilo') { + throw new Error('Cloud Agent only supports Kilo models'); + } + + return { + type: 'prompt', + prompt: payload.prompt, + mode: payload.mode, + model: payload.model.modelID, + ...(payload.variant ? { variant: payload.variant } : {}), + }; +} + type CloudAgentTransportConfig = { sessionId: CloudAgentSessionId; kiloSessionId: KiloSessionId; @@ -181,7 +200,14 @@ function createCloudAgentTransport(config: CloudAgentTransportConfig): Transport closeConnection('destroy'); }, - send: payload => config.api.send({ sessionId: config.sessionId, ...payload }), + send: async input => + config.api.send({ + sessionId: config.sessionId, + payload: normalizeCloudAgentPayload(input.payload), + ...(input.messageId ? { messageId: input.messageId } : {}), + ...(input.attachments ? { attachments: input.attachments } : {}), + ...(input.images ? { images: input.images } : {}), + }), interrupt: () => config.api.interrupt({ sessionId: config.sessionId }), answer: payload => config.api.answer({ sessionId: config.sessionId, ...payload }), reject: payload => config.api.reject({ sessionId: config.sessionId, ...payload }), diff --git a/apps/web/src/lib/cloud-agent-sdk/index.ts b/apps/web/src/lib/cloud-agent-sdk/index.ts index 23b4432dac..632cbdda82 100644 --- a/apps/web/src/lib/cloud-agent-sdk/index.ts +++ b/apps/web/src/lib/cloud-agent-sdk/index.ts @@ -59,7 +59,7 @@ export type { CliHistoricalTransportConfig } from './cli-historical-transport'; export { createCliLiveTransport } from './cli-live-transport'; export type { CliLiveTransportConfig } from './cli-live-transport'; -export { createUserWebConnection } from './user-web-connection'; +export { createUserWebConnection, UserWebCommandError } from './user-web-connection'; export type { UserWebConnection, UserWebConnectionConfig, @@ -70,6 +70,30 @@ export type { UserWebSystemEvent, } from './user-web-connection'; +export { + REMOTE_MODEL_IDENTITY_MAX_LENGTH, + REMOTE_MODEL_MAX_MODELS_PER_PROVIDER, + REMOTE_MODEL_MAX_MODELS_TOTAL, + REMOTE_MODEL_MAX_PROVIDERS, + REMOTE_MODEL_MAX_VARIANTS_PER_MODEL, + REMOTE_MODEL_MAX_VARIANTS_TOTAL, + createModelRefKeyMap, + modelRefSchema, + modelRefsEqual, + modelSelectionSchema, + remoteModelCatalogV1Schema, + remoteModelCatalogWireV1Schema, +} from './remote-model-catalog'; +export type { + ModelRef, + ModelRefKeyMap, + ModelSelection, + RemoteModelCatalogV1, + RemoteModelCatalogWireV1, + RemoteModelOverride, + RemoteModelState, +} from './remote-model-catalog'; + export type { CloudAgentApi, CloudAgentStreamTicket, diff --git a/apps/web/src/lib/cloud-agent-sdk/normalizer.test.ts b/apps/web/src/lib/cloud-agent-sdk/normalizer.test.ts index 4b77bbd02c..3852376276 100644 --- a/apps/web/src/lib/cloud-agent-sdk/normalizer.test.ts +++ b/apps/web/src/lib/cloud-agent-sdk/normalizer.test.ts @@ -447,6 +447,45 @@ describe('normalize', () => { }); }); + it('preserves session model identity from session.created', () => { + const result = normalize( + createRaw('session.created', { + info: { + id: 'ses-1', + model: { + providerID: 'custom-provider', + id: 'deployment/model.v1', + variant: 'thinking', + }, + }, + }) + ); + + expect(result).toEqual({ + type: 'session.created', + info: { + id: 'ses-1', + parentID: undefined, + model: { + providerID: 'custom-provider', + id: 'deployment/model.v1', + variant: 'thinking', + }, + }, + }); + }); + + it('omits malformed session model metadata without dropping the event', () => { + const result = normalizeCliEvent('session.created', { + info: { id: 'ses-1', model: { providerID: 'openai', id: 42 } }, + }); + + expect(result).toEqual({ + type: 'session.created', + info: { id: 'ses-1', parentID: undefined }, + }); + }); + it('returns null when info is missing', () => { expect(normalize(createRaw('session.created', {}))).toBeNull(); }); @@ -467,6 +506,24 @@ describe('normalize', () => { }); }); + it('preserves session model identity from session.updated', () => { + const result = normalizeCliEvent('session.updated', { + info: { + id: 'ses-1', + model: { providerID: 'openai', id: 'gpt-5.1' }, + }, + }); + + expect(result).toEqual({ + type: 'session.updated', + info: { + id: 'ses-1', + parentID: undefined, + model: { providerID: 'openai', id: 'gpt-5.1' }, + }, + }); + }); + it('returns null when info is missing', () => { expect(normalize(createRaw('session.updated', {}))).toBeNull(); }); diff --git a/apps/web/src/lib/cloud-agent-sdk/normalizer.ts b/apps/web/src/lib/cloud-agent-sdk/normalizer.ts index ae4a03c84d..c0c510fced 100644 --- a/apps/web/src/lib/cloud-agent-sdk/normalizer.ts +++ b/apps/web/src/lib/cloud-agent-sdk/normalizer.ts @@ -3,6 +3,7 @@ * wire data and typed internal code. Validates shape via Zod schemas then uses * boundary `as` casts so downstream code receives properly typed NormalizedEvents. */ +import { z } from 'zod'; import type { Part, SessionStatus, QuestionInfo, Message } from '@/types/opencode.gen'; import type { SessionInfo, CloudStatus, SuggestionAction, SlashCommandInfo } from './types'; import { @@ -177,6 +178,21 @@ function extractErrorMessage(rawError: unknown): string { return 'Unknown error'; } +const sessionModelSchema = z.object({ + providerID: z.string(), + id: z.string(), + variant: z.string().optional(), +}); + +function normalizeSessionInfo(rawInfo: { id: string; [key: string]: unknown }): SessionInfo { + const model = sessionModelSchema.safeParse(rawInfo.model); + return { + id: rawInfo.id, + parentID: rawInfo.parentID != null ? String(rawInfo.parentID) : undefined, + ...(model.success ? { model: model.data } : {}), + }; +} + function normalizeInnerEvent(eventType: string, data: unknown): NormalizedEvent | null { switch (eventType) { case 'message.updated': { @@ -231,10 +247,7 @@ function normalizeInnerEvent(eventType: string, data: unknown): NormalizedEvent const rawCreated = r.data.info; return { type: 'session.created', - info: { - id: rawCreated.id, - parentID: rawCreated.parentID != null ? String(rawCreated.parentID) : undefined, - }, + info: normalizeSessionInfo(rawCreated), }; } @@ -244,10 +257,7 @@ function normalizeInnerEvent(eventType: string, data: unknown): NormalizedEvent const rawUpdated = r.data.info; return { type: 'session.updated', - info: { - id: rawUpdated.id, - parentID: rawUpdated.parentID != null ? String(rawUpdated.parentID) : undefined, - }, + info: normalizeSessionInfo(rawUpdated), }; } diff --git a/apps/web/src/lib/cloud-agent-sdk/remote-model-catalog.test.ts b/apps/web/src/lib/cloud-agent-sdk/remote-model-catalog.test.ts new file mode 100644 index 0000000000..558c9f3cf9 --- /dev/null +++ b/apps/web/src/lib/cloud-agent-sdk/remote-model-catalog.test.ts @@ -0,0 +1,393 @@ +import { + REMOTE_MODEL_CATALOG_MAX_SERIALIZED_BYTES, + REMOTE_MODEL_IDENTITY_MAX_LENGTH, + REMOTE_MODEL_MAX_MODELS_PER_PROVIDER, + REMOTE_MODEL_MAX_PROVIDERS, + REMOTE_MODEL_MAX_VARIANTS_PER_MODEL, + createModelRefKeyMap, + modelRefSchema, + modelRefsEqual, + remoteModelCatalogV1Schema, + remoteModelCatalogWireV1Schema, +} from './remote-model-catalog'; + +function createSdkModel(providerID: string, id: string, variants: string[] = [], name = id) { + return { + id, + providerID, + api: { id, url: '', npm: '' }, + name, + capabilities: { + temperature: true, + reasoning: true, + attachment: true, + toolcall: true, + input: { text: true, audio: false, image: true, video: false, pdf: true }, + output: { text: true, audio: false, image: false, video: false, pdf: false }, + interleaved: false, + }, + cost: { input: 0, output: 0, cache: { read: 0, write: 0 } }, + limit: { context: 128_000, output: 16_000 }, + status: 'active' as const, + options: {}, + headers: {}, + release_date: '', + variants: Object.fromEntries(variants.map(variant => [variant, {}])), + }; +} + +function createSdkProvider( + id: string, + models: ReturnType[] = [createSdkModel(id, `model-${id}`)] +) { + return { + id, + name: id, + source: 'custom' as const, + env: [], + options: {}, + models: Object.fromEntries(models.map(model => [model.id, model])), + }; +} + +function createWireCatalog(all: ReturnType[]) { + return { + all, + default: Object.fromEntries( + all.flatMap(provider => { + const modelID = Object.keys(provider.models)[0]; + return modelID ? [[provider.id, modelID]] : []; + }) + ), + connected: all.map(provider => provider.id), + failed: [], + protocolVersion: 1 as const, + truncated: false, + }; +} + +function getSerializedByteLength(value: unknown): number { + return new TextEncoder().encode(JSON.stringify(value)).byteLength; +} + +function createCatalogWithSerializedBytes(targetBytes: number) { + for (let count = 256; count <= 2_048; count += 64) { + const models = Array.from({ length: count }, (_, index) => + createSdkModel( + `provider-${Math.floor(index / REMOTE_MODEL_MAX_MODELS_PER_PROVIDER)}`, + `model-${index}`, + [], + '' + ) + ); + const providers = Array.from( + { length: Math.ceil(count / REMOTE_MODEL_MAX_MODELS_PER_PROVIDER) }, + (_, providerIndex) => + createSdkProvider( + `provider-${providerIndex}`, + models.slice( + providerIndex * REMOTE_MODEL_MAX_MODELS_PER_PROVIDER, + (providerIndex + 1) * REMOTE_MODEL_MAX_MODELS_PER_PROVIDER + ) + ) + ); + const catalog = createWireCatalog(providers); + let remainingBytes = targetBytes - getSerializedByteLength(catalog); + if (remainingBytes < 0 || remainingBytes > count * REMOTE_MODEL_IDENTITY_MAX_LENGTH) continue; + + for (const model of models) { + const addedBytes = Math.min(remainingBytes, REMOTE_MODEL_IDENTITY_MAX_LENGTH); + model.name = 'x'.repeat(addedBytes); + remainingBytes -= addedBytes; + if (remainingBytes === 0) break; + } + if (getSerializedByteLength(catalog) === targetBytes) return catalog; + } + throw new Error(`Cannot create a catalog with ${targetBytes} serialized bytes`); +} + +function createUtf8OversizedCatalog() { + for (let count = 256; count <= 2_048; count += 64) { + const models = Array.from({ length: count }, (_, index) => + createSdkModel( + `provider-${Math.floor(index / REMOTE_MODEL_MAX_MODELS_PER_PROVIDER)}`, + `model-${index}`, + [], + 'é'.repeat(REMOTE_MODEL_IDENTITY_MAX_LENGTH) + ) + ); + const providers = Array.from( + { length: Math.ceil(count / REMOTE_MODEL_MAX_MODELS_PER_PROVIDER) }, + (_, providerIndex) => + createSdkProvider( + `provider-${providerIndex}`, + models.slice( + providerIndex * REMOTE_MODEL_MAX_MODELS_PER_PROVIDER, + (providerIndex + 1) * REMOTE_MODEL_MAX_MODELS_PER_PROVIDER + ) + ) + ); + const catalog = createWireCatalog(providers); + if ( + JSON.stringify(catalog).length < REMOTE_MODEL_CATALOG_MAX_SERIALIZED_BYTES && + getSerializedByteLength(catalog) > REMOTE_MODEL_CATALOG_MAX_SERIALIZED_BYTES + ) { + return catalog; + } + } + throw new Error('Cannot create a UTF-8 oversized catalog'); +} + +describe('remoteModelCatalogV1Schema', () => { + it('normalizes the SDK ProviderListResponse shape without rewriting model identities', () => { + const connected = createSdkProvider('custom/provider:v1', [ + createSdkModel('custom/provider:v1', 'team/model.v2-beta', ['reasoning/high'], 'Team model'), + ]); + connected.name = 'Private deployment'; + const disconnected = createSdkProvider('disconnected'); + const wire = { + ...createWireCatalog([connected, disconnected]), + connected: ['custom/provider:v1'], + currentModel: { + model: { providerID: 'custom/provider:v1', modelID: 'team/model.v2-beta' }, + variant: 'reasoning/high', + }, + defaultModel: { providerID: 'custom/provider:v1', modelID: 'team/model.v2-beta' }, + }; + + expect(remoteModelCatalogV1Schema.parse(wire)).toEqual({ + protocolVersion: 1, + providers: [ + { + id: 'custom/provider:v1', + name: 'Private deployment', + models: [ + { + id: 'team/model.v2-beta', + name: 'Team model', + variants: ['reasoning/high'], + capabilities: { attachment: true, reasoning: true }, + limits: { context: 128_000, output: 16_000 }, + }, + ], + }, + ], + currentModel: wire.currentModel, + defaultModel: wire.defaultModel, + truncated: false, + }); + }); + + it('rejects duplicate provider IDs and inconsistent model identities', () => { + const duplicate = createSdkProvider('provider'); + expect( + remoteModelCatalogWireV1Schema.safeParse(createWireCatalog([duplicate, duplicate])).success + ).toBe(false); + + const wrongKey = createSdkProvider('provider'); + const model = wrongKey.models['model-provider']; + if (!model) throw new Error('Expected model fixture'); + wrongKey.models = { 'wrong-key': model }; + + const wrongProvider = createSdkProvider('provider'); + wrongProvider.models['model-provider'] = { ...model, providerID: 'other-provider' }; + + const wrongApiID = createSdkProvider('provider'); + wrongApiID.models['model-provider'] = { ...model, api: { ...model.api, id: 'other-model' } }; + + for (const provider of [wrongKey, wrongProvider, wrongApiID]) { + expect(remoteModelCatalogWireV1Schema.safeParse(createWireCatalog([provider])).success).toBe( + false + ); + } + }); + + it('rejects dangling connected and default references, including inherited property names', () => { + const provider = createSdkProvider('provider'); + const base = createWireCatalog([provider]); + + expect( + remoteModelCatalogWireV1Schema.safeParse({ ...base, connected: ['missing'] }).success + ).toBe(false); + expect( + remoteModelCatalogWireV1Schema.safeParse({ + ...base, + default: { provider: 'toString' }, + }).success + ).toBe(false); + }); + + it('rejects credential-bearing provider and model configuration', () => { + const provider = createSdkProvider('provider'); + const base = createWireCatalog([provider]); + const model = provider.models['model-provider']; + if (!model) throw new Error('Expected model fixture'); + + expect( + remoteModelCatalogWireV1Schema.safeParse({ + ...base, + all: [{ ...provider, key: 'secret' }], + }).success + ).toBe(false); + expect( + remoteModelCatalogWireV1Schema.safeParse({ + ...base, + all: [{ ...provider, env: ['PRIVATE_API_KEY'] }], + }).success + ).toBe(false); + expect( + remoteModelCatalogWireV1Schema.safeParse({ + ...base, + all: [{ ...provider, options: { apiKey: 'secret' } }], + }).success + ).toBe(false); + expect( + remoteModelCatalogWireV1Schema.safeParse({ + ...base, + all: [ + { + ...provider, + models: { [model.id]: { ...model, headers: { authorization: 'secret' } } }, + }, + ], + }).success + ).toBe(false); + expect( + remoteModelCatalogWireV1Schema.safeParse({ + ...base, + all: [ + { + ...provider, + models: { [model.id]: { ...model, api: { ...model.api, url: 'https://private' } } }, + }, + ], + }).success + ).toBe(false); + expect( + remoteModelCatalogWireV1Schema.safeParse({ + ...base, + all: [ + { + ...provider, + models: { [model.id]: { ...model, variants: { precise: { apiKey: 'secret' } } } }, + }, + ], + }).success + ).toBe(false); + + const otherUnsafeModels = [ + { ...model, options: { apiKey: 'secret' } }, + { ...model, api: { ...model.api, npm: 'file:///private/provider' } }, + { ...model, cost: { ...model.cost, input: 1 } }, + { ...model, release_date: 'private-release-metadata' }, + ]; + for (const unsafeModel of otherUnsafeModels) { + expect( + remoteModelCatalogWireV1Schema.safeParse({ + ...base, + all: [{ ...provider, models: { [model.id]: unsafeModel } }], + }).success + ).toBe(false); + } + }); + + it('accepts exactly 512 KiB and rejects one serialized byte over the limit', () => { + const atLimit = createCatalogWithSerializedBytes(REMOTE_MODEL_CATALOG_MAX_SERIALIZED_BYTES); + const overLimit = createCatalogWithSerializedBytes( + REMOTE_MODEL_CATALOG_MAX_SERIALIZED_BYTES + 1 + ); + + expect(getSerializedByteLength(atLimit)).toBe(REMOTE_MODEL_CATALOG_MAX_SERIALIZED_BYTES); + expect(remoteModelCatalogWireV1Schema.safeParse(atLimit).success).toBe(true); + expect(getSerializedByteLength(overLimit)).toBe(REMOTE_MODEL_CATALOG_MAX_SERIALIZED_BYTES + 1); + expect(remoteModelCatalogWireV1Schema.safeParse(overLimit).success).toBe(false); + }); + + it('measures the serialized catalog limit in UTF-8 bytes', () => { + const catalog = createUtf8OversizedCatalog(); + + expect(JSON.stringify(catalog).length).toBeLessThan(REMOTE_MODEL_CATALOG_MAX_SERIALIZED_BYTES); + expect(getSerializedByteLength(catalog)).toBeGreaterThan( + REMOTE_MODEL_CATALOG_MAX_SERIALIZED_BYTES + ); + expect(remoteModelCatalogWireV1Schema.safeParse(catalog).success).toBe(false); + }); + + it('enforces provider, per-provider model, and per-model variant count bounds', () => { + const tooManyProviders = Array.from({ length: REMOTE_MODEL_MAX_PROVIDERS + 1 }, (_, index) => + createSdkProvider(`provider-${index}`) + ); + const tooManyModels = Array.from( + { length: REMOTE_MODEL_MAX_MODELS_PER_PROVIDER + 1 }, + (_, index) => createSdkModel('provider', `model-${index}`) + ); + const tooManyVariants = Array.from( + { length: REMOTE_MODEL_MAX_VARIANTS_PER_MODEL + 1 }, + (_, index) => `variant-${index}` + ); + + expect( + remoteModelCatalogWireV1Schema.safeParse(createWireCatalog(tooManyProviders)).success + ).toBe(false); + expect( + remoteModelCatalogWireV1Schema.safeParse( + createWireCatalog([createSdkProvider('provider', tooManyModels)]) + ).success + ).toBe(false); + expect( + remoteModelCatalogWireV1Schema.safeParse( + createWireCatalog([ + createSdkProvider('provider', [createSdkModel('provider', 'model', tooManyVariants)]), + ]) + ).success + ).toBe(false); + }); + + it('requires exact non-empty identities within the v1 length bound', () => { + const validIdentity = 'p'.repeat(REMOTE_MODEL_IDENTITY_MAX_LENGTH); + + expect( + modelRefSchema.parse({ providerID: 'provider/with/slash', modelID: validIdentity }) + ).toEqual({ providerID: 'provider/with/slash', modelID: validIdentity }); + expect(modelRefSchema.safeParse({ providerID: '', modelID: 'model' }).success).toBe(false); + expect( + modelRefSchema.safeParse({ + providerID: 'provider', + modelID: 'm'.repeat(REMOTE_MODEL_IDENTITY_MAX_LENGTH + 1), + }).success + ).toBe(false); + }); +}); + +describe('modelRefsEqual', () => { + it('compares exact provider and model identities without parsing separators', () => { + const model = { providerID: 'custom/provider', modelID: 'family/model:v1' }; + + expect(modelRefsEqual(model, { ...model })).toBe(true); + expect(modelRefsEqual(model, { providerID: 'other/provider', modelID: model.modelID })).toBe( + false + ); + expect(modelRefsEqual(model, { providerID: model.providerID, modelID: 'model:v1' })).toBe( + false + ); + }); +}); + +describe('createModelRefKeyMap', () => { + it('round-trips exact refs through opaque keys without provider/model collisions', () => { + const keyMap = createModelRefKeyMap(); + const first = { providerID: 'provider/one', modelID: 'shared/model' }; + const second = { providerID: 'provider/two', modelID: 'shared/model' }; + + const firstKey = keyMap.getOrCreateKey(first); + const secondKey = keyMap.getOrCreateKey(second); + + expect(firstKey).not.toBe(secondKey); + expect(firstKey).not.toContain(first.providerID); + expect(firstKey).not.toContain(first.modelID); + expect(keyMap.getOrCreateKey({ ...first })).toBe(firstKey); + expect(keyMap.getModelRef(firstKey)).toEqual(first); + expect(keyMap.getModelRef(secondKey)).toEqual(second); + expect(keyMap.getModelRef('unknown-key')).toBeUndefined(); + }); +}); diff --git a/apps/web/src/lib/cloud-agent-sdk/remote-model-catalog.ts b/apps/web/src/lib/cloud-agent-sdk/remote-model-catalog.ts new file mode 100644 index 0000000000..05981c522f --- /dev/null +++ b/apps/web/src/lib/cloud-agent-sdk/remote-model-catalog.ts @@ -0,0 +1,69 @@ +import type { ModelRef, ModelSelection, RemoteModelCatalogV1 } from './schemas'; + +export { + REMOTE_MODEL_CATALOG_MAX_SERIALIZED_BYTES, + REMOTE_MODEL_IDENTITY_MAX_LENGTH, + REMOTE_MODEL_MAX_MODELS_PER_PROVIDER, + REMOTE_MODEL_MAX_MODELS_TOTAL, + REMOTE_MODEL_MAX_PROVIDERS, + REMOTE_MODEL_MAX_VARIANTS_PER_MODEL, + REMOTE_MODEL_MAX_VARIANTS_TOTAL, + modelRefSchema, + modelSelectionSchema, + remoteModelCatalogV1Schema, + remoteModelCatalogWireV1Schema, + type ModelRef, + type ModelSelection, + type RemoteModelCatalogV1, + type RemoteModelCatalogWireV1, +} from './schemas'; + +// Catalog strings are user/plugin-controlled metadata and may be private. +// Treat them as display data, never executable config or independent telemetry. +export type RemoteModelState = { + ownerConnectionId: string | null; + protocol: 'unknown' | 'legacy' | 'v1'; + catalog?: RemoteModelCatalogV1; + refresh: 'idle' | 'loading' | 'error'; + error?: string; +}; + +export type RemoteModelOverride = + | { source: 'cli-catalog'; selection: ModelSelection } + | { source: 'legacy-gateway'; selection: ModelSelection }; + +export type ModelRefKeyMap = { + getOrCreateKey: (modelRef: ModelRef) => string; + getModelRef: (key: string) => ModelRef | undefined; +}; + +export function modelRefsEqual(left: ModelRef, right: ModelRef): boolean { + return left.providerID === right.providerID && left.modelID === right.modelID; +} + +export function createModelRefKeyMap(): ModelRefKeyMap { + const keysByProviderAndModel = new Map>(); + const modelRefsByKey = new Map(); + + return { + getOrCreateKey(modelRef) { + let keysByModel = keysByProviderAndModel.get(modelRef.providerID); + if (!keysByModel) { + keysByModel = new Map(); + keysByProviderAndModel.set(modelRef.providerID, keysByModel); + } + + const existingKey = keysByModel.get(modelRef.modelID); + if (existingKey) return existingKey; + + const key = `remote-model-${modelRefsByKey.size}`; + const storedModelRef = { providerID: modelRef.providerID, modelID: modelRef.modelID }; + keysByModel.set(modelRef.modelID, key); + modelRefsByKey.set(key, storedModelRef); + return key; + }, + getModelRef(key) { + return modelRefsByKey.get(key); + }, + }; +} diff --git a/apps/web/src/lib/cloud-agent-sdk/schemas.ts b/apps/web/src/lib/cloud-agent-sdk/schemas.ts index fd9aa0b958..313614dbd4 100644 --- a/apps/web/src/lib/cloud-agent-sdk/schemas.ts +++ b/apps/web/src/lib/cloud-agent-sdk/schemas.ts @@ -85,6 +85,246 @@ export const permissionPayloadSchema = z .passthrough(); export type PermissionState = z.infer; +// --------------------------------------------------------------------------- +// Remote CLI model catalog +// --------------------------------------------------------------------------- + +export const REMOTE_MODEL_MAX_PROVIDERS = 64; +export const REMOTE_MODEL_MAX_MODELS_PER_PROVIDER = 512; +export const REMOTE_MODEL_MAX_MODELS_TOTAL = 2_048; +export const REMOTE_MODEL_MAX_VARIANTS_PER_MODEL = 32; +export const REMOTE_MODEL_MAX_VARIANTS_TOTAL = 8_192; +export const REMOTE_MODEL_IDENTITY_MAX_LENGTH = 255; +export const REMOTE_MODEL_CATALOG_MAX_SERIALIZED_BYTES = 512 * 1024; + +const remoteModelIdentitySchema = z.string().min(1).max(REMOTE_MODEL_IDENTITY_MAX_LENGTH); +const remoteModelDisplayNameSchema = z.string().max(REMOTE_MODEL_IDENTITY_MAX_LENGTH); + +export const modelRefSchema = z + .object({ + providerID: remoteModelIdentitySchema, + modelID: remoteModelIdentitySchema, + }) + .strict(); +export type ModelRef = z.infer; + +export const modelSelectionSchema = z + .object({ + model: modelRefSchema, + variant: remoteModelIdentitySchema.optional(), + }) + .strict(); +export type ModelSelection = z.infer; + +const emptyRemoteModelRecordSchema = z.object({}).strict(); +const remoteModelModalitiesSchema = z + .object({ + text: z.boolean(), + audio: z.boolean(), + image: z.boolean(), + video: z.boolean(), + pdf: z.boolean(), + }) + .strict(); +const remoteSdkModelSchema = z + .object({ + id: remoteModelIdentitySchema, + providerID: remoteModelIdentitySchema, + api: z + .object({ + id: remoteModelIdentitySchema, + url: z.literal(''), + npm: z.literal(''), + }) + .strict(), + name: remoteModelDisplayNameSchema, + capabilities: z + .object({ + temperature: z.boolean(), + reasoning: z.boolean(), + attachment: z.boolean(), + toolcall: z.boolean(), + input: remoteModelModalitiesSchema, + output: remoteModelModalitiesSchema, + interleaved: z.union([ + z.boolean(), + z.object({ field: z.enum(['reasoning_content', 'reasoning_details']) }).strict(), + ]), + }) + .strict(), + cost: z + .object({ + input: z.literal(0), + output: z.literal(0), + cache: z.object({ read: z.literal(0), write: z.literal(0) }).strict(), + }) + .strict(), + limit: z + .object({ + context: z.number().finite().nonnegative(), + input: z.number().finite().nonnegative().optional(), + output: z.number().finite().nonnegative(), + }) + .strict(), + status: z.enum(['alpha', 'beta', 'deprecated', 'active']), + options: emptyRemoteModelRecordSchema, + headers: emptyRemoteModelRecordSchema, + release_date: z.literal(''), + variants: z.record(remoteModelIdentitySchema, emptyRemoteModelRecordSchema).optional(), + }) + .strict(); +const remoteSdkProviderSchema = z + .object({ + id: remoteModelIdentitySchema, + name: remoteModelDisplayNameSchema, + source: z.enum(['env', 'config', 'custom', 'api']), + env: z.array(z.never()).max(0), + options: emptyRemoteModelRecordSchema, + models: z.record(remoteModelIdentitySchema, remoteSdkModelSchema), + }) + .strict(); + +export const remoteModelCatalogWireV1Schema = z + .object({ + all: z.array(remoteSdkProviderSchema).max(REMOTE_MODEL_MAX_PROVIDERS), + default: z.record(remoteModelIdentitySchema, remoteModelIdentitySchema), + connected: z.array(remoteModelIdentitySchema).max(REMOTE_MODEL_MAX_PROVIDERS), + failed: z.array(remoteModelIdentitySchema).max(REMOTE_MODEL_MAX_PROVIDERS), + protocolVersion: z.literal(1), + currentModel: modelSelectionSchema.optional(), + defaultModel: modelRefSchema.optional(), + truncated: z.boolean(), + }) + .strict() + .superRefine((catalog, context) => { + let modelCount = 0; + let variantCount = 0; + const providers = new Map(catalog.all.map(provider => [provider.id, provider])); + if (providers.size !== catalog.all.length) { + context.addIssue({ code: 'custom', message: 'Provider ID must be unique', path: ['all'] }); + } + if (new Set(catalog.connected).size !== catalog.connected.length) { + context.addIssue({ + code: 'custom', + message: 'Connected provider ID must be unique', + path: ['connected'], + }); + } + for (const [providerIndex, provider] of catalog.all.entries()) { + const models = Object.entries(provider.models); + modelCount += models.length; + if (models.length > REMOTE_MODEL_MAX_MODELS_PER_PROVIDER) { + context.addIssue({ + code: 'custom', + message: `Provider cannot contain more than ${REMOTE_MODEL_MAX_MODELS_PER_PROVIDER} models`, + path: ['all', providerIndex, 'models'], + }); + } + for (const [modelKey, model] of models) { + if ( + modelKey !== model.id || + model.providerID !== provider.id || + model.api.id !== model.id + ) { + context.addIssue({ + code: 'custom', + message: 'Model record identity must match its provider and record key', + path: ['all', providerIndex, 'models', modelKey], + }); + } + const variants = Object.keys(model.variants ?? {}); + variantCount += variants.length; + if (variants.length > REMOTE_MODEL_MAX_VARIANTS_PER_MODEL) { + context.addIssue({ + code: 'custom', + message: `Model cannot contain more than ${REMOTE_MODEL_MAX_VARIANTS_PER_MODEL} variants`, + path: ['all', providerIndex, 'models', modelKey, 'variants'], + }); + } + } + } + for (const providerId of catalog.connected) { + if (!providers.has(providerId)) { + context.addIssue({ + code: 'custom', + message: 'Connected provider must exist in all', + path: ['connected'], + }); + } + } + for (const [providerId, modelId] of Object.entries(catalog.default)) { + const provider = providers.get(providerId); + if (!provider || !Object.hasOwn(provider.models, modelId)) { + context.addIssue({ + code: 'custom', + message: 'Default model must exist in all', + path: ['default', providerId], + }); + } + } + if (modelCount > REMOTE_MODEL_MAX_MODELS_TOTAL) { + context.addIssue({ + code: 'custom', + message: `Catalog cannot contain more than ${REMOTE_MODEL_MAX_MODELS_TOTAL} models`, + path: ['all'], + }); + } + if (variantCount > REMOTE_MODEL_MAX_VARIANTS_TOTAL) { + context.addIssue({ + code: 'custom', + message: `Catalog cannot contain more than ${REMOTE_MODEL_MAX_VARIANTS_TOTAL} variants`, + path: ['all'], + }); + } + const serializedBytes = new TextEncoder().encode(JSON.stringify(catalog)).byteLength; + if (serializedBytes > REMOTE_MODEL_CATALOG_MAX_SERIALIZED_BYTES) { + context.addIssue({ + code: 'custom', + message: `Catalog cannot exceed ${REMOTE_MODEL_CATALOG_MAX_SERIALIZED_BYTES} serialized bytes`, + }); + } + }); +export type RemoteModelCatalogWireV1 = z.input; + +export const remoteModelCatalogV1Schema = remoteModelCatalogWireV1Schema.transform(catalog => { + const connected = new Set(catalog.connected); + return { + protocolVersion: 1 as const, + providers: catalog.all + .filter(provider => connected.has(provider.id)) + .map(provider => ({ + id: provider.id, + ...(provider.name ? { name: provider.name } : {}), + models: Object.values(provider.models).map(model => ({ + id: model.id, + ...(model.name ? { name: model.name } : {}), + variants: Object.keys(model.variants ?? {}), + capabilities: { + attachment: model.capabilities.attachment, + reasoning: model.capabilities.reasoning, + }, + limits: { + context: model.limit.context, + output: model.limit.output, + }, + })), + })), + ...(catalog.currentModel ? { currentModel: catalog.currentModel } : {}), + ...(catalog.defaultModel ? { defaultModel: catalog.defaultModel } : {}), + truncated: catalog.truncated, + }; +}); +export type RemoteModelCatalogV1 = z.output; + +export const userWebCommandErrorDataSchema = z + .object({ + source: z.literal('relay'), + code: z.string(), + message: z.string(), + }) + .strict(); +export type UserWebCommandErrorData = z.infer; + // --------------------------------------------------------------------------- // WebSocket inbound message (CLI live transport) // --------------------------------------------------------------------------- diff --git a/apps/web/src/lib/cloud-agent-sdk/session-manager.test.ts b/apps/web/src/lib/cloud-agent-sdk/session-manager.test.ts index fe28aff0f8..3676f798cb 100644 --- a/apps/web/src/lib/cloud-agent-sdk/session-manager.test.ts +++ b/apps/web/src/lib/cloud-agent-sdk/session-manager.test.ts @@ -10,7 +10,15 @@ import { createCloudAgentSession } from './session'; import type { JotaiSessionStorage } from './storage/jotai'; import type { AssistantMessage, UserMessage } from '@/types/opencode.gen'; import { kiloId, cloudAgentId, stubUserMessage, stubTextPart, makeSnapshot } from './test-helpers'; -import type { CloudStatus, MessageDeliveryState, ResolvedSession, SessionActivity } from './types'; +import type { + CloudStatus, + MessageDeliveryState, + ResolvedSession, + SessionActivity, + SessionInfo, +} from './types'; +import type { RemoteModelState } from './remote-model-catalog'; +import type { NormalizedEvent } from './normalizer'; // --------------------------------------------------------------------------- // Mock createCloudAgentSession — prevents real WebSocket connections @@ -27,6 +35,7 @@ const mockSession = { respondToPermission: jest.fn(), acceptSuggestion: jest.fn(), dismissSuggestion: jest.fn(), + retryRemoteModels: jest.fn(), canSend: true, canInterrupt: true, state: { @@ -47,7 +56,9 @@ const mockSession = { }; const mockSessionCallbacks: { - onSessionCreated?: (info: { id: string; parentID: string | null }) => void; + onSessionCreated?: (info: SessionInfo) => void; + onSessionUpdated?: (info: SessionInfo) => void; + onQuestionAsked?: (...args: unknown[]) => void; onQuestionResolved?: (...args: unknown[]) => void; onPermissionAsked?: (...args: unknown[]) => void; @@ -55,6 +66,9 @@ const mockSessionCallbacks: { onSuggestionAsked?: (...args: unknown[]) => void; onSuggestionResolved?: (...args: unknown[]) => void; onResolved?: (resolved: ResolvedSession) => void; + onRemoteModelStateChange?: (state: RemoteModelState) => void; + onTransportCapabilityChange?: () => void; + onEvent?: (event: NormalizedEvent) => void; onMessageQueued?: (messageId: string) => void; onMessageCompleted?: (messageId: string) => void; onMessageFailed?: ( @@ -71,7 +85,9 @@ jest.mock('./session', () => ({ (sessionConfig: { kiloSessionId: string; storage: JotaiSessionStorage; - onSessionCreated?: (info: { id: string; parentID: string | null }) => void; + onSessionCreated?: (info: SessionInfo) => void; + onSessionUpdated?: (info: SessionInfo) => void; + onQuestionAsked?: (...args: unknown[]) => void; onQuestionResolved?: (...args: unknown[]) => void; onPermissionAsked?: (...args: unknown[]) => void; @@ -79,6 +95,9 @@ jest.mock('./session', () => ({ onSuggestionAsked?: (...args: unknown[]) => void; onSuggestionResolved?: (...args: unknown[]) => void; onResolved?: (resolved: ResolvedSession) => void; + onRemoteModelStateChange?: (state: RemoteModelState) => void; + onTransportCapabilityChange?: () => void; + onEvent?: (event: NormalizedEvent) => void; onMessageQueued?: (messageId: string) => void; onMessageCompleted?: (messageId: string) => void; onMessageFailed?: ( @@ -98,9 +117,10 @@ jest.mock('./session', () => ({ kiloSessionId: kiloId(sessionConfig.kiloSessionId), cloudAgentSessionId: cloudAgentId('agent-1'), }); - sessionConfig.onSessionCreated?.({ id: sessionConfig.kiloSessionId, parentID: null }); + sessionConfig.onSessionCreated?.({ id: sessionConfig.kiloSessionId }); }); mockSessionCallbacks.onSessionCreated = sessionConfig.onSessionCreated; + mockSessionCallbacks.onSessionUpdated = sessionConfig.onSessionUpdated; mockSessionCallbacks.onQuestionAsked = sessionConfig.onQuestionAsked; mockSessionCallbacks.onQuestionResolved = sessionConfig.onQuestionResolved; mockSessionCallbacks.onPermissionAsked = sessionConfig.onPermissionAsked; @@ -108,6 +128,9 @@ jest.mock('./session', () => ({ mockSessionCallbacks.onSuggestionAsked = sessionConfig.onSuggestionAsked; mockSessionCallbacks.onSuggestionResolved = sessionConfig.onSuggestionResolved; mockSessionCallbacks.onResolved = sessionConfig.onResolved; + mockSessionCallbacks.onRemoteModelStateChange = sessionConfig.onRemoteModelStateChange; + mockSessionCallbacks.onTransportCapabilityChange = sessionConfig.onTransportCapabilityChange; + mockSessionCallbacks.onEvent = sessionConfig.onEvent; mockSessionCallbacks.onMessageQueued = sessionConfig.onMessageQueued; mockSessionCallbacks.onMessageCompleted = sessionConfig.onMessageCompleted; mockSessionCallbacks.onMessageFailed = sessionConfig.onMessageFailed; @@ -140,6 +163,26 @@ const defaultFetchedSession = { associatedPr: null, } satisfies FetchedSessionData; +const remoteCatalog = { + protocolVersion: 1, + providers: [ + { + id: 'anthropic', + name: 'Anthropic', + models: [ + { + id: 'claude-sonnet-4', + name: 'Claude Sonnet 4', + variants: ['high'], + capabilities: { attachment: true, reasoning: true }, + limits: { context: 200_000, output: 64_000 }, + }, + ], + }, + ], + truncated: false, +} satisfies NonNullable; + function createMockConfig(overrides: Partial = {}): SessionManagerConfig { return { store: createStore(), @@ -275,7 +318,11 @@ describe('createSessionManager', () => { mockSessionCallbacks.onPermissionAsked = undefined; mockSessionCallbacks.onPermissionResolved = undefined; mockSessionCallbacks.onSessionCreated = undefined; + mockSessionCallbacks.onSessionUpdated = undefined; mockSessionCallbacks.onResolved = undefined; + mockSessionCallbacks.onRemoteModelStateChange = undefined; + mockSessionCallbacks.onTransportCapabilityChange = undefined; + mockSessionCallbacks.onEvent = undefined; mockSessionCallbacks.onMessageQueued = undefined; mockSessionCallbacks.onMessageCompleted = undefined; mockSessionCallbacks.onMessageFailed = undefined; @@ -522,6 +569,427 @@ describe('createSessionManager', () => { ).toBeNull(); }); + it('exposes active session type and remote model state from the live transport', async () => { + const config = createMockConfig(); + const mgr = createSessionManager(config); + + await mgr.switchSession(kiloId('ses-1')); + expect(atomValue(config.store, mgr.atoms.activeSessionType)).toBe('cloud-agent'); + + mockSessionCallbacks.onResolved?.({ type: 'remote', kiloSessionId: kiloId('ses-1') }); + const remoteState = { + ownerConnectionId: 'owner', + protocol: 'v1', + catalog: { + protocolVersion: 1, + providers: [], + currentModel: { + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + variant: 'high', + }, + defaultModel: { providerID: 'kilo', modelID: 'kilo-auto' }, + truncated: false, + }, + refresh: 'idle', + } satisfies RemoteModelState; + mockSessionCallbacks.onRemoteModelStateChange?.(remoteState); + + expect(atomValue(config.store, mgr.atoms.activeSessionType)).toBe('remote'); + expect(atomValue(config.store, mgr.atoms.remoteModelState)).toEqual(remoteState); + expect(atomValue(config.store, mgr.atoms.observedModel)).toEqual( + remoteState.catalog.currentModel + ); + }); + + it('replaces a catalog-derived observation when the session owner changes', async () => { + const config = createMockConfig(); + const mgr = createSessionManager(config); + + await mgr.switchSession(kiloId('ses-1')); + mockSessionCallbacks.onResolved?.({ type: 'remote', kiloSessionId: kiloId('ses-1') }); + mockSessionCallbacks.onRemoteModelStateChange?.({ + ownerConnectionId: 'owner-a', + protocol: 'v1', + catalog: { + protocolVersion: 1, + providers: [], + currentModel: { model: { providerID: 'provider-a', modelID: 'model-a' } }, + truncated: false, + }, + refresh: 'idle', + }); + expect(atomValue(config.store, mgr.atoms.observedModel)).toEqual({ + model: { providerID: 'provider-a', modelID: 'model-a' }, + }); + + mockSessionCallbacks.onRemoteModelStateChange?.({ + ownerConnectionId: 'owner-b', + protocol: 'unknown', + refresh: 'loading', + }); + expect(atomValue(config.store, mgr.atoms.observedModel)).toBeNull(); + + mockSessionCallbacks.onRemoteModelStateChange?.({ + ownerConnectionId: 'owner-b', + protocol: 'v1', + catalog: { + protocolVersion: 1, + providers: [], + currentModel: { model: { providerID: 'provider-b', modelID: 'model-b' } }, + truncated: false, + }, + refresh: 'idle', + }); + expect(atomValue(config.store, mgr.atoms.observedModel)).toEqual({ + model: { providerID: 'provider-b', modelID: 'model-b' }, + }); + }); + + it('recomputes remote send capability without marking a disconnected owner read-only', async () => { + const config = createMockConfig(); + const mgr = createSessionManager(config); + + await mgr.switchSession(kiloId('ses-1')); + mockSessionCallbacks.onResolved?.({ type: 'remote', kiloSessionId: kiloId('ses-1') }); + + mockSession.canSend = false; + mockSessionCallbacks.onTransportCapabilityChange?.(); + expect(atomValue(config.store, mgr.atoms.canSend)).toBe(false); + expect(atomValue(config.store, mgr.atoms.isReadOnly)).toBe(false); + + mockSession.canSend = true; + mockSessionCallbacks.onTransportCapabilityChange?.(); + expect(atomValue(config.store, mgr.atoms.canSend)).toBe(true); + }); + + it('delegates remote model retries to the active session', async () => { + const config = createMockConfig(); + const mgr = createSessionManager(config); + + await mgr.switchSession(kiloId('ses-1')); + mgr.retryRemoteModels(); + + expect(mockSession.retryRemoteModels).toHaveBeenCalledTimes(1); + }); + + it('keeps session metadata authoritative over replayed root message models', async () => { + const config = createMockConfig(); + const mgr = createSessionManager(config); + + await mgr.switchSession(kiloId('ses-1')); + mockSessionCallbacks.onSessionCreated?.({ + id: 'ses-1', + model: { providerID: 'openai', id: 'gpt-5', variant: 'high' }, + }); + expect(atomValue(config.store, mgr.atoms.observedModel)).toEqual({ + model: { providerID: 'openai', modelID: 'gpt-5' }, + variant: 'high', + }); + + mockSessionCallbacks.onEvent?.({ + type: 'message.updated', + info: stubUserMessage({ + id: 'msg-root', + sessionID: 'ses-1', + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + variant: 'max', + }), + }); + expect(atomValue(config.store, mgr.atoms.observedModel)).toEqual({ + model: { providerID: 'openai', modelID: 'gpt-5' }, + variant: 'high', + }); + + mockSessionCallbacks.onEvent?.({ + type: 'message.updated', + info: createStoredAssistantMessage('msg-assistant', 'ses-1', { + providerID: 'custom-provider', + modelID: 'custom/model', + variant: 'fast', + }).info, + }); + expect(atomValue(config.store, mgr.atoms.observedModel)).toEqual({ + model: { providerID: 'openai', modelID: 'gpt-5' }, + variant: 'high', + }); + }); + + it('uses a replayed root message model when session metadata has no model', async () => { + const config = createMockConfig(); + const mgr = createSessionManager(config); + + await mgr.switchSession(kiloId('ses-1')); + expect(atomValue(config.store, mgr.atoms.observedModel)).toBeNull(); + + mockSessionCallbacks.onEvent?.({ + type: 'message.updated', + info: stubUserMessage({ + id: 'msg-root', + sessionID: 'ses-1', + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + variant: 'high', + }), + }); + + expect(atomValue(config.store, mgr.atoms.observedModel)).toEqual({ + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + variant: 'high', + }); + }); + + it('keeps session metadata above the catalog current model', async () => { + const config = createMockConfig(); + const mgr = createSessionManager(config); + + await mgr.switchSession(kiloId('ses-1')); + mockSessionCallbacks.onSessionCreated?.({ + id: 'ses-1', + model: { providerID: 'openai', id: 'gpt-5', variant: 'high' }, + }); + mockSessionCallbacks.onRemoteModelStateChange?.({ + ownerConnectionId: 'owner-a', + protocol: 'v1', + catalog: { + ...remoteCatalog, + currentModel: { + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + }, + defaultModel: { providerID: 'kilo', modelID: 'kilo-auto' }, + }, + refresh: 'idle', + }); + + expect(atomValue(config.store, mgr.atoms.observedModel)).toEqual({ + model: { providerID: 'openai', modelID: 'gpt-5' }, + variant: 'high', + }); + }); + + it('applies a live session.updated model while retaining the explicit override', async () => { + const config = createMockConfig(); + const mgr = createSessionManager(config); + const override = { + source: 'cli-catalog', + selection: { + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + variant: 'high', + }, + } as const; + + await mgr.switchSession(kiloId('ses-1')); + mockSessionCallbacks.onResolved?.({ type: 'remote', kiloSessionId: kiloId('ses-1') }); + mockSessionCallbacks.onRemoteModelStateChange?.({ + ownerConnectionId: 'owner-a', + protocol: 'v1', + catalog: remoteCatalog, + refresh: 'idle', + }); + mockSessionCallbacks.onSessionCreated?.({ + id: 'ses-1', + model: { providerID: 'openai', id: 'gpt-5' }, + }); + mgr.setRemoteModelOverride(override); + + mockSessionCallbacks.onEvent?.({ + type: 'message.updated', + info: createStoredAssistantMessage('msg-history', 'ses-1', { + providerID: 'historical-provider', + modelID: 'historical-model', + }).info, + }); + mockSessionCallbacks.onSessionUpdated?.({ + id: 'ses-1', + model: { providerID: 'anthropic', id: 'claude-sonnet-4', variant: 'high' }, + }); + + expect(atomValue(config.store, mgr.atoms.observedModel)).toEqual({ + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + variant: 'high', + }); + expect(atomValue(config.store, mgr.atoms.remoteModelOverride)).toEqual(override); + }); + + it('keeps explicit override separate from observations and clears it on owner change', async () => { + const config = createMockConfig(); + const mgr = createSessionManager(config); + const catalog = { + protocolVersion: 1, + providers: [], + truncated: false, + } satisfies NonNullable; + + await mgr.switchSession(kiloId('ses-1')); + mockSessionCallbacks.onResolved?.({ type: 'remote', kiloSessionId: kiloId('ses-1') }); + mockSessionCallbacks.onRemoteModelStateChange?.({ + ownerConnectionId: 'owner-a', + protocol: 'v1', + catalog, + refresh: 'idle', + }); + const override = { + source: 'cli-catalog', + selection: { + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + variant: 'high', + }, + } as const; + mgr.setRemoteModelOverride(override); + + mockSessionCallbacks.onEvent?.({ + type: 'message.updated', + info: createStoredAssistantMessage('msg-assistant', 'ses-1', { + providerID: 'openai', + modelID: 'gpt-5', + }).info, + }); + + expect(atomValue(config.store, mgr.atoms.observedModel)).toEqual({ + model: { providerID: 'openai', modelID: 'gpt-5' }, + }); + expect(atomValue(config.store, mgr.atoms.remoteModelOverride)).toEqual(override); + + mockSessionCallbacks.onRemoteModelStateChange?.({ + ownerConnectionId: 'owner-b', + protocol: 'unknown', + refresh: 'loading', + }); + expect(atomValue(config.store, mgr.atoms.remoteModelOverride)).toBeNull(); + expect(atomValue(config.store, mgr.atoms.remoteModelState)).toEqual({ + ownerConnectionId: 'owner-b', + protocol: 'unknown', + refresh: 'loading', + }); + }); + + it('clears an explicit override when the same owner changes to an incompatible protocol', async () => { + const config = createMockConfig(); + const mgr = createSessionManager(config); + + await mgr.switchSession(kiloId('ses-1')); + mockSessionCallbacks.onResolved?.({ type: 'remote', kiloSessionId: kiloId('ses-1') }); + mockSessionCallbacks.onRemoteModelStateChange?.({ + ownerConnectionId: 'owner-a', + protocol: 'v1', + catalog: remoteCatalog, + refresh: 'idle', + }); + mgr.setRemoteModelOverride({ + source: 'cli-catalog', + selection: { + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + variant: 'high', + }, + }); + + mockSessionCallbacks.onRemoteModelStateChange?.({ + ownerConnectionId: 'owner-a', + protocol: 'legacy', + refresh: 'idle', + }); + + expect(atomValue(config.store, mgr.atoms.remoteModelOverride)).toBeNull(); + }); + + it('clears an explicit override when a same-owner catalog no longer contains its model', async () => { + const config = createMockConfig(); + const mgr = createSessionManager(config); + + await mgr.switchSession(kiloId('ses-1')); + mockSessionCallbacks.onResolved?.({ type: 'remote', kiloSessionId: kiloId('ses-1') }); + mockSessionCallbacks.onRemoteModelStateChange?.({ + ownerConnectionId: 'owner-a', + protocol: 'v1', + catalog: remoteCatalog, + refresh: 'idle', + }); + mgr.setRemoteModelOverride({ + source: 'cli-catalog', + selection: { + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + variant: 'high', + }, + }); + + mockSessionCallbacks.onRemoteModelStateChange?.({ + ownerConnectionId: 'owner-a', + protocol: 'v1', + catalog: { + ...remoteCatalog, + providers: [{ ...remoteCatalog.providers[0], models: [] }], + }, + refresh: 'idle', + }); + + expect(atomValue(config.store, mgr.atoms.remoteModelOverride)).toBeNull(); + }); + + it('keeps a same-owner v1 model override but drops a removed variant', async () => { + const config = createMockConfig(); + const mgr = createSessionManager(config); + + await mgr.switchSession(kiloId('ses-1')); + mockSessionCallbacks.onResolved?.({ type: 'remote', kiloSessionId: kiloId('ses-1') }); + mockSessionCallbacks.onRemoteModelStateChange?.({ + ownerConnectionId: 'owner-a', + protocol: 'v1', + catalog: remoteCatalog, + refresh: 'idle', + }); + mgr.setRemoteModelOverride({ + source: 'cli-catalog', + selection: { + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + variant: 'high', + }, + }); + + mockSessionCallbacks.onRemoteModelStateChange?.({ + ownerConnectionId: 'owner-a', + protocol: 'v1', + catalog: { + ...remoteCatalog, + providers: [ + { + ...remoteCatalog.providers[0], + models: [{ ...remoteCatalog.providers[0].models[0], variants: [] }], + }, + ], + }, + refresh: 'idle', + }); + + expect(atomValue(config.store, mgr.atoms.remoteModelOverride)).toEqual({ + source: 'cli-catalog', + selection: { model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' } }, + }); + }); + + it('clears remote model state and override immediately when switching sessions', async () => { + const config = createMockConfig(); + const mgr = createSessionManager(config); + + await mgr.switchSession(kiloId('ses-1')); + mockSessionCallbacks.onResolved?.({ type: 'remote', kiloSessionId: kiloId('ses-1') }); + mockSessionCallbacks.onRemoteModelStateChange?.({ + ownerConnectionId: 'owner-a', + protocol: 'legacy', + refresh: 'idle', + }); + mgr.setRemoteModelOverride({ + source: 'legacy-gateway', + selection: { model: { providerID: 'kilo', modelID: 'kilo-auto' } }, + }); + + const switching = mgr.switchSession(kiloId('ses-2')); + expect(atomValue(config.store, mgr.atoms.remoteModelState)).toEqual({ + ownerConnectionId: null, + protocol: 'unknown', + refresh: 'idle', + }); + expect(atomValue(config.store, mgr.atoms.remoteModelOverride)).toBeNull(); + await switching; + }); + it('allows attachments only for a resolved Cloud Agent session', async () => { const config = createMockConfig(); const mgr = createSessionManager(config); @@ -659,7 +1127,7 @@ describe('createSessionManager', () => { type: 'prompt', prompt: 'Queue this follow-up', mode: 'code', - model: 'claude-3-5-sonnet', + model: { providerID: 'kilo', modelID: 'claude-3-5-sonnet' }, }, images: undefined, }); @@ -679,7 +1147,73 @@ describe('createSessionManager', () => { expect(atomValue(config.store, mgr.atoms.messagesList)).toHaveLength(0); expect(mockSession.send).toHaveBeenCalledWith({ messageId: expect.stringMatching(/^msg_/), - payload: { type: 'prompt', prompt: 'Hello', mode: 'code', model: 'claude-3-5-sonnet' }, + payload: { + type: 'prompt', + prompt: 'Hello', + mode: 'code', + model: { providerID: 'kilo', modelID: 'claude-3-5-sonnet' }, + }, + images: undefined, + }); + }); + + it('sends only the explicit remote override and omits stale session model fields after clear', async () => { + const config = createMockConfig(); + const mgr = createSessionManager(config); + const override = { + source: 'cli-catalog', + selection: { + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + variant: 'high', + }, + } as const; + + await mgr.switchSession(kiloId('ses-1')); + mockSessionCallbacks.onResolved?.({ type: 'remote', kiloSessionId: kiloId('ses-1') }); + mgr.setRemoteModelOverride(override); + mockSession.send.mockResolvedValue(undefined); + + await mgr.send({ + payload: { + type: 'prompt', + prompt: 'with override', + mode: 'code', + model: 'stale-session-model', + variant: 'stale-session-variant', + }, + }); + + expect(mockSession.send).toHaveBeenLastCalledWith({ + messageId: expect.stringMatching(/^msg_/), + payload: { + type: 'prompt', + prompt: 'with override', + mode: 'code', + model: override.selection.model, + variant: 'high', + }, + remoteModelOverride: override, + images: undefined, + }); + + mgr.setRemoteModelOverride(null); + await mgr.send({ + payload: { + type: 'prompt', + prompt: 'without override', + mode: 'code', + model: 'stale-session-model', + variant: 'stale-session-variant', + }, + }); + + expect(mockSession.send).toHaveBeenLastCalledWith({ + messageId: expect.stringMatching(/^msg_/), + payload: { + type: 'prompt', + prompt: 'without override', + mode: 'code', + }, images: undefined, }); }); @@ -698,7 +1232,8 @@ describe('createSessionManager', () => { expect(mockSession.send).toHaveBeenCalledWith({ messageId: expect.stringMatching(/^msg_/), - payload: { type: 'prompt', prompt: 'Hello', mode: 'code', model: 'claude-3-5-sonnet' }, + payload: { type: 'prompt', prompt: 'Hello', mode: 'code' }, + images: undefined, }); expect(atomValue(config.store, mgr.atoms.messagesList)).toHaveLength(0); }); @@ -877,7 +1412,7 @@ describe('createSessionManager', () => { type: 'prompt', prompt: 'Hello', mode: 'code', - model: 'claude-3-5-sonnet', + model: { providerID: 'kilo', modelID: 'claude-3-5-sonnet' }, variant: 'high', }, images: undefined, @@ -900,7 +1435,12 @@ describe('createSessionManager', () => { expect(accepted).toBe(true); expect(mockSession.send).toHaveBeenCalledWith({ messageId: expect.stringMatching(/^msg_/), - payload: { type: 'prompt', prompt: 'Hello', mode: 'code', model: 'claude-3-5-sonnet' }, + payload: { + type: 'prompt', + prompt: 'Hello', + mode: 'code', + model: { providerID: 'kilo', modelID: 'claude-3-5-sonnet' }, + }, images, }); }); @@ -924,7 +1464,12 @@ describe('createSessionManager', () => { expect(accepted).toBe(true); expect(mockSession.send).toHaveBeenCalledWith({ messageId: expect.stringMatching(/^msg_/), - payload: { type: 'prompt', prompt: 'Hello', mode: 'code', model: 'claude-3-5-sonnet' }, + payload: { + type: 'prompt', + prompt: 'Hello', + mode: 'code', + model: { providerID: 'kilo', modelID: 'claude-3-5-sonnet' }, + }, attachments, images: undefined, }); @@ -995,7 +1540,13 @@ describe('createSessionManager', () => { expect(mockSession.send).toHaveBeenCalledWith({ messageId: expect.stringMatching(/^msg_/), - payload: { type: 'prompt', prompt: 'Hello', mode: 'code', model: 'claude-3-5-sonnet' }, + payload: { + type: 'prompt', + prompt: 'Hello', + mode: 'code', + model: { providerID: 'kilo', modelID: 'claude-3-5-sonnet' }, + }, + images: undefined, }); }); @@ -1034,7 +1585,7 @@ describe('createSessionManager', () => { if (!storage) throw new Error('expected session storage'); storage.upsertMessage(rootMessage.info); storage.upsertMessage(childMessage.info); - mockSessionCallbacks.onSessionCreated?.({ id: 'ses-root', parentID: null }); + mockSessionCallbacks.onSessionCreated?.({ id: 'ses-root' }); }); await mgr.switchSession(kiloId('ses-root')); @@ -1055,7 +1606,7 @@ describe('createSessionManager', () => { const childMessage = createStoredMessage('msg-child', 'child-2', 'assistant'); mockSession.connect.mockImplementation(() => { - mockSessionCallbacks.onSessionCreated?.({ id: 'ses-active', parentID: null }); + mockSessionCallbacks.onSessionCreated?.({ id: 'ses-active' }); }); await mgr.switchSession(kiloId('ses-active')); @@ -1078,7 +1629,7 @@ describe('createSessionManager', () => { const childOneSecond = createStoredMessage('msg-child-1b', 'child-1', 'user'); mockSession.connect.mockImplementation(() => { - mockSessionCallbacks.onSessionCreated?.({ id: 'ses-root', parentID: null }); + mockSessionCallbacks.onSessionCreated?.({ id: 'ses-root' }); }); await mgr.switchSession(kiloId('ses-root')); @@ -1692,7 +2243,7 @@ describe('createSessionManager', () => { // Simulate a session.created event that reports a different root // session ID than the one switchSession was called with. const realRootId = 'ses-real-root'; - mockSessionCallbacks.onSessionCreated?.({ id: realRootId, parentID: null }); + mockSessionCallbacks.onSessionCreated?.({ id: realRootId }); if (!latestStorage) throw new Error('expected session storage'); const rootMessage = createStoredMessage('msg-1', realRootId, 'assistant'); @@ -2360,6 +2911,7 @@ describe('isReadOnly during connecting phase', () => { mockSession.storage = latestStorage; latestStorage = null; mockSessionCallbacks.onSessionCreated = undefined; + mockSessionCallbacks.onSessionUpdated = undefined; mockSessionCallbacks.onQuestionAsked = undefined; mockSessionCallbacks.onQuestionResolved = undefined; mockSessionCallbacks.onPermissionAsked = undefined; @@ -2432,6 +2984,7 @@ describe('isReadOnly during connecting phase', () => { expect(atomValue(config.store, mgr.atoms.isReadOnly)).toBe(false); // Transport resolves but canSend stays false (read-only session) + mockSessionCallbacks.onResolved?.({ type: 'read-only', kiloSessionId: kiloId('ses-1') }); mockSession.state.getActivity.mockReturnValue({ type: 'idle' as const }); subscriberCallbackRef.current?.(); diff --git a/apps/web/src/lib/cloud-agent-sdk/session-manager.ts b/apps/web/src/lib/cloud-agent-sdk/session-manager.ts index f1bf9c8b4d..8d0047d55b 100644 --- a/apps/web/src/lib/cloud-agent-sdk/session-manager.ts +++ b/apps/web/src/lib/cloud-agent-sdk/session-manager.ts @@ -1,7 +1,14 @@ import type { CloudAgentAttachments } from '@/lib/cloud-agent/constants'; import type { Images } from '@/lib/images-schema'; import { errorShapeSchema } from './schemas'; -import type { TransportSendPayload } from './transport'; +import type { SendCommandPayload, SendPromptPayload, TransportSendPayload } from './transport'; +import { modelRefsEqual } from './remote-model-catalog'; +import type { + ModelRef, + ModelSelection, + RemoteModelOverride, + RemoteModelState, +} from './remote-model-catalog'; import { atom } from 'jotai'; import type { Atom, WritableAtom } from 'jotai'; import { createCloudAgentSession } from './session'; @@ -41,6 +48,8 @@ import type { ContextUsage } from './context-usage'; // --------------------------------------------------------------------------- type StoredMessage = { info: MessageInfo; parts: Part[] }; +type SessionManagerPromptPayload = Omit & { model?: string }; +type SessionManagerSendPayload = SessionManagerPromptPayload | SendCommandPayload; type SessionStatusIndicator = { type: 'error' | 'warning' | 'info' | 'progress'; message: string; @@ -56,6 +65,7 @@ type SessionConfig = { runtimeAgents?: Array<{ slug: string; name: string; model?: string; variant?: string }>; }; type ActiveSessionType = ResolvedSession['type']; +type ObservedModelSource = 'session' | 'message' | 'catalog'; type StandaloneQuestion = { requestId: string; questions: QuestionInfo[] }; type StandalonePermission = { requestId: string; @@ -81,6 +91,12 @@ const IDLE_CHILD_SESSION_HYDRATION_STATE = { status: 'idle', } satisfies ChildSessionHydrationState; +const EMPTY_REMOTE_MODEL_STATE = { + ownerConnectionId: null, + protocol: 'unknown', + refresh: 'idle', +} satisfies RemoteModelState; + type AssociatedPrData = { url: string; number: number; @@ -163,6 +179,10 @@ type SessionManagerAtoms = { isReadOnly: W; /** Active resolved transport can deliver canonical Cloud Agent attachments. */ supportsAttachments: W; + activeSessionType: W; + remoteModelState: W; + observedModel: W; + remoteModelOverride: W; canSend: W; canInterrupt: W; statusIndicator: W; @@ -198,10 +218,12 @@ type SessionManager = { switchSession(kiloSessionId: KiloSessionId): Promise; hydrateChildSession(childSessionId: KiloSessionId): Promise; send(input: { - payload: TransportSendPayload; + payload: SessionManagerSendPayload; attachments?: CloudAgentAttachments; images?: Images; }): Promise; + setRemoteModelOverride(override: RemoteModelOverride | null): void; + retryRemoteModels(): void; interrupt(): Promise; answerQuestion(requestId: string, answers: string[][]): Promise; rejectQuestion(requestId: string): Promise; @@ -306,6 +328,16 @@ function indicatorForStatus(s: AgentStatus): SessionStatusIndicator | null { return null; } +function toModelSelection(model: ModelRef, variant?: string): ModelSelection { + return { model, ...(variant ? { variant } : {}) }; +} + +function modelSelectionsEqual(a: ModelSelection | null, b: ModelSelection | null): boolean { + if (a === b) return true; + if (!a || !b) return false; + return modelRefsEqual(a.model, b.model) && a.variant === b.variant; +} + // --------------------------------------------------------------------------- // Factory // --------------------------------------------------------------------------- @@ -322,6 +354,10 @@ function createSessionManager(config: SessionManagerConfig): SessionManager { const isLoadingAtom = atom(false); const isReadOnlyAtom = atom(false); const supportsAttachmentsAtom = atom(false); + const activeSessionTypeAtom = atom(null); + const remoteModelStateAtom = atom(EMPTY_REMOTE_MODEL_STATE); + const observedModelAtom = atom(null); + const remoteModelOverrideAtom = atom(null); const canSendAtom = atom(false); const canInterruptAtom = atom(false); const statusIndicatorAtom = atom(null); @@ -407,6 +443,7 @@ function createSessionManager(config: SessionManagerConfig): SessionManager { let switchGeneration = 0; let currentSession: CloudAgentSession | null = null; let activeSessionType: ActiveSessionType | null = null; + let observedModelSource: ObservedModelSource | null = null; let stateUnsub: (() => void) | null = null; let indicatorTimer: ReturnType | null = null; let childSessionHydrationGeneration = 0; @@ -432,6 +469,11 @@ function createSessionManager(config: SessionManagerConfig): SessionManager { store.set(isLoadingAtom, false); store.set(isReadOnlyAtom, false); store.set(supportsAttachmentsAtom, false); + store.set(activeSessionTypeAtom, null); + store.set(remoteModelStateAtom, EMPTY_REMOTE_MODEL_STATE); + store.set(observedModelAtom, null); + observedModelSource = null; + store.set(remoteModelOverrideAtom, null); store.set(canSendAtom, false); store.set(canInterruptAtom, false); store.set(statusIndicatorAtom, null); @@ -527,6 +569,70 @@ function createSessionManager(config: SessionManagerConfig): SessionManager { } } + function updateCapabilityAtoms(session: CloudAgentSession): void { + const cloudStatus = store.get(cloudStatusAtom); + const cloudReady = cloudStatus === null || cloudStatus.type === 'ready'; + store.set(canSendAtom, session.canSend && cloudReady); + store.set(canInterruptAtom, session.canInterrupt); + } + + function updateObservedModel(model: ModelSelection, source: ObservedModelSource): void { + observedModelSource = source; + // Only churn the atom when the selection actually changes: the incoming + // object is freshly built on every message.updated, so a reference check + // never holds and would needlessly rebuild the whole model-options list. + if (!modelSelectionsEqual(store.get(observedModelAtom), model)) { + store.set(observedModelAtom, model); + } + } + + function handleRemoteModelStateChange(state: RemoteModelState): void { + const previousOwnerConnectionId = store.get(remoteModelStateAtom).ownerConnectionId; + store.set(remoteModelStateAtom, state); + + if (previousOwnerConnectionId !== state.ownerConnectionId) { + store.set(remoteModelOverrideAtom, null); + if (observedModelSource === 'catalog') { + observedModelSource = null; + store.set(observedModelAtom, null); + } + } else { + const override = store.get(remoteModelOverrideAtom); + const sourceMatchesProtocol = + (state.protocol === 'v1' && override?.source === 'cli-catalog') || + (state.protocol === 'legacy' && override?.source === 'legacy-gateway'); + const provider = state.catalog?.providers.find( + item => item.id === override?.selection.model.providerID + ); + const catalogModel = provider?.models.find( + item => item.id === override?.selection.model.modelID + ); + const modelMatchesProtocol = + state.protocol === 'v1' + ? catalogModel !== undefined + : state.protocol === 'legacy' && override?.selection.model.providerID === 'kilo'; + if (override && (!sourceMatchesProtocol || !modelMatchesProtocol)) { + store.set(remoteModelOverrideAtom, null); + } else if ( + override?.source === 'cli-catalog' && + override.selection.variant && + catalogModel && + !catalogModel.variants.includes(override.selection.variant) + ) { + store.set(remoteModelOverrideAtom, { + source: 'cli-catalog', + selection: { model: override.selection.model }, + }); + } + } + if ( + (observedModelSource === null || observedModelSource === 'catalog') && + state.catalog?.currentModel + ) { + updateObservedModel(state.catalog.currentModel, 'catalog'); + } + } + function subscribeToServiceState( session: CloudAgentSession, opts?: { onFirstActivity?: () => void } @@ -563,16 +669,16 @@ function createSessionManager(config: SessionManagerConfig): SessionManager { store.set(sessionInfoAtom, session.state.getSessionInfo()); store.set(pendingMessagesAtom, new Map(session.state.getPendingMessages())); - // canSend factors in cloud status: preparing/finalizing blocks input - const cloudReady = cs === null || cs.type === 'ready'; // Only update read-only state after the transport has been resolved. // During the 'connecting' phase the transport is null so canSend is // always false, which would briefly flash a "read-only" banner. if (act.type !== 'connecting') { - store.set(isReadOnlyAtom, !session.canSend); + store.set( + isReadOnlyAtom, + activeSessionType === null ? !session.canSend : activeSessionType === 'read-only' + ); } - store.set(canSendAtom, session.canSend && cloudReady); - store.set(canInterruptAtom, session.canInterrupt); + updateCapabilityAtoms(session); if (previousStatus.type === 'disconnected' && st.type !== 'disconnected') { store.set(errorAtom, null); @@ -689,6 +795,28 @@ function createSessionManager(config: SessionManagerConfig): SessionManager { // cast cloudAgentSessionId (the createAndStart path). store.set(rootSessionIdAtom, info.id); store.set(isLoadingAtom, false); + if (info.model) { + updateObservedModel( + toModelSelection( + { providerID: info.model.providerID, modelID: info.model.id }, + info.model.variant + ), + 'session' + ); + } + } + }, + + onSessionUpdated: info => { + const rootSessionId = store.get(rootSessionIdAtom); + if (rootSessionId === info.id && info.model) { + updateObservedModel( + toModelSelection( + { providerID: info.model.providerID, modelID: info.model.id }, + info.model.variant + ), + 'session' + ); } }, onQuestionAsked: (requestId, questions) => { @@ -724,8 +852,15 @@ function createSessionManager(config: SessionManagerConfig): SessionManager { }, onResolved: resolved => { activeSessionType = resolved.type; + store.set(activeSessionTypeAtom, resolved.type); store.set(supportsAttachmentsAtom, resolved.type === 'cloud-agent'); + updateCapabilityAtoms(session); }, + onRemoteModelStateChange: handleRemoteModelStateChange, + onTransportCapabilityChange: () => { + if (currentSession === session) updateCapabilityAtoms(session); + }, + onBranchChanged: branch => { const currentFetched = store.get(fetchedSessionDataAtom); if (currentFetched) { @@ -749,10 +884,30 @@ function createSessionManager(config: SessionManagerConfig): SessionManager { store.set(availableCommandsAtom, event.commands); return; } - if (event.type === 'message.updated' && event.info.role === 'assistant') { + if (event.type === 'message.updated') { const rootSessionId = store.get(rootSessionIdAtom); if (rootSessionId !== null && event.info.sessionID !== rootSessionId) return; + if (event.info.role === 'user') { + if (observedModelSource !== 'session') { + updateObservedModel( + toModelSelection(event.info.model, event.info.variant), + 'message' + ); + } + return; + } + + if (observedModelSource !== 'session') { + updateObservedModel( + toModelSelection( + { providerID: event.info.providerID, modelID: event.info.modelID }, + event.info.variant + ), + 'message' + ); + } + // `info.agent` is the agent slug (e.g. 'code', 'e-code'); `info.mode` // is the visibility ('primary'|'subagent'|'all') and must not be used // as the picker's selected mode. @@ -793,7 +948,7 @@ function createSessionManager(config: SessionManagerConfig): SessionManager { } async function send(input: { - payload: TransportSendPayload; + payload: SessionManagerSendPayload; attachments?: CloudAgentAttachments; images?: Images; }): Promise { @@ -812,6 +967,35 @@ function createSessionManager(config: SessionManagerConfig): SessionManager { input.payload.type === 'command' ? `/${input.payload.command}${input.payload.arguments ? ` ${input.payload.arguments}` : ''}` : input.payload.prompt; + const remoteModelOverride = store.get(remoteModelOverrideAtom); + let transportPayload: TransportSendPayload; + if (input.payload.type === 'command') { + transportPayload = input.payload; + } else if (sessionType === 'remote') { + transportPayload = { + type: 'prompt', + prompt: input.payload.prompt, + ...(input.payload.mode ? { mode: input.payload.mode } : {}), + ...(remoteModelOverride + ? { + model: remoteModelOverride.selection.model, + ...(remoteModelOverride.selection.variant + ? { variant: remoteModelOverride.selection.variant } + : {}), + } + : {}), + }; + } else { + transportPayload = { + type: 'prompt', + prompt: input.payload.prompt, + ...(input.payload.mode ? { mode: input.payload.mode } : {}), + ...(input.payload.model + ? { model: { providerID: 'kilo', modelID: input.payload.model } } + : {}), + ...(input.payload.model && input.payload.variant ? { variant: input.payload.variant } : {}), + }; + } try { if (!currentSession) throw new Error('No active session'); @@ -819,11 +1003,13 @@ function createSessionManager(config: SessionManagerConfig): SessionManager { throw new Error('Only Cloud Agent sessions support attachments'); } await currentSession.send({ - payload: input.payload, + payload: transportPayload, messageId, ...(input.attachments ? { attachments: input.attachments } : {}), images: input.images, + ...(sessionType === 'remote' && remoteModelOverride ? { remoteModelOverride } : {}), }); + if (sessionType === 'remote' && kiloSessionId) { config.onRemoteSessionMessageSent?.({ kiloSessionId }); } @@ -905,6 +1091,14 @@ function createSessionManager(config: SessionManagerConfig): SessionManager { } } + function setRemoteModelOverride(override: RemoteModelOverride | null): void { + store.set(remoteModelOverrideAtom, override); + } + + function retryRemoteModels(): void { + currentSession?.retryRemoteModels(); + } + function destroy(): void { childSessionHydrationGeneration += 1; childSessionHydrationRequests.clear(); @@ -926,6 +1120,8 @@ function createSessionManager(config: SessionManagerConfig): SessionManager { switchSession, hydrateChildSession, send, + setRemoteModelOverride, + retryRemoteModels, interrupt, answerQuestion, rejectQuestion, @@ -943,6 +1139,10 @@ function createSessionManager(config: SessionManagerConfig): SessionManager { isLoading: isLoadingAtom, isReadOnly: isReadOnlyAtom, supportsAttachments: supportsAttachmentsAtom, + activeSessionType: activeSessionTypeAtom, + remoteModelState: remoteModelStateAtom, + observedModel: observedModelAtom, + remoteModelOverride: remoteModelOverrideAtom, canSend: canSendAtom, canInterrupt: canInterruptAtom, statusIndicator: statusIndicatorAtom, diff --git a/apps/web/src/lib/cloud-agent-sdk/session-transport.test.ts b/apps/web/src/lib/cloud-agent-sdk/session-transport.test.ts index 7ac229af47..6426dfe37f 100644 --- a/apps/web/src/lib/cloud-agent-sdk/session-transport.test.ts +++ b/apps/web/src/lib/cloud-agent-sdk/session-transport.test.ts @@ -1,5 +1,7 @@ import { createCloudAgentSession, type CloudAgentSession } from './session'; import type { CloudAgentApi } from './transport'; +import type { KiloSessionId } from './types'; +import type { UserWebSystemEvent } from './user-web-connection'; import { kiloId, cloudAgentId, makeSnapshot } from './test-helpers'; // --------------------------------------------------------------------------- @@ -88,6 +90,58 @@ async function connectSession(session: CloudAgentSession): Promise { mockWs.onopen?.(new Event('open')); } +function createUserWebConnection() { + let systemListener: ((event: UserWebSystemEvent) => void) | undefined; + return { + connect: jest.fn(), + disconnect: jest.fn(), + destroy: jest.fn(), + subscribeToCliSession: jest.fn(() => jest.fn()), + sendCommand: jest.fn((_sessionId: string, command: string) => + Promise.resolve( + command === 'list_models' + ? { protocolVersion: 1, providers: [], truncated: false } + : { ok: true } + ) + ), + onCliEvent: jest.fn(() => jest.fn()), + onSystemEvent: jest.fn((listener: (event: UserWebSystemEvent) => void) => { + systemListener = listener; + return jest.fn(); + }), + onReconnect: jest.fn(() => jest.fn()), + onSessionEvent: jest.fn(() => jest.fn()), + emitSystem(event: UserWebSystemEvent) { + systemListener?.(event); + }, + }; +} + +function emitSessionsListOwner( + connection: ReturnType, + sessionId: KiloSessionId +): void { + connection.emitSystem({ + event: 'sessions.list', + data: { + sessions: [{ id: sessionId, status: 'active', title: 'Remote', connectionId: 'owner' }], + }, + }); +} + +function emitHeartbeatOwner( + connection: ReturnType, + sessionId: KiloSessionId +): void { + connection.emitSystem({ + event: 'sessions.heartbeat', + data: { + connectionId: 'owner', + sessions: [{ id: sessionId, status: 'active', title: 'Remote' }], + }, + }); +} + // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- @@ -98,12 +152,19 @@ describe('session transport delegation (cloud agent)', () => { const session = createCloudAgentResolvedSession(api); await connectSession(session); - await session.send({ payload: { type: 'prompt', prompt: 'hello', mode: 'auto' } }); + await session.send({ + payload: { + type: 'prompt', + prompt: 'hello', + mode: 'auto', + model: { providerID: 'kilo', modelID: 'test/model-1' }, + }, + }); expect(api.send).toHaveBeenCalledTimes(1); expect(api.send).toHaveBeenCalledWith({ sessionId: cloudAgentSessionId, - payload: { type: 'prompt', prompt: 'hello', mode: 'auto' }, + payload: { type: 'prompt', prompt: 'hello', mode: 'auto', model: 'test/model-1' }, }); session.destroy(); @@ -119,13 +180,18 @@ describe('session transport delegation (cloud agent)', () => { await connectSession(session); await session.send({ - payload: { type: 'prompt', prompt: 'hello', mode: 'auto' }, + payload: { + type: 'prompt', + prompt: 'hello', + mode: 'auto', + model: { providerID: 'kilo', modelID: 'test/model-1' }, + }, attachments, }); expect(api.send).toHaveBeenCalledWith({ sessionId: cloudAgentSessionId, - payload: { type: 'prompt', prompt: 'hello', mode: 'auto' }, + payload: { type: 'prompt', prompt: 'hello', mode: 'auto', model: 'test/model-1' }, attachments, }); @@ -288,20 +354,6 @@ describe('session transport missing command methods (read-only session)', () => describe('remote session send via typed transport methods', () => { const cliKiloSessionId = kiloId('ses_cli-live-session'); - function createUserWebConnection() { - return { - connect: jest.fn(), - disconnect: jest.fn(), - destroy: jest.fn(), - subscribeToCliSession: jest.fn(() => jest.fn()), - sendCommand: jest.fn(() => Promise.resolve({ ok: true })), - onCliEvent: jest.fn(() => jest.fn()), - onSystemEvent: jest.fn(() => jest.fn()), - onReconnect: jest.fn(() => jest.fn()), - onSessionEvent: jest.fn(() => jest.fn()), - }; - } - it('uses the required user web connection without constructing a viewer socket', async () => { const userWebConnection = createUserWebConnection(); const session = createCloudAgentSession({ @@ -313,6 +365,8 @@ describe('remote session send via typed transport methods', () => { session.connect(); await new Promise(r => setTimeout(r, 0)); await new Promise(r => setTimeout(r, 0)); + emitSessionsListOwner(userWebConnection, cliKiloSessionId); + await Promise.resolve(); await session.send({ payload: { type: 'prompt', prompt: 'Hello remote' } }); @@ -320,7 +374,8 @@ describe('remote session send via typed transport methods', () => { expect(userWebConnection.sendCommand).toHaveBeenCalledWith( cliKiloSessionId, 'send_message', - expect.objectContaining({ sessionID: cliKiloSessionId }) + expect.objectContaining({ sessionID: cliKiloSessionId }), + 'owner' ); expect(jest.mocked(global.WebSocket)).not.toHaveBeenCalled(); session.destroy(); @@ -339,17 +394,28 @@ describe('remote session send via typed transport methods', () => { session.connect(); await new Promise(r => setTimeout(r, 0)); + emitHeartbeatOwner(userWebConnection, cliKiloSessionId); + await Promise.resolve(); await session.send({ - payload: { type: 'prompt', prompt: 'Hello world', mode: 'code', model: 'test/model-1' }, + payload: { + type: 'prompt', + prompt: 'Hello world', + mode: 'code', + model: { providerID: 'kilo', modelID: 'test/model-1' }, + }, }); - expect(userWebConnection.sendCommand).toHaveBeenCalledWith(cliKiloSessionId, 'send_message', { - sessionID: cliKiloSessionId, - parts: [{ type: 'text', text: 'Hello world' }], - agent: 'code', - model: 'test/model-1', - }); + expect(userWebConnection.sendCommand).toHaveBeenCalledWith( + cliKiloSessionId, + 'send_message', + { + sessionID: cliKiloSessionId, + parts: [{ type: 'text', text: 'Hello world' }], + agent: 'code', + }, + 'owner' + ); session.destroy(); }); }); @@ -370,32 +436,22 @@ describe('session capabilities', () => { session.destroy(); }); - it('canSend is true after connecting a remote session', async () => { + it('canSend is true after a remote session owner is observed', async () => { + const cliKiloSessionId = kiloId('ses_cli-live'); + const userWebConnection = createUserWebConnection(); const session = createCloudAgentSession({ - kiloSessionId: kiloId('ses_cli-live'), + kiloSessionId: cliKiloSessionId, resolveSession: async () => ({ type: 'remote' as const, - kiloSessionId: kiloId('ses_cli-live'), + kiloSessionId: cliKiloSessionId, }), - transport: { - userWebConnection: { - connect: jest.fn(), - disconnect: jest.fn(), - destroy: jest.fn(), - subscribeToCliSession: jest.fn(() => jest.fn()), - sendCommand: jest.fn(() => Promise.resolve()), - onCliEvent: jest.fn(() => jest.fn()), - onSystemEvent: jest.fn(() => jest.fn()), - onReconnect: jest.fn(() => jest.fn()), - onSessionEvent: jest.fn(() => jest.fn()), - }, - }, + transport: { userWebConnection }, }); session.connect(); await new Promise(r => setTimeout(r, 0)); await new Promise(r => setTimeout(r, 0)); - await new Promise(r => setTimeout(r, 0)); + emitSessionsListOwner(userWebConnection, cliKiloSessionId); expect(session.canSend).toBe(true); session.destroy(); diff --git a/apps/web/src/lib/cloud-agent-sdk/session.test.ts b/apps/web/src/lib/cloud-agent-sdk/session.test.ts index c0004e3e8a..eed3982c2a 100644 --- a/apps/web/src/lib/cloud-agent-sdk/session.test.ts +++ b/apps/web/src/lib/cloud-agent-sdk/session.test.ts @@ -4,7 +4,10 @@ * Instead of testing through createCloudAgentSession (which requires a WebSocket), * we wire the same components directly — mirroring session.ts's event routing logic. */ -import { createTestSession } from './test-helpers'; +import { createTestSession, kiloId } from './test-helpers'; +import { createCloudAgentSession } from './session'; +import type { RemoteModelState } from './remote-model-catalog'; +import type { UserWebSystemEvent } from './user-web-connection'; import { createEventHelpers, sessionInfo, @@ -358,3 +361,69 @@ describe('session pipeline integration', () => { }); }); }); + +describe('remote session transport state', () => { + it('publishes owner-scoped model state and dynamic send capability', async () => { + let systemListener: ((event: UserWebSystemEvent) => void) | undefined; + const onRemoteModelStateChange = jest.fn(); + const onTransportCapabilityChange = jest.fn(); + const userWebConnection = { + connect: jest.fn(), + disconnect: jest.fn(), + destroy: jest.fn(), + subscribeToCliSession: jest.fn(() => jest.fn()), + sendCommand: jest.fn(() => + Promise.resolve({ + protocolVersion: 1, + providers: [], + truncated: false, + }) + ), + onCliEvent: jest.fn(() => jest.fn()), + onSystemEvent: jest.fn((listener: (event: UserWebSystemEvent) => void) => { + systemListener = listener; + return jest.fn(); + }), + onReconnect: jest.fn(() => jest.fn()), + onSessionEvent: jest.fn(() => jest.fn()), + }; + const kiloSessionId = kiloId('ses-remote'); + const session = createCloudAgentSession({ + kiloSessionId, + resolveSession: () => Promise.resolve({ type: 'remote', kiloSessionId }), + transport: { userWebConnection }, + onRemoteModelStateChange, + onTransportCapabilityChange, + }); + + session.connect(); + await Promise.resolve(); + expect(session.canSend).toBe(false); + + systemListener?.({ + event: 'sessions.list', + data: { + sessions: [ + { + id: kiloSessionId, + status: 'active', + title: 'Remote', + connectionId: 'owner', + }, + ], + }, + }); + await Promise.resolve(); + await Promise.resolve(); + + expect(session.canSend).toBe(true); + expect(onTransportCapabilityChange).toHaveBeenCalled(); + expect(onRemoteModelStateChange).toHaveBeenLastCalledWith({ + ownerConnectionId: 'owner', + protocol: 'v1', + catalog: { protocolVersion: 1, providers: [], truncated: false }, + refresh: 'idle', + }); + session.destroy(); + }); +}); diff --git a/apps/web/src/lib/cloud-agent-sdk/session.ts b/apps/web/src/lib/cloud-agent-sdk/session.ts index 4c75714f80..0aa28d42d7 100644 --- a/apps/web/src/lib/cloud-agent-sdk/session.ts +++ b/apps/web/src/lib/cloud-agent-sdk/session.ts @@ -10,6 +10,7 @@ import type { CloudAgentAttachments } from '@/lib/cloud-agent/constants'; import type { Images } from '@/lib/images-schema'; import type { NormalizedEvent } from './normalizer'; import type { SuggestionAction } from './types'; +import type { RemoteModelOverride, RemoteModelState } from './remote-model-catalog'; import { createChatProcessor } from './chat-processor'; import { createServiceState } from './service-state'; import type { ServiceState } from './service-state'; @@ -63,6 +64,8 @@ type CloudAgentSessionConfig = { onSuggestionResolved?: (requestId: string) => void; onBranchChanged?: (branch: string) => void; onResolved?: (resolved: ResolvedSession) => void; + onRemoteModelStateChange?: (state: RemoteModelState) => void; + onTransportCapabilityChange?: () => void; onSessionCreated?: (info: SessionInfo) => void; onSessionUpdated?: (info: SessionInfo) => void; onEvent?: (event: NormalizedEvent) => void; @@ -79,6 +82,7 @@ type CloudAgentSessionSendInput = { messageId?: string; attachments?: CloudAgentAttachments; images?: Images; + remoteModelOverride?: RemoteModelOverride; }; type CloudAgentSessionAnswerInput = { @@ -127,7 +131,7 @@ type CloudAgentSession = { state: ServiceState; // Commands - send: (payload: CloudAgentSessionSendInput) => unknown | Promise; + send: (input: CloudAgentSessionSendInput) => unknown | Promise; interrupt: () => unknown | Promise; answer: (payload: CloudAgentSessionAnswerInput) => unknown | Promise; reject: (payload: CloudAgentSessionRejectInput) => unknown | Promise; @@ -138,6 +142,7 @@ type CloudAgentSession = { dismissSuggestion: ( payload: CloudAgentSessionDismissSuggestionInput ) => unknown | Promise; + retryRemoteModels: () => void; // Capability checks canSend: boolean; @@ -208,6 +213,8 @@ function createCloudAgentSession(config: CloudAgentSessionConfig): CloudAgentSes userWebConnection: config.transport.userWebConnection, fetchSnapshot: config.transport.fetchSnapshot, onError: config.onError, + onRemoteModelStateChange: config.onRemoteModelStateChange, + onCapabilityChange: config.onTransportCapabilityChange, }); } case 'cloud-agent': { @@ -293,15 +300,17 @@ function createCloudAgentSession(config: CloudAgentSessionConfig): CloudAgentSes transport.connect(); } + const send = (input: CloudAgentSessionSendInput): unknown | Promise => { + if (!transport?.send) { + throw new Error('CloudAgentSession transport.send is not configured'); + } + return transport.send(input); + }; + return { storage, state: serviceState, - send: payload => { - if (!transport?.send) { - throw new Error('CloudAgentSession transport.send is not configured'); - } - return transport.send(payload); - }, + send, interrupt: () => { if (!transport?.interrupt) { throw new Error('CloudAgentSession transport.interrupt is not configured'); @@ -360,8 +369,11 @@ function createCloudAgentSession(config: CloudAgentSessionConfig): CloudAgentSes } return result; }, + retryRemoteModels() { + transport?.retryRemoteModels?.(); + }, get canSend() { - return transport?.send !== undefined; + return transport?.send !== undefined && (transport.canSend?.() ?? true); }, get canInterrupt() { return transport?.interrupt !== undefined; diff --git a/apps/web/src/lib/cloud-agent-sdk/transport.ts b/apps/web/src/lib/cloud-agent-sdk/transport.ts index 0c16b4b2db..f385c12145 100644 --- a/apps/web/src/lib/cloud-agent-sdk/transport.ts +++ b/apps/web/src/lib/cloud-agent-sdk/transport.ts @@ -8,6 +8,7 @@ import type { ChatEvent, ServiceEvent } from './normalizer'; import type { CloudAgentAttachments } from '@/lib/cloud-agent/constants'; import type { Images } from '@/lib/images-schema'; import type { CloudAgentSessionId } from './types'; +import type { ModelRef, RemoteModelOverride } from './remote-model-catalog'; type CloudAgentStreamTicket = { ticket: string; @@ -33,7 +34,7 @@ type SendPromptPayload = { type: 'prompt'; prompt: string; mode?: string; - model?: string; + model?: ModelRef; variant?: string; }; type SendCommandPayload = { @@ -44,6 +45,23 @@ type SendCommandPayload = { }; type TransportSendPayload = SendPromptPayload | SendCommandPayload; +type CloudAgentPromptPayload = { + type: 'prompt'; + prompt: string; + mode: string; + model: string; + variant?: string; +}; +type CloudAgentSendPayload = CloudAgentPromptPayload | SendCommandPayload; + +type TransportSendInput = { + payload: TransportSendPayload; + messageId?: string; + attachments?: CloudAgentAttachments; + images?: Images; + remoteModelOverride?: RemoteModelOverride; +}; + /** Lifecycle interface for a transport. */ type Transport = { connect(): void; @@ -51,12 +69,9 @@ type Transport = { destroy(): void; // Commands — present only on interactive transports - send?: (payload: { - payload: TransportSendPayload; - messageId?: string; - attachments?: CloudAgentAttachments; - images?: Images; - }) => Promise; + send?: (payload: TransportSendInput) => Promise; + canSend?: () => boolean; + retryRemoteModels?: () => void; interrupt?: () => Promise; answer?: (payload: { requestId: string; answers: string[][] }) => Promise; reject?: (payload: { requestId: string }) => Promise; @@ -81,7 +96,7 @@ type TransportFactory = (sink: TransportSink) => Transport; type CloudAgentApi = { send: (payload: { sessionId: CloudAgentSessionId; - payload: TransportSendPayload; + payload: CloudAgentSendPayload; messageId?: string; attachments?: CloudAgentAttachments; images?: Images; @@ -102,11 +117,14 @@ type CloudAgentApi = { export type { CloudAgentApi, + CloudAgentPromptPayload, + CloudAgentSendPayload, CloudAgentStreamTicket, CloudAgentStreamTicketResult, TransportFactory, TransportSink, Transport, + TransportSendInput, TransportSendPayload, SendPromptPayload, SendCommandPayload, diff --git a/apps/web/src/lib/cloud-agent-sdk/types.ts b/apps/web/src/lib/cloud-agent-sdk/types.ts index cd62ba359a..fa92cba257 100644 --- a/apps/web/src/lib/cloud-agent-sdk/types.ts +++ b/apps/web/src/lib/cloud-agent-sdk/types.ts @@ -40,10 +40,15 @@ export type ProcessedMessage = { parts: Part[]; }; -/** Minimal session identity — only the fields the SDK actually reads. */ +/** Minimal session metadata — only the fields the SDK actually reads. */ export type SessionInfo = { id: string; parentID?: string; + model?: { + providerID: string; + id: string; + variant?: string; + }; }; export type SessionPhase = diff --git a/apps/web/src/lib/cloud-agent-sdk/user-web-connection.test.ts b/apps/web/src/lib/cloud-agent-sdk/user-web-connection.test.ts index 26c4be0d6c..3ab0ec5bad 100644 --- a/apps/web/src/lib/cloud-agent-sdk/user-web-connection.test.ts +++ b/apps/web/src/lib/cloud-agent-sdk/user-web-connection.test.ts @@ -1,6 +1,7 @@ import { configureCloudAgentSdkRuntime, resetCloudAgentSdkRuntime } from './runtime'; import { createUserWebConnection, + UserWebCommandError, VIEWER_PING_INTERVAL_MS, VIEWER_PONG_TIMEOUT_MS, } from './user-web-connection'; @@ -1074,6 +1075,143 @@ describe('createUserWebConnection', () => { client.destroy(); }); + it('sends the expected owner connection id when provided', async () => { + const client = createUserWebConnection({ websocketUrl: WS_URL, getAuthToken: () => 'token' }); + client.connect(); + open(); + + const promise = client.sendCommand( + 'ses-1', + 'list_models', + { protocolVersion: 1 }, + 'cli-owner-1' + ); + await Promise.resolve(); + + expect(sockets[0].send).toHaveBeenCalledWith( + JSON.stringify({ + type: 'command', + id: 'uuid-2', + command: 'list_models', + sessionId: 'ses-1', + connectionId: 'cli-owner-1', + data: { protocolVersion: 1 }, + }) + ); + inbound({ type: 'response', id: 'uuid-2', result: { protocolVersion: 1 } }); + await expect(promise).resolves.toEqual({ protocolVersion: 1 }); + client.destroy(); + }); + + it('preserves strict structured relay errors as typed command errors', async () => { + const client = createUserWebConnection({ websocketUrl: WS_URL, getAuthToken: () => 'token' }); + client.connect(); + open(); + + const promise = client.sendCommand('ses-1', 'send_message', {}); + await Promise.resolve(); + inbound({ + type: 'response', + id: 'uuid-2', + error: { + source: 'relay', + code: 'SESSION_OWNER_CHANGED', + message: 'Session owner changed', + }, + }); + + await expect(promise).rejects.toEqual( + expect.objectContaining({ + name: 'UserWebCommandError', + code: 'SESSION_OWNER_CHANGED', + message: 'Session owner changed', + }) + ); + await expect(promise).rejects.toBeInstanceOf(UserWebCommandError); + client.destroy(); + }); + + it('keeps sanitized CLI error envelopes generic', async () => { + const client = createUserWebConnection({ websocketUrl: WS_URL, getAuthToken: () => 'token' }); + client.connect(); + open(); + + const promise = client.sendCommand('ses-1', 'send_message', {}); + await Promise.resolve(); + inbound({ + type: 'response', + id: 'uuid-2', + error: { source: 'cli', message: 'Command failed' }, + }); + + await expect(promise).rejects.toEqual( + expect.objectContaining({ name: 'Error', message: 'Command failed' }) + ); + await expect(promise).rejects.not.toBeInstanceOf(UserWebCommandError); + client.destroy(); + }); + + it('keeps relay envelopes with extra fields generic', async () => { + const client = createUserWebConnection({ websocketUrl: WS_URL, getAuthToken: () => 'token' }); + client.connect(); + open(); + + const promise = client.sendCommand('ses-1', 'send_message', {}); + await Promise.resolve(); + inbound({ + type: 'response', + id: 'uuid-2', + error: { + source: 'relay', + code: 'SESSION_OWNER_CHANGED', + message: 'Session owner changed', + ownerConnectionId: 'private-owner', + }, + }); + + await expect(promise).rejects.toEqual( + expect.objectContaining({ name: 'Error', message: 'Command failed' }) + ); + await expect(promise).rejects.not.toBeInstanceOf(UserWebCommandError); + client.destroy(); + }); + + it('keeps malformed relay error objects generic', async () => { + const client = createUserWebConnection({ websocketUrl: WS_URL, getAuthToken: () => 'token' }); + client.connect(); + open(); + + const promise = client.sendCommand('ses-1', 'send_message', {}); + await Promise.resolve(); + inbound({ + type: 'response', + id: 'uuid-2', + error: { source: 'relay', code: 'UNTRUSTED_CODE', message: { raw: 'internal details' } }, + }); + + await expect(promise).rejects.toEqual( + expect.objectContaining({ name: 'Error', message: 'Command failed' }) + ); + await expect(promise).rejects.not.toBeInstanceOf(UserWebCommandError); + client.destroy(); + }); + + it('preserves CLI string errors', async () => { + const client = createUserWebConnection({ websocketUrl: WS_URL, getAuthToken: () => 'token' }); + client.connect(); + open(); + + const promise = client.sendCommand('ses-1', 'send_message', {}); + await Promise.resolve(); + inbound({ type: 'response', id: 'uuid-2', error: 'CLI disconnected' }); + + await expect(promise).rejects.toEqual( + expect.objectContaining({ name: 'Error', message: 'CLI disconnected' }) + ); + await expect(promise).rejects.not.toBeInstanceOf(UserWebCommandError); + client.destroy(); + }); + it('routes command responses by request id', async () => { const client = createUserWebConnection({ websocketUrl: WS_URL, getAuthToken: () => 'token' }); client.connect(); diff --git a/apps/web/src/lib/cloud-agent-sdk/user-web-connection.ts b/apps/web/src/lib/cloud-agent-sdk/user-web-connection.ts index 998d4580ae..b57d7ebe53 100644 --- a/apps/web/src/lib/cloud-agent-sdk/user-web-connection.ts +++ b/apps/web/src/lib/cloud-agent-sdk/user-web-connection.ts @@ -6,6 +6,7 @@ import { import { cloudAgentSdkRuntime } from './runtime'; import { sessionEventPayloadSchema, + userWebCommandErrorDataSchema, webInboundMessageSchema, type SessionEventPayload, type WebInboundMessage, @@ -25,6 +26,16 @@ type UserWebSessionEventData = Extract< type CliEvent = Omit, 'type'>; type SystemEvent = Omit, 'type'>; +class UserWebCommandError extends Error { + readonly code: string; + + constructor(error: { code: string; message: string }) { + super(error.message); + this.name = 'UserWebCommandError'; + this.code = error.code; + } +} + type UserWebConnectionConfig = { websocketUrl: string; getAuthToken: () => string | Promise; @@ -42,7 +53,12 @@ type UserWebConnection = { disconnect: () => void; destroy: () => void; subscribeToCliSession: (sessionId: string) => () => void; - sendCommand: (sessionId: string, command: string, data: unknown) => Promise; + sendCommand: ( + sessionId: string, + command: string, + data: unknown, + expectedOwnerConnectionId?: string + ) => Promise; onCliEvent: (sessionId: string, listener: (event: CliEvent) => void) => () => void; onSystemEvent: (listener: (event: SystemEvent) => void) => () => void; onReconnect: (listener: () => void) => () => void; @@ -52,6 +68,14 @@ type UserWebConnection = { ) => () => void; }; +function parseCommandError(error: unknown): Error { + if (typeof error === 'string') return new Error(error); + + const parsed = userWebCommandErrorDataSchema.safeParse(error); + if (parsed.success) return new UserWebCommandError(parsed.data); + return new Error('Command failed'); +} + function createUserWebConnection( config: UserWebConnectionConfig ): UserWebConnection & { retain: () => () => void } { @@ -291,8 +315,7 @@ function createUserWebConnection( if (!pending) return; clearTimeout(pending.timer); pendingCommands.delete(msg.id); - if (msg.error) - pending.reject(new Error(typeof msg.error === 'string' ? msg.error : 'Command failed')); + if (msg.error) pending.reject(parseCommandError(msg.error)); else pending.resolve(msg.result); } @@ -474,7 +497,7 @@ function createUserWebConnection( releaseConnection(); }; }, - sendCommand(sessionId, command, data) { + sendCommand(sessionId, command, data, expectedOwnerConnectionId) { const hasOwnerLifetime = retainCount > commandRetainCount; const releaseCommandLifetime = hasOwnerLifetime ? null : retainConnection(); if (releaseCommandLifetime) commandRetainCount += 1; @@ -512,7 +535,16 @@ function createUserWebConnection( rejectCommand(new Error('Command timed out')); }, COMMAND_TIMEOUT_MS); pendingCommands.set(id, { resolve: resolveCommand, reject: rejectCommand, timer }); - ws.send(JSON.stringify({ type: 'command', id, command, sessionId, data })); + ws.send( + JSON.stringify({ + type: 'command', + id, + command, + sessionId, + connectionId: expectedOwnerConnectionId, + data, + }) + ); }, reason => { rejectCommand( @@ -551,7 +583,7 @@ function createUserWebConnection( }; } -export { createUserWebConnection }; +export { createUserWebConnection, UserWebCommandError }; export type { UserWebConnection, UserWebConnectionConfig, diff --git a/apps/web/src/lib/session-ingest-client.test.ts b/apps/web/src/lib/session-ingest-client.test.ts index 4c5074af85..1ac6df2fe3 100644 --- a/apps/web/src/lib/session-ingest-client.test.ts +++ b/apps/web/src/lib/session-ingest-client.test.ts @@ -38,10 +38,11 @@ function makeSnapshot( messages: Array<{ role: string; parts: Array<{ type: string; text?: string; id?: string }>; - }> + }>, + info: SessionSnapshot['info'] = {} ): SessionSnapshot { return { - info: {}, + info, messages: messages.map((m, i) => ({ info: { id: `msg_${i}`, role: m.role }, parts: m.parts.map((p, j) => ({ @@ -64,10 +65,19 @@ describe('fetchSessionSnapshot', () => { mockGenerateInternalServiceToken.mockReset().mockReturnValue('mock-jwt-token'); }); - it('returns parsed snapshot on 200 response', async () => { - const snapshot = makeSnapshot([ - { role: 'assistant', parts: [{ type: 'text', text: 'hello' }] }, - ]); + it('returns parsed snapshot with session model metadata on 200 response', async () => { + const snapshot = makeSnapshot( + [{ role: 'assistant', parts: [{ type: 'text', text: 'hello' }] }], + { + id: 'ses_abc123', + title: 'Remote session', + model: { + providerID: 'anthropic', + id: 'claude-sonnet-4', + variant: 'thinking', + }, + } + ); mockFetch.mockResolvedValue({ ok: true, @@ -86,6 +96,23 @@ describe('fetchSessionSnapshot', () => { ); }); + it('validates session model metadata', async () => { + mockFetch.mockResolvedValue({ + ok: true, + status: 200, + json: () => + Promise.resolve({ + info: { + id: 'ses_abc123', + model: { providerID: 'anthropic', id: 42, variant: 'thinking' }, + }, + messages: [], + }), + }); + + await expect(fetchSessionSnapshot('ses_abc123', 'user_123')).rejects.toThrow(); + }); + it('returns null on 404', async () => { mockFetch.mockResolvedValue({ ok: false, status: 404, statusText: 'Not Found' }); diff --git a/apps/web/src/lib/session-ingest-client.ts b/apps/web/src/lib/session-ingest-client.ts index c5fe3584a2..4da96aca32 100644 --- a/apps/web/src/lib/session-ingest-client.ts +++ b/apps/web/src/lib/session-ingest-client.ts @@ -12,8 +12,20 @@ import type { User } from '@kilocode/db/schema'; // Mirrors SharedSessionSnapshotSchema from cloudflare-session-ingest/src/util/share-output.ts. // Kept in sync manually (same pattern as cloud-agent-client.ts). +const SessionInfoSchema = z.looseObject({ + id: z.string().optional(), + parentID: z.string().optional(), + model: z + .object({ + providerID: z.string(), + id: z.string(), + variant: z.string().optional(), + }) + .optional(), +}); + const SessionSnapshotSchema = z.object({ - info: z.unknown(), + info: SessionInfoSchema, messages: z.array( z.looseObject({ info: z.looseObject({ diff --git a/apps/web/src/routers/cli-sessions-v2-router.test.ts b/apps/web/src/routers/cli-sessions-v2-router.test.ts index 8fe4f4b312..01c8ef2cbf 100644 --- a/apps/web/src/routers/cli-sessions-v2-router.test.ts +++ b/apps/web/src/routers/cli-sessions-v2-router.test.ts @@ -76,6 +76,67 @@ describe('cli-sessions-v2-router', () => { testOrganization = org; }); + describe('getSessionMessages', () => { + const sessionId = 'ses_snapshot_metadata_test_1234'; + let fetchSpy: jest.SpyInstance; + + beforeEach(async () => { + await db.insert(cli_sessions_v2).values({ + session_id: sessionId, + kilo_user_id: regularUser.id, + created_on_platform: 'cloud-agent', + }); + fetchSpy = jest.spyOn(global, 'fetch').mockResolvedValue( + new Response( + JSON.stringify({ + info: { + id: sessionId, + model: { + providerID: 'anthropic', + id: 'claude-sonnet-4', + variant: 'thinking', + }, + }, + messages: [{ info: { id: 'msg_1', role: 'user' }, parts: [] }], + }), + { status: 200, headers: { 'Content-Type': 'application/json' } } + ) + ); + }); + + afterEach(async () => { + fetchSpy.mockRestore(); + await db.delete(cli_sessions_v2).where(eq(cli_sessions_v2.session_id, sessionId)); + }); + + it('returns validated snapshot info together with messages', async () => { + const caller = await createCallerForUser(regularUser.id); + + const result = await caller.cliSessionsV2.getSessionMessages({ session_id: sessionId }); + + expect(result).toEqual({ + info: { + id: sessionId, + model: { + providerID: 'anthropic', + id: 'claude-sonnet-4', + variant: 'thinking', + }, + }, + messages: [{ info: { id: 'msg_1', role: 'user' }, parts: [] }], + }); + }); + + it('does not fetch a snapshot for a session owned by another user', async () => { + const caller = await createCallerForUser(otherUser.id); + + await expect( + caller.cliSessionsV2.getSessionMessages({ session_id: sessionId }) + ).rejects.toThrow('Session not found'); + expect(fetchSpy).not.toHaveBeenCalled(); + }); + }); + describe('shareForWebhookTrigger', () => { let triggerId: string; let profileId: string; diff --git a/apps/web/src/routers/cli-sessions-v2-router.ts b/apps/web/src/routers/cli-sessions-v2-router.ts index ad0ed236bf..b25e870ec5 100644 --- a/apps/web/src/routers/cli-sessions-v2-router.ts +++ b/apps/web/src/routers/cli-sessions-v2-router.ts @@ -27,7 +27,7 @@ import { import { createCloudAgentNextClient } from '@/lib/cloud-agent-next/cloud-agent-client'; import { generateApiToken, generateInternalServiceToken } from '@/lib/tokens'; import { - fetchSessionMessages, + fetchSessionSnapshot, deleteSession as deleteSessionIngest, shareSession as shareSessionIngest, } from '@/lib/session-ingest-client'; @@ -689,7 +689,7 @@ export const cliSessionsV2Router = createTRPCRouter({ }), /** - * Get messages for a V2 session from the session ingest worker. + * Get snapshot metadata and messages for a V2 session from the session ingest worker. */ getSessionMessages: baseProcedure .input(z.object({ session_id: sessionIdField })) @@ -697,8 +697,8 @@ export const cliSessionsV2Router = createTRPCRouter({ await getSessionWithAccessCheck(input.session_id, ctx); try { - const messages = await fetchSessionMessages(input.session_id, ctx.user); - return { messages: messages ?? [] }; + const snapshot = await fetchSessionSnapshot(input.session_id, ctx.user.id); + return snapshot ?? { info: {}, messages: [] }; } catch (error) { console.error( `Failed to fetch messages for session ${input.session_id}:`, diff --git a/services/session-ingest/src/dos/UserConnectionDO.test.ts b/services/session-ingest/src/dos/UserConnectionDO.test.ts index 82d868e45f..4dcc795b43 100644 --- a/services/session-ingest/src/dos/UserConnectionDO.test.ts +++ b/services/session-ingest/src/dos/UserConnectionDO.test.ts @@ -12,7 +12,7 @@ vi.mock('cloudflare:workers', () => ({ }, })); -import { UserConnectionDO } from './UserConnectionDO'; +import { MAX_CATALOG_RESULT_BYTES, UserConnectionDO } from './UserConnectionDO'; // --------------------------------------------------------------------------- // Mock WebSocket @@ -152,6 +152,31 @@ function connectWebSocket(doInstance: UserConnectionDO, connectionId: string): M return server; } +function connectCliSocket(doInstance: UserConnectionDO, connectionId: string): MockWS { + const client = createMockWs(); + const server = createMockWs(); + vi.stubGlobal( + 'WebSocketPair', + class { + 0 = client; + 1 = server; + } + ); + vi.stubGlobal( + 'Response', + class { + constructor(_body?: BodyInit | null, _init?: ResponseInit) {} + } + ); + + doInstance.fetch( + new Request(`http://local/cli?connectionId=${connectionId}`, { + headers: { Upgrade: 'websocket' }, + }) + ); + return server; +} + /** Create a CLI WebSocket and add it to the context with proper attachment. */ function addCliSocket( mockCtx: ReturnType, @@ -224,6 +249,29 @@ function sendCliResponse( doInstance.webSocketMessage(cliWs as never, msg); } +function createResultWithSerializedBytes(targetBytes: number): { padding: string } { + const framingBytes = new TextEncoder().encode(JSON.stringify({ padding: '' })).byteLength; + const result = { padding: 'x'.repeat(targetBytes - framingBytes) }; + if (new TextEncoder().encode(JSON.stringify(result)).byteLength !== targetBytes) { + throw new Error(`Result fixture does not serialize to ${targetBytes} bytes`); + } + return result; +} + +function createUtf8OversizedResult(): { padding: string } { + const framingBytes = JSON.stringify({ padding: '' }).length; + const result = { + padding: 'é'.repeat(Math.floor((MAX_CATALOG_RESULT_BYTES - framingBytes) / 2) + 1), + }; + if ( + JSON.stringify(result).length >= MAX_CATALOG_RESULT_BYTES || + new TextEncoder().encode(JSON.stringify(result)).byteLength <= MAX_CATALOG_RESULT_BYTES + ) { + throw new Error('UTF-8 catalog fixture does not cross the byte-only boundary'); + } + return result; +} + /** Trigger CLI disconnect */ function disconnectCli(doInstance: UserConnectionDO, cliWs: MockWS) { doInstance.webSocketClose(cliWs as never, 0, '', false); @@ -332,6 +380,39 @@ describe('UserConnectionDO', () => { }); }); + it('fails an in-flight command when the session owner changes', () => { + const { doInstance, mockCtx } = setup(); + const firstOwner = addCliSocket(mockCtx, 'cli-1'); + const nextOwner = addCliSocket(mockCtx, 'cli-2'); + const webWs = addWebSocket(mockCtx, 'web-1'); + + sendHeartbeat(doInstance, firstOwner, [makeSession('s1')]); + sendHeartbeat(doInstance, nextOwner, []); + firstOwner.send.mockClear(); + sendCommand(doInstance, webWs, { + id: 'cmd-1', + command: 'list_models', + sessionId: 's1', + connectionId: 'cli-1', + }); + const correlationId = getCorrelationId(firstOwner); + webWs.send.mockClear(); + + sendHeartbeat(doInstance, nextOwner, [makeSession('s1')]); + + expect(parseSent(webWs)).toEqual({ + type: 'response', + id: 'cmd-1', + error: { + source: 'relay', + code: 'SESSION_OWNER_CHANGED', + message: 'Session owner changed', + }, + }); + sendCliResponse(doInstance, firstOwner, { id: correlationId, result: 'late' }); + expect(webWs.send).toHaveBeenCalledTimes(1); + }); + it('replays existing web subscriptions when a session gets a new CLI owner', () => { const { doInstance, mockCtx } = setup(); const cli1 = addCliSocket(mockCtx, 'cli-1'); @@ -799,6 +880,63 @@ describe('UserConnectionDO', () => { expect(errorResp).toMatchObject({ type: 'response', id: 'cmd-1', error: 'CLI disconnected' }); }); + it('reports owner change when an owner-fenced command target disconnects', () => { + const { doInstance, mockCtx } = setup(); + const cliWs = addCliSocket(mockCtx, 'cli-1'); + const webWs = addWebSocket(mockCtx, 'web-1'); + + sendHeartbeat(doInstance, cliWs, [makeSession('s1')]); + sendCommand(doInstance, webWs, { + id: 'cmd-1', + command: 'list_models', + sessionId: 's1', + connectionId: 'cli-1', + }); + webWs.send.mockClear(); + + mockCtx.removeSocket(cliWs); + disconnectCli(doInstance, cliWs); + + expect(parseSent(webWs)).toEqual({ + type: 'response', + id: 'cmd-1', + error: { + source: 'relay', + code: 'SESSION_OWNER_CHANGED', + message: 'Session owner changed', + }, + }); + }); + + it('fails pending commands as soon as their target socket is replaced', () => { + const { doInstance, mockCtx } = setup(); + const firstCli = connectCliSocket(doInstance, 'cli-1'); + const webWs = addWebSocket(mockCtx, 'web-1'); + + sendHeartbeat(doInstance, firstCli, [makeSession('s1')]); + firstCli.send.mockClear(); + sendCommand(doInstance, webWs, { + id: 'cmd-1', + command: 'list_models', + sessionId: 's1', + connectionId: 'cli-1', + }); + webWs.send.mockClear(); + + connectCliSocket(doInstance, 'cli-1'); + + expect(firstCli.close).toHaveBeenCalledWith(1000, 'replaced by reconnect'); + expect(parseSent(webWs)).toEqual({ + type: 'response', + id: 'cmd-1', + error: { + source: 'relay', + code: 'SESSION_OWNER_CHANGED', + message: 'Session owner changed', + }, + }); + }); + it('sends error for connection-routed pending commands on CLI disconnect', () => { const { doInstance, mockCtx } = setup(); const cliWs = addCliSocket(mockCtx, 'cli-1'); @@ -1070,6 +1208,315 @@ describe('UserConnectionDO', () => { }); }); + it('sanitizes a relay-shaped CLI error before forwarding it to web', () => { + const { doInstance, mockCtx } = setup(); + const cliWs = addCliSocket(mockCtx, 'cli-1'); + const webWs = addWebSocket(mockCtx, 'web-1'); + + sendHeartbeat(doInstance, cliWs, [makeSession('s1')]); + cliWs.send.mockClear(); + sendCommand(doInstance, webWs, { id: 'cmd-1', command: 'test', sessionId: 's1' }); + const correlationId = getCorrelationId(cliWs); + webWs.send.mockClear(); + + sendCliResponse(doInstance, cliWs, { + id: correlationId, + error: { + source: 'relay', + code: 'SESSION_OWNER_CHANGED', + message: 'Session owner changed', + }, + }); + + expect(parseSent(webWs)).toEqual({ + type: 'response', + id: 'cmd-1', + error: { + source: 'cli', + message: 'Command failed', + }, + }); + }); + + it('preserves CLI string errors for old-CLI compatibility', () => { + const { doInstance, mockCtx } = setup(); + const cliWs = addCliSocket(mockCtx, 'cli-1'); + const webWs = addWebSocket(mockCtx, 'web-1'); + + sendHeartbeat(doInstance, cliWs, [makeSession('s1')]); + cliWs.send.mockClear(); + sendCommand(doInstance, webWs, { id: 'cmd-1', command: 'list_models', sessionId: 's1' }); + const correlationId = getCorrelationId(cliWs); + webWs.send.mockClear(); + + sendCliResponse(doInstance, cliWs, { + id: correlationId, + error: 'unknown command: list_models', + }); + + expect(parseSent(webWs)).toEqual({ + type: 'response', + id: 'cmd-1', + error: 'unknown command: list_models', + }); + }); + + it('accepts a pending response only from the targeted CLI socket', () => { + const { doInstance, mockCtx } = setup(); + const targetCli = addCliSocket(mockCtx, 'cli-1'); + const otherCli = addCliSocket(mockCtx, 'cli-2'); + const webWs = addWebSocket(mockCtx, 'web-1'); + + sendHeartbeat(doInstance, targetCli, [makeSession('s1')]); + sendHeartbeat(doInstance, otherCli, []); + targetCli.send.mockClear(); + sendCommand(doInstance, webWs, { id: 'cmd-1', command: 'test', sessionId: 's1' }); + const correlationId = getCorrelationId(targetCli); + webWs.send.mockClear(); + + sendCliResponse(doInstance, otherCli, { id: correlationId, result: 'wrong-owner' }); + expect(webWs.send).not.toHaveBeenCalled(); + + sendCliResponse(doInstance, targetCli, { id: correlationId, result: 'ok' }); + expect(parseSent(webWs)).toEqual({ type: 'response', id: 'cmd-1', result: 'ok' }); + }); + + it('rejects a duplicate in-flight list_models request for the same viewer session and owner', () => { + const { doInstance, mockCtx } = setup(); + const cliWs = addCliSocket(mockCtx, 'cli-1'); + const webWs = addWebSocket(mockCtx, 'web-1'); + + sendHeartbeat(doInstance, cliWs, [makeSession('s1')]); + cliWs.send.mockClear(); + sendCommand(doInstance, webWs, { + id: 'cmd-1', + command: 'list_models', + sessionId: 's1', + connectionId: 'cli-1', + }); + sendCommand(doInstance, webWs, { + id: 'cmd-2', + command: 'list_models', + sessionId: 's1', + connectionId: 'cli-1', + }); + + expect(parseSent(webWs)).toEqual({ + type: 'response', + id: 'cmd-2', + error: { + source: 'relay', + code: 'CATALOG_REQUEST_PENDING', + message: 'Model catalog request already pending', + }, + }); + expect(allSent(cliWs).filter(message => message.type === 'command')).toHaveLength(1); + }); + + it('expires pending commands before handling another command', () => { + const now = 1_000_000; + vi.spyOn(Date, 'now').mockReturnValue(now); + const { doInstance, mockCtx } = setup(); + const cliWs = addCliSocket(mockCtx, 'cli-1'); + const webWs = addWebSocket(mockCtx, 'web-1'); + + sendHeartbeat(doInstance, cliWs, [makeSession('s1')]); + cliWs.send.mockClear(); + sendCommand(doInstance, webWs, { + id: 'cmd-1', + command: 'list_models', + sessionId: 's1', + connectionId: 'cli-1', + }); + + vi.mocked(Date.now).mockReturnValue(now + 35_001); + sendCommand(doInstance, webWs, { + id: 'cmd-2', + command: 'list_models', + sessionId: 's1', + connectionId: 'cli-1', + }); + + expect(parseSent(webWs)).toEqual({ + type: 'response', + id: 'cmd-1', + error: { + source: 'relay', + code: 'COMMAND_EXPIRED', + message: 'Command expired', + }, + }); + expect(allSent(cliWs).filter(message => message.type === 'command')).toHaveLength(2); + }); + + it('does not postpone pending-command expiry when heartbeats reschedule the alarm', () => { + const now = 1_000_000; + vi.spyOn(Date, 'now').mockReturnValue(now); + const { doInstance, mockCtx, ctx } = setup(); + const cliWs = addCliSocket(mockCtx, 'cli-1'); + const webWs = addWebSocket(mockCtx, 'web-1'); + + sendHeartbeat(doInstance, cliWs, [makeSession('s1')]); + sendCommand(doInstance, webWs, { id: 'cmd-1', command: 'test', sessionId: 's1' }); + + ctx.storage.setAlarm.mockClear(); + vi.mocked(Date.now).mockReturnValue(now + 20_000); + sendHeartbeat(doInstance, cliWs, [makeSession('s1')]); + + expect(ctx.storage.setAlarm).toHaveBeenCalledWith(now + 35_000); + }); + + it('expires pending commands during alarm processing', async () => { + const now = 1_000_000; + vi.spyOn(Date, 'now').mockReturnValue(now); + const { doInstance, mockCtx } = setup(); + const cliWs = addCliSocket(mockCtx, 'cli-1'); + const webWs = addWebSocket(mockCtx, 'web-1'); + + sendHeartbeat(doInstance, cliWs, [makeSession('s1')]); + cliWs.send.mockClear(); + sendCommand(doInstance, webWs, { id: 'cmd-1', command: 'test', sessionId: 's1' }); + const correlationId = getCorrelationId(cliWs); + + vi.mocked(Date.now).mockReturnValue(now + 34_000); + sendHeartbeat(doInstance, cliWs, [makeSession('s1')]); + webWs.send.mockClear(); + vi.mocked(Date.now).mockReturnValue(now + 35_001); + + await doInstance.alarm(); + + expect(parseSent(webWs)).toEqual({ + type: 'response', + id: 'cmd-1', + error: { + source: 'relay', + code: 'COMMAND_EXPIRED', + message: 'Command expired', + }, + }); + sendCliResponse(doInstance, cliWs, { id: correlationId, result: 'late' }); + expect(webWs.send).toHaveBeenCalledTimes(1); + }); + + it('rejects commands after reaching the global pending-command cap', () => { + const { doInstance, mockCtx } = setup(); + const cliWs = addCliSocket(mockCtx, 'cli-1'); + const webWs = addWebSocket(mockCtx, 'web-1'); + + sendHeartbeat(doInstance, cliWs, [makeSession('s1')]); + cliWs.send.mockClear(); + for (let index = 0; index < 128; index++) { + sendCommand(doInstance, webWs, { + id: `cmd-${index}`, + command: 'test', + sessionId: 's1', + }); + } + + sendCommand(doInstance, webWs, { + id: 'cmd-over-cap', + command: 'test', + sessionId: 's1', + }); + + expect(parseSent(webWs)).toEqual({ + type: 'response', + id: 'cmd-over-cap', + error: { + source: 'relay', + code: 'PENDING_COMMAND_LIMIT', + message: 'Too many pending commands', + }, + }); + expect(allSent(cliWs).filter(message => message.type === 'command')).toHaveLength(128); + }); + + it('accepts a list_models result at exactly 512 KiB', () => { + const { doInstance, mockCtx } = setup(); + const cliWs = addCliSocket(mockCtx, 'cli-1'); + const webWs = addWebSocket(mockCtx, 'web-1'); + + sendHeartbeat(doInstance, cliWs, [makeSession('s1')]); + cliWs.send.mockClear(); + sendCommand(doInstance, webWs, { + id: 'cmd-1', + command: 'list_models', + sessionId: 's1', + connectionId: 'cli-1', + }); + const correlationId = getCorrelationId(cliWs); + webWs.send.mockClear(); + const result = createResultWithSerializedBytes(MAX_CATALOG_RESULT_BYTES); + + sendCliResponse(doInstance, cliWs, { id: correlationId, result }); + + expect(parseSent(webWs)).toEqual({ type: 'response', id: 'cmd-1', result }); + }); + + it('rejects a list_models result one byte over 512 KiB', () => { + const { doInstance, mockCtx } = setup(); + const cliWs = addCliSocket(mockCtx, 'cli-1'); + const webWs = addWebSocket(mockCtx, 'web-1'); + + sendHeartbeat(doInstance, cliWs, [makeSession('s1')]); + cliWs.send.mockClear(); + sendCommand(doInstance, webWs, { + id: 'cmd-1', + command: 'list_models', + sessionId: 's1', + connectionId: 'cli-1', + }); + const correlationId = getCorrelationId(cliWs); + webWs.send.mockClear(); + + sendCliResponse(doInstance, cliWs, { + id: correlationId, + result: createResultWithSerializedBytes(MAX_CATALOG_RESULT_BYTES + 1), + }); + + expect(parseSent(webWs)).toEqual({ + type: 'response', + id: 'cmd-1', + error: { + source: 'relay', + code: 'CATALOG_TOO_LARGE', + message: 'Model catalog response is too large', + }, + }); + }); + + it('rejects a multibyte list_models result over 512 KiB', () => { + const { doInstance, mockCtx } = setup(); + const cliWs = addCliSocket(mockCtx, 'cli-1'); + const webWs = addWebSocket(mockCtx, 'web-1'); + + sendHeartbeat(doInstance, cliWs, [makeSession('s1')]); + cliWs.send.mockClear(); + sendCommand(doInstance, webWs, { + id: 'cmd-1', + command: 'list_models', + sessionId: 's1', + connectionId: 'cli-1', + }); + const correlationId = getCorrelationId(cliWs); + webWs.send.mockClear(); + + sendCliResponse(doInstance, cliWs, { + id: correlationId, + result: createUtf8OversizedResult(), + }); + + expect(parseSent(webWs)).toEqual({ + type: 'response', + id: 'cmd-1', + error: { + source: 'relay', + code: 'CATALOG_TOO_LARGE', + message: 'Model catalog response is too large', + }, + }); + }); + it('returns error when CLI not found for session', () => { const { doInstance, mockCtx } = setup(); const webWs = addWebSocket(mockCtx, 'web-1'); @@ -1088,6 +1535,37 @@ describe('UserConnectionDO', () => { }); }); + it('rejects a stale expected session owner without forwarding', () => { + const { doInstance, mockCtx } = setup(); + const currentOwner = addCliSocket(mockCtx, 'cli-1'); + const staleOwner = addCliSocket(mockCtx, 'cli-2'); + const webWs = addWebSocket(mockCtx, 'web-1'); + + sendHeartbeat(doInstance, currentOwner, [makeSession('s1')]); + sendHeartbeat(doInstance, staleOwner, []); + currentOwner.send.mockClear(); + staleOwner.send.mockClear(); + + sendCommand(doInstance, webWs, { + id: 'cmd-1', + command: 'send_message', + sessionId: 's1', + connectionId: 'cli-2', + }); + + expect(parseSent(webWs)).toEqual({ + type: 'response', + id: 'cmd-1', + error: { + source: 'relay', + code: 'SESSION_OWNER_CHANGED', + message: 'Session owner changed', + }, + }); + expect(currentOwner.send).not.toHaveBeenCalled(); + expect(staleOwner.send).not.toHaveBeenCalled(); + }); + it('routes command by connectionId to specific CLI', () => { const { doInstance, mockCtx } = setup(); const cli1 = addCliSocket(mockCtx, 'cli-1'); @@ -1528,6 +2006,22 @@ describe('UserConnectionDO', () => { doInstance.webSocketMessage(cliWs as never, 'not-json'); }); + it('logs invalid CLI JSON metadata without raw payload content', () => { + const { doInstance, mockCtx } = setup(); + const cliWs = addCliSocket(mockCtx, 'cli-1'); + const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined); + const malformed = '{"secret":"raw-secret-must-not-be-logged"'; + + doInstance.webSocketMessage(cliWs as never, malformed); + + expect(warn).toHaveBeenCalledWith('Failed to parse WebSocket message as JSON', { + role: 'cli', + connectionId: 'cli-1', + byteCount: new TextEncoder().encode(malformed).byteLength, + }); + expect(JSON.stringify(warn.mock.calls)).not.toContain('raw-secret-must-not-be-logged'); + }); + it('ignores messages from socket with no attachment', () => { const { doInstance, mockCtx } = setup(); const ws = createMockWs(['cli'], null); @@ -1554,6 +2048,28 @@ describe('UserConnectionDO', () => { // Should not throw }); + it('logs malformed CLI message metadata without raw payload content', () => { + const { doInstance, mockCtx } = setup(); + const cliWs = addCliSocket(mockCtx, 'cli-1'); + const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined); + const secret = 'raw-secret-must-not-be-logged'; + const malformed = JSON.stringify({ + type: 'response', + id: 123, + result: { secret }, + }); + + doInstance.webSocketMessage(cliWs as never, malformed); + + expect(warn).toHaveBeenCalledWith('CLI message parse failed', { + role: 'cli', + connectionId: 'cli-1', + byteCount: new TextEncoder().encode(malformed).byteLength, + issues: [{ path: ['id'], code: 'invalid_type' }], + }); + expect(JSON.stringify(warn.mock.calls)).not.toContain(secret); + }); + it('webSocketError triggers webSocketClose', () => { const { doInstance, mockCtx } = setup(); const cliWs = addCliSocket(mockCtx, 'cli-1'); diff --git a/services/session-ingest/src/dos/UserConnectionDO.ts b/services/session-ingest/src/dos/UserConnectionDO.ts index 565e912415..e4badbc73a 100644 --- a/services/session-ingest/src/dos/UserConnectionDO.ts +++ b/services/session-ingest/src/dos/UserConnectionDO.ts @@ -23,8 +23,47 @@ type WSAttachment = | { role: 'cli'; connectionId: string; sessions: HeartbeatSession[] } | { role: 'web'; connectionId: string; subscribedSessions: string[]; replaced?: true }; +export const MAX_CATALOG_RESULT_BYTES = 512 * 1024; + +const SESSION_OWNER_CHANGED_ERROR = { + source: 'relay', + code: 'SESSION_OWNER_CHANGED', + message: 'Session owner changed', +}; + +const CATALOG_TOO_LARGE_ERROR = { + source: 'relay', + code: 'CATALOG_TOO_LARGE', + message: 'Model catalog response is too large', +}; + +const CATALOG_REQUEST_PENDING_ERROR = { + source: 'relay', + code: 'CATALOG_REQUEST_PENDING', + message: 'Model catalog request already pending', +}; + +const PENDING_COMMAND_LIMIT_ERROR = { + source: 'relay', + code: 'PENDING_COMMAND_LIMIT', + message: 'Too many pending commands', +}; + +const COMMAND_EXPIRED_ERROR = { + source: 'relay', + code: 'COMMAND_EXPIRED', + message: 'Command expired', +}; + +const CLI_COMMAND_ERROR = { + source: 'cli', + message: 'Command failed', +}; + export class UserConnectionDO extends DurableObject { private static readonly HEARTBEAT_TIMEOUT_MS = 30_000; + private static readonly PENDING_COMMAND_TTL_MS = 35_000; + private static readonly MAX_PENDING_COMMANDS = 128; // Which CLI connection owns each session private sessionOwners = new Map(); @@ -35,7 +74,16 @@ export class UserConnectionDO extends DurableObject { // Pending command responses: correlationId → originating web socket private pendingCommands = new Map< string, - { ws: WebSocket; sessionId?: string; originalId: string; targetCliWs: WebSocket } + { + ws: WebSocket; + sessionId?: string; + originalId: string; + command: string; + expectedOwnerConnectionId?: string; + targetConnectionId: string; + expiresAt: number; + targetCliWs: WebSocket; + } >(); // Last heartbeat timestamp per CLI connectionId (for staleness eviction) private lastHeartbeatAt = new Map(); @@ -108,8 +156,9 @@ export class UserConnectionDO extends DurableObject { const attachment: WSAttachment = { role: 'cli', connectionId, sessions: [] }; this.ctx.acceptWebSocket(server, ['cli']); server.serializeAttachment(attachment); - this.lastHeartbeatAt.set(connectionId, Date.now()); - this.scheduleStaleCheck(); + const now = Date.now(); + this.lastHeartbeatAt.set(connectionId, now); + this.scheduleNextAlarm(now); console.log('CLI socket connected', { connectionId, @@ -152,23 +201,28 @@ export class UserConnectionDO extends DurableObject { webSocketMessage(ws: WebSocket, message: string | ArrayBuffer): void { this.ensureState(); + const attachment = ws.deserializeAttachment() as WSAttachment | null; + if (!attachment) { + console.warn('WebSocket message from socket with no attachment'); + return; + } + const raw = typeof message === 'string' ? message : new TextDecoder().decode(message); + const binaryByteCount = typeof message === 'string' ? undefined : message.byteLength; let parsed: unknown; try { parsed = JSON.parse(raw); } catch { - console.warn('Failed to parse WebSocket message as JSON'); - return; - } - - const attachment = ws.deserializeAttachment() as WSAttachment | null; - if (!attachment) { - console.warn('WebSocket message from socket with no attachment'); + console.warn('Failed to parse WebSocket message as JSON', { + role: attachment.role, + connectionId: attachment.connectionId, + byteCount: binaryByteCount ?? new TextEncoder().encode(raw).byteLength, + }); return; } if (attachment.role === 'cli') { - this.handleCliMessage(ws, attachment, parsed); + this.handleCliMessage(ws, attachment, parsed, raw, binaryByteCount); } else if (!attachment.replaced) { this.handleWebMessage(ws, attachment, parsed); } @@ -200,10 +254,11 @@ export class UserConnectionDO extends DurableObject { this.ensureState(); const now = Date.now(); + this.expirePendingCommands(now); const staleConnectionIds: string[] = []; for (const [connectionId, lastSeen] of this.lastHeartbeatAt) { - if (now - lastSeen > UserConnectionDO.HEARTBEAT_TIMEOUT_MS) { + if (now - lastSeen >= UserConnectionDO.HEARTBEAT_TIMEOUT_MS) { staleConnectionIds.push(connectionId); } } @@ -222,10 +277,7 @@ export class UserConnectionDO extends DurableObject { // via the webSocketClose callback } - // If there are still active CLI connections, schedule another check - if (this.lastHeartbeatAt.size > staleConnectionIds.length) { - this.scheduleStaleCheck(); - } + this.scheduleNextAlarm(now); } // --------------------------------------------------------------------------- @@ -235,14 +287,17 @@ export class UserConnectionDO extends DurableObject { private handleCliMessage( ws: WebSocket, attachment: WSAttachment & { role: 'cli' }, - parsed: unknown + parsed: unknown, + raw: string, + binaryByteCount: number | undefined ): void { const result = CLIOutboundMessageSchema.safeParse(parsed); if (!result.success) { console.warn('CLI message parse failed', { + role: 'cli', connectionId: attachment.connectionId, - errors: result.error.issues.map(i => i.message), - raw: JSON.stringify(parsed).slice(0, 500), + byteCount: binaryByteCount ?? new TextEncoder().encode(raw).byteLength, + issues: result.error.issues.map(issue => ({ path: issue.path, code: issue.code })), }); return; } @@ -256,7 +311,7 @@ export class UserConnectionDO extends DurableObject { this.handleCliEvent(msg.sessionId, msg.parentSessionId, msg.event, msg.data); break; case 'response': - this.handleCliResponse(msg.id, msg.result, msg.error); + this.handleCliResponse(ws, msg.id, msg.result, msg.error); break; } } @@ -267,8 +322,9 @@ export class UserConnectionDO extends DurableObject { sessions: HeartbeatSession[] ): void { const { connectionId } = attachment; - this.lastHeartbeatAt.set(connectionId, Date.now()); - this.scheduleStaleCheck(); + const now = Date.now(); + this.lastHeartbeatAt.set(connectionId, now); + this.scheduleNextAlarm(now); // Remove sessions this connection previously owned but no longer reports const previousSessions = this.connectionSessions.get(connectionId) ?? []; @@ -276,12 +332,17 @@ export class UserConnectionDO extends DurableObject { for (const prev of previousSessions) { if (!currentIds.has(prev.id) && this.sessionOwners.get(prev.id) === connectionId) { this.sessionOwners.delete(prev.id); + this.failPendingCommandsForOwnerChange(prev.id, undefined); } } // Update ownership this.connectionSessions.set(connectionId, sessions); for (const session of sessions) { + const previousOwner = this.sessionOwners.get(session.id); + if (previousOwner && previousOwner !== connectionId) { + this.failPendingCommandsForOwnerChange(session.id, connectionId); + } this.sessionOwners.set(session.id, connectionId); } @@ -351,16 +412,36 @@ export class UserConnectionDO extends DurableObject { } } - private handleCliResponse(id: string, result: unknown, error: unknown): void { + private handleCliResponse( + respondingWs: WebSocket, + id: string, + result: unknown, + error: unknown + ): void { const entry = this.pendingCommands.get(id); - if (!entry) return; + if (!entry || entry.targetCliWs !== respondingWs) return; this.pendingCommands.delete(id); + if (entry.command === 'list_models' && result !== undefined) { + const serializedResult = JSON.stringify(result); + const resultBytes = new TextEncoder().encode(serializedResult).byteLength; + if (resultBytes > MAX_CATALOG_RESULT_BYTES) { + this.sendToWeb(entry.ws, { + type: 'response', + id: entry.originalId, + error: CATALOG_TOO_LARGE_ERROR, + }); + return; + } + } + this.sendToWeb(entry.ws, { type: 'response', id: entry.originalId, ...(result !== undefined ? { result } : {}), - ...(error !== undefined ? { error } : {}), + ...(error !== undefined + ? { error: typeof error === 'string' ? error : CLI_COMMAND_ERROR } + : {}), }); } @@ -468,18 +549,24 @@ export class UserConnectionDO extends DurableObject { ws: WebSocket, msg: { id: string; command: string; sessionId?: string; connectionId?: string; data?: unknown } ): void { + const now = Date.now(); + this.expirePendingCommands(now); + // Find target CLI let targetCli: WebSocket | undefined; - if (msg.connectionId) { - // Route to specific CLI by connectionId - for (const cliWs of this.ctx.getWebSockets('cli')) { - const att = cliWs.deserializeAttachment() as WSAttachment | null; - if (att?.role === 'cli' && att.connectionId === msg.connectionId) { - targetCli = cliWs; - break; - } + if (msg.sessionId && msg.connectionId) { + targetCli = this.findCliByConnectionId(msg.connectionId); + if (this.sessionOwners.get(msg.sessionId) !== msg.connectionId || !targetCli) { + this.sendToWeb(ws, { + type: 'response', + id: msg.id, + error: SESSION_OWNER_CHANGED_ERROR, + }); + return; } + } else if (msg.connectionId) { + targetCli = this.findCliByConnectionId(msg.connectionId); } else if (msg.sessionId) { targetCli = this.findCliForSession(msg.sessionId); } else { @@ -493,13 +580,51 @@ export class UserConnectionDO extends DurableObject { return; } + const targetAttachment = targetCli.deserializeAttachment() as WSAttachment | null; + if (targetAttachment?.role !== 'cli') return; + const expectedOwnerConnectionId = + msg.sessionId && msg.connectionId ? msg.connectionId : undefined; + const targetConnectionId = targetAttachment.connectionId; + + if ( + msg.command === 'list_models' && + [...this.pendingCommands.values()].some( + entry => + entry.ws === ws && + entry.command === 'list_models' && + entry.sessionId === msg.sessionId && + entry.targetConnectionId === targetConnectionId + ) + ) { + this.sendToWeb(ws, { + type: 'response', + id: msg.id, + error: CATALOG_REQUEST_PENDING_ERROR, + }); + return; + } + + if (this.pendingCommands.size >= UserConnectionDO.MAX_PENDING_COMMANDS) { + this.sendToWeb(ws, { + type: 'response', + id: msg.id, + error: PENDING_COMMAND_LIMIT_ERROR, + }); + return; + } + const correlationId = crypto.randomUUID(); this.pendingCommands.set(correlationId, { ws, sessionId: msg.sessionId, originalId: msg.id, + command: msg.command, + expectedOwnerConnectionId, + targetConnectionId, + expiresAt: now + UserConnectionDO.PENDING_COMMAND_TTL_MS, targetCliWs: targetCli, }); + this.scheduleNextAlarm(now); this.sendToCli(targetCli, { type: 'command', @@ -670,6 +795,7 @@ export class UserConnectionDO extends DurableObject { const att = ws.deserializeAttachment() as WSAttachment | null; if (att?.role === 'cli' && att.connectionId === connectionId) { console.log('Closing stale CLI socket for reconnect', { connectionId }); + this.failPendingCommandsForSocket(ws); // Preserve session ownership — the reconnecting CLI still owns these sessions ws.close(1000, 'replaced by reconnect'); return true; @@ -705,10 +831,13 @@ export class UserConnectionDO extends DurableObject { private findCliForSession(sessionId: string): WebSocket | undefined { const ownerConnectionId = this.sessionOwners.get(sessionId); if (!ownerConnectionId) return undefined; + return this.findCliByConnectionId(ownerConnectionId); + } + private findCliByConnectionId(connectionId: string): WebSocket | undefined { for (const ws of this.ctx.getWebSockets('cli')) { const attachment = ws.deserializeAttachment() as WSAttachment | null; - if (attachment?.role === 'cli' && attachment.connectionId === ownerConnectionId) { + if (attachment?.role === 'cli' && attachment.connectionId === connectionId) { return ws; } } @@ -721,17 +850,61 @@ export class UserConnectionDO extends DurableObject { this.sendToWeb(entry.ws, { type: 'response', id: entry.originalId, - error: 'CLI disconnected', + error: entry.expectedOwnerConnectionId ? SESSION_OWNER_CHANGED_ERROR : 'CLI disconnected', }); this.pendingCommands.delete(id); } } } - private scheduleStaleCheck(): void { - // Schedule an alarm to run after the timeout period. - // setAlarm is idempotent if one is already scheduled sooner. - void this.ctx.storage.setAlarm(Date.now() + UserConnectionDO.HEARTBEAT_TIMEOUT_MS); + private failPendingCommandsForOwnerChange( + sessionId: string, + nextOwnerConnectionId: string | undefined + ): void { + for (const [id, entry] of this.pendingCommands) { + if (entry.sessionId !== sessionId || entry.targetConnectionId === nextOwnerConnectionId) { + continue; + } + this.pendingCommands.delete(id); + this.sendToWeb(entry.ws, { + type: 'response', + id: entry.originalId, + error: SESSION_OWNER_CHANGED_ERROR, + }); + } + } + + private expirePendingCommands(now: number): void { + for (const [id, entry] of this.pendingCommands) { + if (entry.expiresAt > now) continue; + this.pendingCommands.delete(id); + this.sendToWeb(entry.ws, { + type: 'response', + id: entry.originalId, + error: COMMAND_EXPIRED_ERROR, + }); + } + } + + private scheduleNextAlarm(now: number): void { + let nextAlarmAt: number | undefined; + + for (const lastSeen of this.lastHeartbeatAt.values()) { + const staleAt = lastSeen + UserConnectionDO.HEARTBEAT_TIMEOUT_MS; + if (staleAt > now && (nextAlarmAt === undefined || staleAt < nextAlarmAt)) { + nextAlarmAt = staleAt; + } + } + + for (const entry of this.pendingCommands.values()) { + if (entry.expiresAt > now && (nextAlarmAt === undefined || entry.expiresAt < nextAlarmAt)) { + nextAlarmAt = entry.expiresAt; + } + } + + if (nextAlarmAt !== undefined) { + void this.ctx.storage.setAlarm(nextAlarmAt); + } } private aggregateSessions(): Array { From 22a1c88dd8ea8f5c9d44223b98bbe2d949ea4caa Mon Sep 17 00:00:00 2001 From: Evgeny Shurakov Date: Tue, 30 Jun 2026 15:48:26 +0200 Subject: [PATCH 2/4] Fix circular dependency between ModelCombobox and model-combobox-options --- .../src/components/shared/ModelCombobox.tsx | 23 +++---------------- .../shared/model-combobox-options.ts | 21 ++++++++++++++++- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/apps/web/src/components/shared/ModelCombobox.tsx b/apps/web/src/components/shared/ModelCombobox.tsx index 4245d17746..349d71486c 100644 --- a/apps/web/src/components/shared/ModelCombobox.tsx +++ b/apps/web/src/components/shared/ModelCombobox.tsx @@ -18,8 +18,11 @@ import { cn } from '@/lib/utils'; import { buildModelOptionGroups, getModelOptionKeywords, + type ModelOption, type ModelOptionGroup, } from './model-combobox-options'; + +export type { ModelOption }; import { Tooltip, TooltipContent, TooltipTrigger } from '@/components/ui/tooltip'; import { formatShortModelDisplayName } from '@/lib/format-model-name'; import { @@ -32,26 +35,6 @@ import { mayTrainOnYourPrompts, } from '@/components/shared/free-model-data-disclosure'; -export type ModelOption = { - id: string; // e.g., "anthropic/claude-sonnet-4.5" - name: string; // e.g., "Claude Sonnet 4.5" - /** Exact user-facing ID when `id` is an opaque selection value. */ - displayId?: string; - /** Optional provider group for provider-aware catalogs. */ - providerGroup?: { id: string; label: string }; - /** Additional user-facing search terms. Opaque selection values stay excluded. */ - searchTerms?: string[]; - supportsVision?: boolean; - supportsReasoning?: boolean; - isFree?: boolean; - mayTrainOnYourPrompts?: boolean; - hasUserByokAvailable?: boolean; - showGatewayMetadata?: boolean; - unavailable?: boolean; - /** Ordered list of variant key names (e.g., ["none","low","medium","high","max"]) */ - variants?: string[]; -}; - export type ModelComboboxProps = { label?: string; helperText?: string; diff --git a/apps/web/src/components/shared/model-combobox-options.ts b/apps/web/src/components/shared/model-combobox-options.ts index 9e9b2bbad5..751c32d210 100644 --- a/apps/web/src/components/shared/model-combobox-options.ts +++ b/apps/web/src/components/shared/model-combobox-options.ts @@ -1,5 +1,24 @@ import { preferredModels } from '@/lib/ai-gateway/models'; -import type { ModelOption } from './ModelCombobox'; + +export type ModelOption = { + id: string; // e.g., "anthropic/claude-sonnet-4.5" + name: string; // e.g., "Claude Sonnet 4.5" + /** Exact user-facing ID when `id` is an opaque selection value. */ + displayId?: string; + /** Optional provider group for provider-aware catalogs. */ + providerGroup?: { id: string; label: string }; + /** Additional user-facing search terms. Opaque selection values stay excluded. */ + searchTerms?: string[]; + supportsVision?: boolean; + supportsReasoning?: boolean; + isFree?: boolean; + mayTrainOnYourPrompts?: boolean; + hasUserByokAvailable?: boolean; + showGatewayMetadata?: boolean; + unavailable?: boolean; + /** Ordered list of variant key names (e.g., ["none","low","medium","high","max"]) */ + variants?: string[]; +}; export type ModelOptionGroup = { id: string; From d33a510471af3c1cac2c4fdea536ba8c85a5b219 Mon Sep 17 00:00:00 2001 From: Evgeny Shurakov Date: Tue, 30 Jun 2026 15:59:54 +0200 Subject: [PATCH 3/4] Fix session test remote model catalog wire shape --- apps/web/src/lib/cloud-agent-sdk/session.test.ts | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/apps/web/src/lib/cloud-agent-sdk/session.test.ts b/apps/web/src/lib/cloud-agent-sdk/session.test.ts index eed3982c2a..a8fd86615e 100644 --- a/apps/web/src/lib/cloud-agent-sdk/session.test.ts +++ b/apps/web/src/lib/cloud-agent-sdk/session.test.ts @@ -374,8 +374,11 @@ describe('remote session transport state', () => { subscribeToCliSession: jest.fn(() => jest.fn()), sendCommand: jest.fn(() => Promise.resolve({ + all: [], + default: {}, + connected: [], + failed: [], protocolVersion: 1, - providers: [], truncated: false, }) ), From 2fedc735fbd2d42637007bfd0e83fa90c49d090c Mon Sep 17 00:00:00 2001 From: Evgeny Shurakov Date: Thu, 2 Jul 2026 21:51:04 +0200 Subject: [PATCH 4/4] feat(cloud-agent): order and flag CLI-configured remote models Sort remote model catalog providers/models by recommendation rank and surface isFree/mayTrainOnYourPrompts/hasUserByokAvailable flags through the wire schema so the model picker can render TUI-parity ordering and footers. Add onReplayComplete to the transport/session layer so consumers can tell a replayed message.updated from a live one. --- .../agents/session-detail-content.tsx | 5 - .../lib/hooks/use-session-model-options.ts | 26 +-- apps/mobile/src/lib/picker-bridge.test.ts | 35 ---- .../src/lib/use-session-model-options.test.ts | 174 +++++++++++++++++- apps/mobile/vitest.config.ts | 3 + .../hooks/useSessionModels.test.ts | 174 +++++++++++++++++- .../hooks/useSessionModels.ts | 33 ++-- .../src/components/shared/ModelCombobox.tsx | 39 ++-- .../cli-live-transport.test.ts | 33 +++- .../lib/cloud-agent-sdk/cli-live-transport.ts | 1 + apps/web/src/lib/cloud-agent-sdk/index.ts | 3 + .../remote-model-catalog.test.ts | 74 +++++++- .../cloud-agent-sdk/remote-model-catalog.ts | 5 + .../lib/cloud-agent-sdk/remote-model-order.ts | 74 ++++++++ apps/web/src/lib/cloud-agent-sdk/schemas.ts | 55 ++++-- .../cloud-agent-sdk/session-manager.test.ts | 141 +++++++++++++- .../lib/cloud-agent-sdk/session-manager.ts | 57 ++++-- apps/web/src/lib/cloud-agent-sdk/session.ts | 2 + apps/web/src/lib/cloud-agent-sdk/transport.ts | 7 + 19 files changed, 800 insertions(+), 141 deletions(-) create mode 100644 apps/web/src/lib/cloud-agent-sdk/remote-model-order.ts diff --git a/apps/mobile/src/components/agents/session-detail-content.tsx b/apps/mobile/src/components/agents/session-detail-content.tsx index d0365f704f..7820925d39 100644 --- a/apps/mobile/src/components/agents/session-detail-content.tsx +++ b/apps/mobile/src/components/agents/session-detail-content.tsx @@ -218,11 +218,6 @@ export function SessionDetailContent({ sessionId }: Readonly { if (activeSessionType === 'remote') { - if (pickerSelection?.option.action === 'use-session-model') { - manager.setRemoteModelOverride(null); - return; - } - const selectedRef = pickerSelection?.option.modelRef; const option = selectedRef ? modelOptions.find( diff --git a/apps/mobile/src/lib/hooks/use-session-model-options.ts b/apps/mobile/src/lib/hooks/use-session-model-options.ts index 5a5ec6760e..44ddda0bf5 100644 --- a/apps/mobile/src/lib/hooks/use-session-model-options.ts +++ b/apps/mobile/src/lib/hooks/use-session-model-options.ts @@ -7,6 +7,7 @@ import { type RemoteModelState, type ResolvedSession, } from 'cloud-agent-sdk'; +import { sortRemoteModelCatalogProviders } from 'cloud-agent-sdk/remote-model-order'; import { type ModelOption } from '@/lib/hooks/use-available-models'; @@ -36,7 +37,6 @@ export type SessionModelOption = { overrideSource?: RemoteModelOverride['source']; showGatewayMetadata: boolean; unavailable?: boolean; - action?: 'use-session-model'; }; type BuildSessionModelOptionsInput = { @@ -213,9 +213,6 @@ function buildLegacyGatewayOptions(input: BuildSessionModelOptionsInput): Sessio options.unshift(selectedOption); notices.push(createUnavailableNotice(currentSelection.model)); } - if (remoteModelOverride) { - options.unshift(createUseSessionModelOption()); - } const selectedVariant = currentSelection?.variant && selectedOption?.variants.includes(currentSelection.variant) @@ -240,7 +237,7 @@ function buildCliCatalogOptions(input: BuildSessionModelOptionsInput): SessionMo } let opaqueIndex = 0; - const options = catalog.providers.flatMap(provider => + const options = sortRemoteModelCatalogProviders(catalog.providers).flatMap(provider => provider.models.map(model => { const option: SessionModelOption = { id: `remote-model-${opaqueIndex}`, @@ -248,6 +245,9 @@ function buildCliCatalogOptions(input: BuildSessionModelOptionsInput): SessionMo displayId: model.id, variants: model.variants, isPreferred: false, + isFree: model.isFree, + mayTrainOnYourPrompts: model.mayTrainOnYourPrompts, + hasUserByokAvailable: model.hasUserByokAvailable, provider: { id: provider.id, name: provider.name ?? provider.id }, modelRef: { providerID: provider.id, modelID: model.id }, overrideSource: 'cli-catalog', @@ -270,9 +270,6 @@ function buildCliCatalogOptions(input: BuildSessionModelOptionsInput): SessionMo options.unshift(selectedOption); notices.push(createUnavailableNotice(currentSelection.model)); } - if (input.remoteModelOverride) { - options.unshift(createUseSessionModelOption()); - } if (input.remoteModelState.refresh === 'error') { notices.unshift({ id: 'stale', @@ -313,19 +310,6 @@ function buildCliCatalogOptions(input: BuildSessionModelOptionsInput): SessionMo }; } -function createUseSessionModelOption(): SessionModelOption { - return { - id: 'remote-use-session-model', - name: 'Use session model', - displayId: 'Stop overriding the model selected by the CLI', - variants: [], - isPreferred: false, - provider: { id: 'session', name: 'Session' }, - showGatewayMetadata: false, - action: 'use-session-model', - }; -} - function createUnavailableOption(modelRef: ModelRef): SessionModelOption { return { id: 'remote-unavailable-model', diff --git a/apps/mobile/src/lib/picker-bridge.test.ts b/apps/mobile/src/lib/picker-bridge.test.ts index 9189500e16..5d6494ea1a 100644 --- a/apps/mobile/src/lib/picker-bridge.test.ts +++ b/apps/mobile/src/lib/picker-bridge.test.ts @@ -132,39 +132,4 @@ describe('model picker bridge', () => { expect(commitModelPickerSelection(bridge, remoteOption.id, 'high')).toBe(true); expect(onSelect).toHaveBeenCalledWith({ option: remoteOption, variant: 'high' }); }); - - it('passes the Use session model action through without inventing an override', () => { - const resetOption = { - id: 'remote-use-session-model', - name: 'Use session model', - displayId: 'Stop overriding the model selected by the CLI', - variants: [], - isPreferred: false, - provider: { id: 'session', name: 'Session' }, - showGatewayMetadata: false, - action: 'use-session-model' as const, - }; - const onSelect = vi.fn(); - setModelPickerBridge({ - ...currentSelectionScope, - options: [resetOption], - currentValue: remoteOption.id, - currentVariant: 'high', - onSelect: selection => { - onSelect(selection); - }, - }); - - const bridge = getModelPickerBridge(); - if (!bridge) { - throw new Error('Expected model picker bridge'); - } - const selection = resolveModelPickerSelection(bridge, resetOption.id, 'high'); - if (!selection) { - throw new Error('Expected model picker selection'); - } - bridge.onSelect(selection); - - expect(onSelect).toHaveBeenCalledWith({ option: resetOption, variant: '' }); - }); }); diff --git a/apps/mobile/src/lib/use-session-model-options.test.ts b/apps/mobile/src/lib/use-session-model-options.test.ts index d5cc935e12..8228f49c4d 100644 --- a/apps/mobile/src/lib/use-session-model-options.test.ts +++ b/apps/mobile/src/lib/use-session-model-options.test.ts @@ -1,3 +1,4 @@ +/* eslint-disable max-lines -- Model option tests mirror the SDK/web suite. */ import { describe, expect, it } from 'vitest'; import { @@ -136,6 +137,172 @@ describe('buildSessionModelOptions', () => { expect(result.selectedVariant).toBe('high'); }); + it('sorts provider-aware CLI rows like the CLI TUI picker', () => { + const result = buildSessionModelOptions({ + activeSessionType: 'remote', + remoteModelState: { + ownerConnectionId: 'cli-owner', + protocol: 'v1', + refresh: 'idle', + catalog: { + protocolVersion: 1, + truncated: false, + providers: [ + { + id: 'zeta-provider', + name: 'Zeta Provider', + models: [ + { + id: 'zeta-model', + name: 'Zeta Model', + variants: [], + capabilities: { attachment: false, reasoning: false }, + limits: { context: 32_000, output: 4096 }, + }, + ], + }, + { + id: 'kilo', + name: 'Kilo Gateway', + models: [ + { + id: 'kilo-later', + name: 'Kilo Later', + recommendedIndex: 2, + variants: [], + capabilities: { attachment: false, reasoning: false }, + limits: { context: 32_000, output: 4096 }, + }, + { + id: 'kilo-first', + name: 'Kilo First', + recommendedIndex: 0, + variants: [], + capabilities: { attachment: false, reasoning: false }, + limits: { context: 32_000, output: 4096 }, + }, + ], + }, + { + id: 'alpha-provider', + name: 'Alpha Provider', + models: [ + { + id: 'beta-model', + name: 'Beta Model', + variants: [], + capabilities: { attachment: false, reasoning: false }, + limits: { context: 32_000, output: 4096 }, + }, + { + id: 'alpha-model', + name: 'Alpha Model', + variants: [], + capabilities: { attachment: false, reasoning: false }, + limits: { context: 32_000, output: 4096 }, + }, + ], + }, + { + id: 'opencode', + name: 'OpenCode', + models: [ + { + id: 'z-model', + name: 'Z Model', + isFree: true, + variants: [], + capabilities: { attachment: false, reasoning: false }, + limits: { context: 32_000, output: 4096 }, + }, + { + id: 'a-model', + name: 'A Model', + isFree: true, + variants: [], + capabilities: { attachment: false, reasoning: false }, + limits: { context: 32_000, output: 4096 }, + }, + ], + }, + ], + }, + }, + observedModel: null, + remoteModelOverride: null, + gatewayModels, + gatewayModelsLoading: false, + organizationId: 'org-persisted', + }); + + expect(result.options.map(option => option.modelRef)).toEqual([ + { providerID: 'opencode', modelID: 'a-model' }, + { providerID: 'opencode', modelID: 'z-model' }, + { providerID: 'alpha-provider', modelID: 'alpha-model' }, + { providerID: 'alpha-provider', modelID: 'beta-model' }, + { providerID: 'kilo', modelID: 'kilo-first' }, + { providerID: 'kilo', modelID: 'kilo-later' }, + { providerID: 'zeta-provider', modelID: 'zeta-model' }, + ]); + }); + + it('uses Kilo Gateway recommendedIndex ranks from CLI metadata', () => { + const result = buildSessionModelOptions({ + activeSessionType: 'remote', + remoteModelState: { + ownerConnectionId: 'cli-owner', + protocol: 'v1', + refresh: 'idle', + catalog: { + protocolVersion: 1, + truncated: false, + providers: [ + { + id: 'kilo', + name: 'Kilo Gateway', + models: [ + { + id: 'zzz-unranked-model', + name: 'AAA Unranked Model', + variants: [], + capabilities: { attachment: false, reasoning: false }, + limits: { context: 32_000, output: 4096 }, + }, + { + id: 'anthropic/claude-sonnet-4.6', + name: 'ZZZ Claude Sonnet', + recommendedIndex: 1, + variants: [], + capabilities: { attachment: true, reasoning: true }, + limits: { context: 1_000_000, output: 128_000 }, + }, + { + id: 'kilo-auto/efficient', + name: 'MMM Auto Efficient', + recommendedIndex: 0, + variants: [], + capabilities: { attachment: true, reasoning: true }, + limits: { context: 1_000_000, output: 65_536 }, + }, + ], + }, + ], + }, + }, + observedModel: null, + remoteModelOverride: null, + gatewayModels, + gatewayModelsLoading: false, + organizationId: 'org-persisted', + }); + + expect(result.options.map(option => option.modelRef)).toEqual([ + { providerID: 'kilo', modelID: 'kilo-auto/efficient' }, + { providerID: 'kilo', modelID: 'anthropic/claude-sonnet-4.6' }, + { providerID: 'kilo', modelID: 'zzz-unranked-model' }, + ]); + }); + it('uses Gateway rows only for an exact legacy CLI and preserves override provenance', () => { const result = buildSessionModelOptions({ activeSessionType: 'remote', @@ -201,7 +368,7 @@ describe('buildSessionModelOptions', () => { expect(result.selectedValue).toBe('gateway/model'); expect(result.selectedVariant).toBe('high'); expect(result.options.some(option => option.unavailable)).toBe(false); - expect(result.options.some(option => option.action === 'use-session-model')).toBe(false); + expect(result.options.some(option => option.name === 'Use session model')).toBe(false); }); it('disables model changes when remote discovery fails without exposing Gateway rows', () => { @@ -229,7 +396,7 @@ describe('buildSessionModelOptions', () => { expect(result.notices).toEqual([expect.objectContaining({ id: 'error', retry: true })]); }); - it('retains a stale v1 catalog with reset, truncation, and local-provider notices', () => { + it('retains a stale v1 catalog with truncation and local-provider notices', () => { const result = buildSessionModelOptions({ activeSessionType: 'remote', remoteModelState: { @@ -269,7 +436,6 @@ describe('buildSessionModelOptions', () => { organizationId: 'org-persisted', }); - const resetOption = result.options.find(option => option.action === 'use-session-model'); const cliOption = result.options.find(option => option.overrideSource === 'cli-catalog'); expect(result.notices.map(notice => notice.id)).toEqual([ @@ -278,8 +444,6 @@ describe('buildSessionModelOptions', () => { 'local-provider', ]); expect(result.notices[2]?.message).toContain("organization's model restrictions"); - expect(resetOption).toMatchObject({ name: 'Use session model', action: 'use-session-model' }); - expect(createRemoteModelOverride(resetOption, 'high')).toBeNull(); expect(createRemoteModelOverride(cliOption, 'removed')).toEqual({ source: 'cli-catalog', selection: { model: { providerID: 'local-provider', modelID: 'private-model' } }, diff --git a/apps/mobile/vitest.config.ts b/apps/mobile/vitest.config.ts index 4f0e22d42d..c0605c173d 100644 --- a/apps/mobile/vitest.config.ts +++ b/apps/mobile/vitest.config.ts @@ -6,6 +6,9 @@ export default defineConfig({ resolve: { alias: { '@': fileURLToPath(new URL('./src', import.meta.url)), + 'cloud-agent-sdk/remote-model-order': fileURLToPath( + new URL('../web/src/lib/cloud-agent-sdk/remote-model-order.ts', import.meta.url) + ), }, }, test: { diff --git a/apps/web/src/components/cloud-agent-next/hooks/useSessionModels.test.ts b/apps/web/src/components/cloud-agent-next/hooks/useSessionModels.test.ts index be70abb985..e0df9c052a 100644 --- a/apps/web/src/components/cloud-agent-next/hooks/useSessionModels.test.ts +++ b/apps/web/src/components/cloud-agent-next/hooks/useSessionModels.test.ts @@ -155,6 +155,175 @@ describe('buildSessionModels', () => { expect(result.availableVariants).toEqual(['low', 'high']); }); + it('sorts v1 CLI catalog options like the CLI TUI picker', () => { + const result = buildSessionModels({ + activeSessionType: 'remote', + remoteModelState: { + ownerConnectionId: 'cli-owner', + protocol: 'v1', + refresh: 'idle', + catalog: { + protocolVersion: 1, + truncated: false, + providers: [ + { + id: 'zeta-provider', + name: 'Zeta Provider', + models: [ + { + id: 'zeta-model', + name: 'Zeta Model', + variants: [], + capabilities: { attachment: false, reasoning: false }, + limits: { context: 32_000, output: 4_096 }, + }, + ], + }, + { + id: 'kilo', + name: 'Kilo Gateway', + models: [ + { + id: 'kilo-later', + name: 'Kilo Later', + recommendedIndex: 2, + variants: [], + capabilities: { attachment: false, reasoning: false }, + limits: { context: 32_000, output: 4_096 }, + }, + { + id: 'kilo-first', + name: 'Kilo First', + recommendedIndex: 0, + variants: [], + capabilities: { attachment: false, reasoning: false }, + limits: { context: 32_000, output: 4_096 }, + }, + ], + }, + { + id: 'alpha-provider', + name: 'Alpha Provider', + models: [ + { + id: 'beta-model', + name: 'Beta Model', + variants: [], + capabilities: { attachment: false, reasoning: false }, + limits: { context: 32_000, output: 4_096 }, + }, + { + id: 'alpha-model', + name: 'Alpha Model', + variants: [], + capabilities: { attachment: false, reasoning: false }, + limits: { context: 32_000, output: 4_096 }, + }, + ], + }, + { + id: 'opencode', + name: 'OpenCode', + models: [ + { + id: 'z-model', + name: 'Z Model', + isFree: true, + variants: [], + capabilities: { attachment: false, reasoning: false }, + limits: { context: 32_000, output: 4_096 }, + }, + { + id: 'a-model', + name: 'A Model', + isFree: true, + variants: [], + capabilities: { attachment: false, reasoning: false }, + limits: { context: 32_000, output: 4_096 }, + }, + ], + }, + ], + }, + }, + observedModel: null, + remoteModelOverride: null, + gatewayModels, + gatewayModelsLoading: false, + }); + + expect(result.modelOptions.map(option => option.modelRef)).toEqual([ + { providerID: 'opencode', modelID: 'a-model' }, + { providerID: 'opencode', modelID: 'z-model' }, + { providerID: 'alpha-provider', modelID: 'alpha-model' }, + { providerID: 'alpha-provider', modelID: 'beta-model' }, + { providerID: 'kilo', modelID: 'kilo-first' }, + { providerID: 'kilo', modelID: 'kilo-later' }, + { providerID: 'zeta-provider', modelID: 'zeta-model' }, + ]); + }); + + it('uses Kilo Gateway recommendedIndex ranks and a Recommended group from CLI metadata', () => { + const result = buildSessionModels({ + activeSessionType: 'remote', + remoteModelState: { + ownerConnectionId: 'cli-owner', + protocol: 'v1', + refresh: 'idle', + catalog: { + protocolVersion: 1, + truncated: false, + providers: [ + { + id: 'kilo', + name: 'Kilo Gateway', + models: [ + { + id: 'zzz-unranked-model', + name: 'AAA Unranked Model', + variants: [], + capabilities: { attachment: false, reasoning: false }, + limits: { context: 32_000, output: 4_096 }, + }, + { + id: 'anthropic/claude-sonnet-4.6', + name: 'ZZZ Claude Sonnet', + recommendedIndex: 1, + variants: [], + capabilities: { attachment: true, reasoning: true }, + limits: { context: 1_000_000, output: 128_000 }, + }, + { + id: 'kilo-auto/efficient', + name: 'MMM Auto Efficient', + recommendedIndex: 0, + variants: [], + capabilities: { attachment: true, reasoning: true }, + limits: { context: 1_000_000, output: 65_536 }, + }, + ], + }, + ], + }, + }, + observedModel: null, + remoteModelOverride: null, + gatewayModels, + gatewayModelsLoading: false, + }); + + expect(result.modelOptions.map(option => option.modelRef)).toEqual([ + { providerID: 'kilo', modelID: 'kilo-auto/efficient' }, + { providerID: 'kilo', modelID: 'anthropic/claude-sonnet-4.6' }, + { providerID: 'kilo', modelID: 'zzz-unranked-model' }, + ]); + expect(result.modelOptions.map(option => option.providerGroup)).toEqual([ + { id: 'kilo-recommended', label: 'Recommended' }, + { id: 'kilo-recommended', label: 'Recommended' }, + { id: 'kilo', label: 'Kilo Gateway' }, + ]); + }); + it('uses persisted-organization Gateway fallback only for an exact legacy CLI', () => { const result = buildSessionModels({ activeSessionType: 'remote', @@ -233,7 +402,7 @@ describe('buildSessionModels', () => { expect(result.notices).toEqual([expect.objectContaining({ id: 'error', retry: true })]); }); - it('keeps a stale v1 catalog with reset, truncation, and local-provider disclosure', () => { + it('keeps a stale v1 catalog with truncation and local-provider disclosure', () => { const result = buildSessionModels({ activeSessionType: 'remote', remoteModelState: { @@ -273,7 +442,6 @@ describe('buildSessionModels', () => { gatewayOrganizationId: 'org-persisted', }); - const resetOption = result.modelOptions.find(option => option.action === 'use-session-model'); const cliOption = result.modelOptions.find(option => option.overrideSource === 'cli-catalog'); expect(result.source).toBe('remote-cli-catalog'); @@ -283,8 +451,6 @@ describe('buildSessionModels', () => { 'local-provider', ]); expect(result.notices[2].message).toContain("organization's model restrictions"); - expect(resetOption).toMatchObject({ name: 'Use session model', action: 'use-session-model' }); - expect(createRemoteModelOverride(resetOption, 'high')).toBeNull(); expect(createRemoteModelOverride(cliOption, 'removed-variant')).toEqual({ source: 'cli-catalog', selection: { diff --git a/apps/web/src/components/cloud-agent-next/hooks/useSessionModels.ts b/apps/web/src/components/cloud-agent-next/hooks/useSessionModels.ts index a943aa3309..429c32c928 100644 --- a/apps/web/src/components/cloud-agent-next/hooks/useSessionModels.ts +++ b/apps/web/src/components/cloud-agent-next/hooks/useSessionModels.ts @@ -1,7 +1,9 @@ import { useMemo } from 'react'; import { createModelRefKeyMap, + isRemoteModelRecommended, modelRefsEqual, + sortRemoteModelCatalogProviders, type FetchedSessionData, type ModelRef, type ModelSelection, @@ -39,7 +41,6 @@ type SessionModelOption = ModelOption & { modelRef?: ModelRef; overrideSource?: RemoteModelOverride['source']; unavailable?: boolean; - action?: 'use-session-model'; }; type BuildSessionModelsInput = { @@ -230,9 +231,6 @@ function buildLegacyGatewayModels(input: BuildSessionModelsInput): SessionModels modelOptions.unshift(unavailableOption); selectedOption = unavailableOption; } - if (input.remoteModelOverride) { - modelOptions.unshift(createUseSessionModelOption()); - } const selectedVariant = currentSelection?.variant && selectedOption?.variants?.includes(currentSelection.variant) @@ -264,19 +262,27 @@ function buildCliCatalogModels(input: BuildSessionModelsInput): SessionModels { if (!catalog) throw new Error('CLI catalog is required for v1 model options'); const keyMap = createModelRefKeyMap(); - const modelOptions: SessionModelOption[] = catalog.providers.flatMap(provider => + const modelOptions: SessionModelOption[] = sortRemoteModelCatalogProviders( + catalog.providers + ).flatMap(provider => provider.models.map(model => { const modelRef = { providerID: provider.id, modelID: model.id }; return { id: keyMap.getOrCreateKey(modelRef), name: model.name ?? model.id, displayId: model.id, - providerGroup: { id: provider.id, label: provider.name ?? provider.id }, + providerGroup: + provider.id === 'kilo' && isRemoteModelRecommended(provider.id, model) + ? { id: 'kilo-recommended', label: 'Recommended' } + : { id: provider.id, label: provider.name ?? provider.id }, searchTerms: [provider.id, provider.name, model.id, model.name].filter( (term): term is string => term !== undefined ), supportsVision: model.capabilities.attachment, supportsReasoning: model.capabilities.reasoning, + isFree: model.isFree, + mayTrainOnYourPrompts: model.mayTrainOnYourPrompts, + hasUserByokAvailable: model.hasUserByokAvailable, showGatewayMetadata: false, variants: model.variants, modelRef, @@ -296,9 +302,6 @@ function buildCliCatalogModels(input: BuildSessionModelsInput): SessionModels { modelOptions.unshift(unavailableOption); selectedOption = unavailableOption; } - if (input.remoteModelOverride) { - modelOptions.unshift(createUseSessionModelOption()); - } const selectedVariant = currentSelection?.variant && selectedOption?.variants?.includes(currentSelection.variant) @@ -351,18 +354,6 @@ function currentRemoteSelection(input: BuildSessionModelsInput): ModelSelection ); } -function createUseSessionModelOption(): SessionModelOption { - return { - id: 'remote-use-session-model', - name: 'Use session model', - displayId: 'Stop overriding the model selected by the CLI', - providerGroup: { id: 'session', label: 'Session' }, - searchTerms: ['session', 'default', 'observed'], - showGatewayMetadata: false, - action: 'use-session-model', - }; -} - function createUnavailableOption(modelRef: ModelRef): SessionModelOption { const keyMap = createModelRefKeyMap(); return { diff --git a/apps/web/src/components/shared/ModelCombobox.tsx b/apps/web/src/components/shared/ModelCombobox.tsx index 349d71486c..99488e8940 100644 --- a/apps/web/src/components/shared/ModelCombobox.tsx +++ b/apps/web/src/components/shared/ModelCombobox.tsx @@ -309,20 +309,10 @@ function ModelOptionGroups({
{model.name} {model.supportsVision === true && ( - - - - - Supports vision - + )} {model.supportsReasoning === true && ( - - - - - Supports reasoning - + )} {model.showGatewayMetadata !== false && } {model.unavailable && ( @@ -348,6 +338,21 @@ function ModelOptionGroups({ )); } +/** + * Row-level icon hint used in the model list. Options can number in the + * hundreds (a CLI's full connected-provider catalog), and Radix's Tooltip + * mounts a Portal-backed component tree per instance, so using it here makes + * opening the list itself slow. A native `title` gives the same hover text + * for a single lightweight span. + */ +function RowIconHint({ icon: Icon, label }: { icon: typeof Image; label: string }) { + return ( + + + + ); +} + function FreeModelDataIcon({ compact = false }: { compact?: boolean }) { return ( @@ -385,7 +390,15 @@ function ModelMetadataBadges({ model }: { model: ModelOption }) { {BYOK_MODEL_LABEL} )} - {collectsData && } + {collectsData && ( + + + + )} ); } diff --git a/apps/web/src/lib/cloud-agent-sdk/cli-live-transport.test.ts b/apps/web/src/lib/cloud-agent-sdk/cli-live-transport.test.ts index 324169efb5..51a566e626 100644 --- a/apps/web/src/lib/cloud-agent-sdk/cli-live-transport.test.ts +++ b/apps/web/src/lib/cloud-agent-sdk/cli-live-transport.test.ts @@ -134,6 +134,7 @@ function createTransportWithSinks(opts?: { const userWebConnection = opts?.connection ?? createConnection(); const chatEvents: ChatEvent[] = []; const serviceEvents: ServiceEvent[] = []; + let replayCompleteCount = 0; const transport = createCliLiveTransport({ kiloSessionId: KILO_SESSION_ID, userWebConnection, @@ -144,8 +145,17 @@ function createTransportWithSinks(opts?: { })({ onChatEvent: event => chatEvents.push(event), onServiceEvent: event => serviceEvents.push(event), + onReplayComplete: () => { + replayCompleteCount += 1; + }, }); - return { userWebConnection, transport, chatEvents, serviceEvents }; + return { + userWebConnection, + transport, + chatEvents, + serviceEvents, + getReplayCompleteCount: () => replayCompleteCount, + }; } function emitOwner(connection: FakeUserWebConnection, connectionId = 'owner'): void { @@ -604,6 +614,27 @@ describe('CliLiveTransport unified user web connection', () => { transport.destroy(); }); + it('signals replay completion only after the snapshot has been drained', async () => { + let resolveSnapshot: ((snapshot: SessionSnapshot) => void) | undefined; + const fetchSnapshot = jest.fn( + () => + new Promise(resolve => { + resolveSnapshot = resolve; + }) + ); + const { transport, getReplayCompleteCount } = createTransportWithSinks({ fetchSnapshot }); + transport.connect(); + + expect(getReplayCompleteCount()).toBe(0); + + resolveSnapshot?.(makeSnapshot({ id: KILO_SESSION_ID }, [])); + await Promise.resolve(); + await Promise.resolve(); + + expect(getReplayCompleteCount()).toBe(1); + transport.destroy(); + }); + it('applies a live session.updated after stale snapshot metadata', async () => { let resolveSnapshot: ((snapshot: SessionSnapshot) => void) | undefined; const fetchSnapshot = jest.fn( diff --git a/apps/web/src/lib/cloud-agent-sdk/cli-live-transport.ts b/apps/web/src/lib/cloud-agent-sdk/cli-live-transport.ts index c5cbe1d5aa..b95eedbb7f 100644 --- a/apps/web/src/lib/cloud-agent-sdk/cli-live-transport.ts +++ b/apps/web/src/lib/cloud-agent-sdk/cli-live-transport.ts @@ -343,6 +343,7 @@ function createCliLiveTransport(config: CliLiveTransportConfig): TransportFactor const drainBufferedCliEvents = (): void => { const events = bufferedCliEvents; bufferedCliEvents = null; + sink.onReplayComplete?.(); for (const msg of events ?? []) { handleEventMessage(msg.sessionId, msg.parentSessionId, msg.event, msg.data); } diff --git a/apps/web/src/lib/cloud-agent-sdk/index.ts b/apps/web/src/lib/cloud-agent-sdk/index.ts index 632cbdda82..5065115bf0 100644 --- a/apps/web/src/lib/cloud-agent-sdk/index.ts +++ b/apps/web/src/lib/cloud-agent-sdk/index.ts @@ -78,11 +78,14 @@ export { REMOTE_MODEL_MAX_VARIANTS_PER_MODEL, REMOTE_MODEL_MAX_VARIANTS_TOTAL, createModelRefKeyMap, + getRemoteModelRecommendedRank, + isRemoteModelRecommended, modelRefSchema, modelRefsEqual, modelSelectionSchema, remoteModelCatalogV1Schema, remoteModelCatalogWireV1Schema, + sortRemoteModelCatalogProviders, } from './remote-model-catalog'; export type { ModelRef, diff --git a/apps/web/src/lib/cloud-agent-sdk/remote-model-catalog.test.ts b/apps/web/src/lib/cloud-agent-sdk/remote-model-catalog.test.ts index 558c9f3cf9..42093f29a2 100644 --- a/apps/web/src/lib/cloud-agent-sdk/remote-model-catalog.test.ts +++ b/apps/web/src/lib/cloud-agent-sdk/remote-model-catalog.test.ts @@ -36,9 +36,16 @@ function createSdkModel(providerID: string, id: string, variants: string[] = [], }; } +type SdkModelFixture = ReturnType & { + recommendedIndex?: number; + isFree?: boolean; + mayTrainOnYourPrompts?: boolean; + hasUserByokAvailable?: boolean; +}; + function createSdkProvider( id: string, - models: ReturnType[] = [createSdkModel(id, `model-${id}`)] + models: SdkModelFixture[] = [createSdkModel(id, `model-${id}`)] ) { return { id, @@ -140,9 +147,17 @@ function createUtf8OversizedCatalog() { describe('remoteModelCatalogV1Schema', () => { it('normalizes the SDK ProviderListResponse shape without rewriting model identities', () => { - const connected = createSdkProvider('custom/provider:v1', [ - createSdkModel('custom/provider:v1', 'team/model.v2-beta', ['reasoning/high'], 'Team model'), - ]); + const model: SdkModelFixture = createSdkModel( + 'custom/provider:v1', + 'team/model.v2-beta', + ['reasoning/high'], + 'Team model' + ); + model.recommendedIndex = 4; + model.isFree = true; + model.mayTrainOnYourPrompts = false; + model.hasUserByokAvailable = true; + const connected = createSdkProvider('custom/provider:v1', [model]); connected.name = 'Private deployment'; const disconnected = createSdkProvider('disconnected'); const wire = { @@ -166,6 +181,10 @@ describe('remoteModelCatalogV1Schema', () => { id: 'team/model.v2-beta', name: 'Team model', variants: ['reasoning/high'], + recommendedIndex: 4, + isFree: true, + mayTrainOnYourPrompts: false, + hasUserByokAvailable: true, capabilities: { attachment: true, reasoning: true }, limits: { context: 128_000, output: 16_000 }, }, @@ -178,6 +197,53 @@ describe('remoteModelCatalogV1Schema', () => { }); }); + it('orders providers and models to match the CLI TUI picker', () => { + const zeta = createSdkProvider('zeta-provider'); + zeta.name = 'Zeta Provider'; + const alpha = createSdkProvider('alpha-provider', [ + createSdkModel('alpha-provider', 'beta', [], 'Beta'), + createSdkModel('alpha-provider', 'alpha', [], 'Alpha'), + ]); + alpha.name = 'Alpha Provider'; + const kiloLater = { + ...createSdkModel('kilo', 'kilo-later', [], 'Kilo Later'), + recommendedIndex: 2, + }; + const kiloFirst = { + ...createSdkModel('kilo', 'kilo-first', [], 'Kilo First'), + recommendedIndex: 0, + }; + const kiloByok = { + ...createSdkModel('kilo', 'kilo-byok', [], 'Kilo BYOK'), + hasUserByokAvailable: true, + }; + const kilo = createSdkProvider('kilo', [kiloLater, kiloByok, kiloFirst]); + kilo.name = 'Kilo Gateway'; + const opencode = createSdkProvider('opencode', [ + { ...createSdkModel('opencode', 'z-model', [], 'Z Model'), isFree: true }, + { ...createSdkModel('opencode', 'a-model', [], 'A Model'), isFree: true }, + ]); + opencode.name = 'OpenCode'; + + const parsed = remoteModelCatalogV1Schema.parse( + createWireCatalog([zeta, kilo, alpha, opencode]) + ); + + expect(parsed.providers.map(provider => provider.id)).toEqual([ + 'opencode', + 'alpha-provider', + 'kilo', + 'zeta-provider', + ]); + expect(parsed.providers[0]?.models.map(model => model.id)).toEqual(['a-model', 'z-model']); + expect(parsed.providers[1]?.models.map(model => model.id)).toEqual(['alpha', 'beta']); + expect(parsed.providers[2]?.models.map(model => model.id)).toEqual([ + 'kilo-first', + 'kilo-later', + 'kilo-byok', + ]); + }); + it('rejects duplicate provider IDs and inconsistent model identities', () => { const duplicate = createSdkProvider('provider'); expect( diff --git a/apps/web/src/lib/cloud-agent-sdk/remote-model-catalog.ts b/apps/web/src/lib/cloud-agent-sdk/remote-model-catalog.ts index 05981c522f..e79c19bee1 100644 --- a/apps/web/src/lib/cloud-agent-sdk/remote-model-catalog.ts +++ b/apps/web/src/lib/cloud-agent-sdk/remote-model-catalog.ts @@ -17,6 +17,11 @@ export { type RemoteModelCatalogV1, type RemoteModelCatalogWireV1, } from './schemas'; +export { + getRemoteModelRecommendedRank, + isRemoteModelRecommended, + sortRemoteModelCatalogProviders, +} from './remote-model-order'; // Catalog strings are user/plugin-controlled metadata and may be private. // Treat them as display data, never executable config or independent telemetry. diff --git a/apps/web/src/lib/cloud-agent-sdk/remote-model-order.ts b/apps/web/src/lib/cloud-agent-sdk/remote-model-order.ts new file mode 100644 index 0000000000..045fdd6450 --- /dev/null +++ b/apps/web/src/lib/cloud-agent-sdk/remote-model-order.ts @@ -0,0 +1,74 @@ +type SortableRemoteModel = { + id: string; + name?: string; + recommendedIndex?: number; + isFree?: boolean; + mayTrainOnYourPrompts?: boolean; + hasUserByokAvailable?: boolean; +}; + +type SortableRemoteProvider = { + id: string; + name?: string; + models: readonly SortableRemoteModel[]; +}; + +function providerName(provider: { id: string; name?: string }): string { + return provider.name ?? provider.id; +} + +function modelName(model: SortableRemoteModel): string { + return model.name ?? model.id; +} + +export function getRemoteModelRecommendedRank( + providerId: string, + model: SortableRemoteModel +): number { + if (providerId !== 'kilo') return 0; + return model.recommendedIndex ?? Number.POSITIVE_INFINITY; +} + +export function isRemoteModelRecommended(providerId: string, model: SortableRemoteModel): boolean { + return Number.isFinite(getRemoteModelRecommendedRank(providerId, model)); +} + +function hasTuiFooter(providerId: string, model: SortableRemoteModel): boolean { + return ( + (providerId === 'kilo' && + (model.hasUserByokAvailable === true || model.mayTrainOnYourPrompts === true)) || + (providerId === 'opencode' && model.isFree === true) + ); +} + +function compareProvider(left: SortableRemoteProvider, right: SortableRemoteProvider): number { + const leftOpenCode = left.id === 'opencode' ? 0 : 1; + const rightOpenCode = right.id === 'opencode' ? 0 : 1; + return leftOpenCode - rightOpenCode || providerName(left).localeCompare(providerName(right)); +} + +function compareModel( + providerId: string, + left: SortableRemoteModel, + right: SortableRemoteModel +): number { + return ( + getRemoteModelRecommendedRank(providerId, left) - + getRemoteModelRecommendedRank(providerId, right) || + Number(!hasTuiFooter(providerId, left)) - Number(!hasTuiFooter(providerId, right)) || + modelName(left).localeCompare(modelName(right)) + ); +} + +export function sortRemoteModelCatalogProviders< + TProvider extends { id: string; name?: string; models: readonly SortableRemoteModel[] }, +>( + providers: readonly TProvider[] +): Array & { models: Array }> { + return providers + .map(provider => ({ + ...provider, + models: [...provider.models].sort((left, right) => compareModel(provider.id, left, right)), + })) + .sort(compareProvider); +} diff --git a/apps/web/src/lib/cloud-agent-sdk/schemas.ts b/apps/web/src/lib/cloud-agent-sdk/schemas.ts index 313614dbd4..db4fbf2630 100644 --- a/apps/web/src/lib/cloud-agent-sdk/schemas.ts +++ b/apps/web/src/lib/cloud-agent-sdk/schemas.ts @@ -1,4 +1,5 @@ import * as z from 'zod'; +import { sortRemoteModelCatalogProviders } from './remote-model-order'; // --------------------------------------------------------------------------- // Wire-level envelope @@ -170,6 +171,10 @@ const remoteSdkModelSchema = z options: emptyRemoteModelRecordSchema, headers: emptyRemoteModelRecordSchema, release_date: z.literal(''), + recommendedIndex: z.number().finite().optional(), + isFree: z.boolean().optional(), + mayTrainOnYourPrompts: z.boolean().optional(), + hasUserByokAvailable: z.boolean().optional(), variants: z.record(remoteModelIdentitySchema, emptyRemoteModelRecordSchema).optional(), }) .strict(); @@ -290,25 +295,37 @@ export const remoteModelCatalogV1Schema = remoteModelCatalogWireV1Schema.transfo const connected = new Set(catalog.connected); return { protocolVersion: 1 as const, - providers: catalog.all - .filter(provider => connected.has(provider.id)) - .map(provider => ({ - id: provider.id, - ...(provider.name ? { name: provider.name } : {}), - models: Object.values(provider.models).map(model => ({ - id: model.id, - ...(model.name ? { name: model.name } : {}), - variants: Object.keys(model.variants ?? {}), - capabilities: { - attachment: model.capabilities.attachment, - reasoning: model.capabilities.reasoning, - }, - limits: { - context: model.limit.context, - output: model.limit.output, - }, - })), - })), + providers: sortRemoteModelCatalogProviders( + catalog.all + .filter(provider => connected.has(provider.id)) + .map(provider => ({ + id: provider.id, + ...(provider.name ? { name: provider.name } : {}), + models: Object.values(provider.models).map(model => ({ + id: model.id, + ...(model.name ? { name: model.name } : {}), + ...(model.recommendedIndex !== undefined + ? { recommendedIndex: model.recommendedIndex } + : {}), + ...(model.isFree !== undefined ? { isFree: model.isFree } : {}), + ...(model.mayTrainOnYourPrompts !== undefined + ? { mayTrainOnYourPrompts: model.mayTrainOnYourPrompts } + : {}), + ...(model.hasUserByokAvailable !== undefined + ? { hasUserByokAvailable: model.hasUserByokAvailable } + : {}), + variants: Object.keys(model.variants ?? {}), + capabilities: { + attachment: model.capabilities.attachment, + reasoning: model.capabilities.reasoning, + }, + limits: { + context: model.limit.context, + output: model.limit.output, + }, + })), + })) + ), ...(catalog.currentModel ? { currentModel: catalog.currentModel } : {}), ...(catalog.defaultModel ? { defaultModel: catalog.defaultModel } : {}), truncated: catalog.truncated, diff --git a/apps/web/src/lib/cloud-agent-sdk/session-manager.test.ts b/apps/web/src/lib/cloud-agent-sdk/session-manager.test.ts index 3676f798cb..f3d182676a 100644 --- a/apps/web/src/lib/cloud-agent-sdk/session-manager.test.ts +++ b/apps/web/src/lib/cloud-agent-sdk/session-manager.test.ts @@ -58,6 +58,7 @@ const mockSession = { const mockSessionCallbacks: { onSessionCreated?: (info: SessionInfo) => void; onSessionUpdated?: (info: SessionInfo) => void; + onReplayComplete?: () => void; onQuestionAsked?: (...args: unknown[]) => void; onQuestionResolved?: (...args: unknown[]) => void; @@ -87,6 +88,7 @@ jest.mock('./session', () => ({ storage: JotaiSessionStorage; onSessionCreated?: (info: SessionInfo) => void; onSessionUpdated?: (info: SessionInfo) => void; + onReplayComplete?: () => void; onQuestionAsked?: (...args: unknown[]) => void; onQuestionResolved?: (...args: unknown[]) => void; @@ -121,6 +123,7 @@ jest.mock('./session', () => ({ }); mockSessionCallbacks.onSessionCreated = sessionConfig.onSessionCreated; mockSessionCallbacks.onSessionUpdated = sessionConfig.onSessionUpdated; + mockSessionCallbacks.onReplayComplete = sessionConfig.onReplayComplete; mockSessionCallbacks.onQuestionAsked = sessionConfig.onQuestionAsked; mockSessionCallbacks.onQuestionResolved = sessionConfig.onQuestionResolved; mockSessionCallbacks.onPermissionAsked = sessionConfig.onPermissionAsked; @@ -714,6 +717,35 @@ describe('createSessionManager', () => { }); }); + it('lets a live message override a session-set model once replay has finished', async () => { + // Regression test: `session.updated` can report a stale/default model + // that never changes for a per-request override (the wrapper sends the + // override straight through without persisting it to the session), so + // once we're live, a message's own reported model must win. + const config = createMockConfig(); + const mgr = createSessionManager(config); + + await mgr.switchSession(kiloId('ses-1')); + mockSessionCallbacks.onSessionCreated?.({ + id: 'ses-1', + model: { providerID: 'openai', id: 'gpt-5', variant: 'high' }, + }); + mockSessionCallbacks.onReplayComplete?.(); + + mockSessionCallbacks.onEvent?.({ + type: 'message.updated', + info: stubUserMessage({ + id: 'msg-live', + sessionID: 'ses-1', + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + }), + }); + + expect(atomValue(config.store, mgr.atoms.observedModel)).toEqual({ + model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' }, + }); + }); + it('uses a replayed root message model when session metadata has no model', async () => { const config = createMockConfig(); const mgr = createSessionManager(config); @@ -765,6 +797,40 @@ describe('createSessionManager', () => { }); }); + it('keeps a message-observed model when the catalog current model arrives afterward', async () => { + // Snapshot replay and catalog discovery are two independent async + // round-trips racing on first load. A session with history should + // land on the model its last message actually used, not whichever of + // the two requests happened to finish last. + const config = createMockConfig(); + const mgr = createSessionManager(config); + + await mgr.switchSession(kiloId('ses-1')); + mockSessionCallbacks.onResolved?.({ type: 'remote', kiloSessionId: kiloId('ses-1') }); + + mockSessionCallbacks.onEvent?.({ + type: 'message.updated', + info: createStoredAssistantMessage('msg-history', 'ses-1', { + providerID: 'kilo', + modelID: 'anthropic/claude-sonnet-4', + }).info, + }); + + mockSessionCallbacks.onRemoteModelStateChange?.({ + ownerConnectionId: 'owner-a', + protocol: 'v1', + catalog: { + ...remoteCatalog, + currentModel: { model: { providerID: 'openai', modelID: 'gpt-5' } }, + }, + refresh: 'idle', + }); + + expect(atomValue(config.store, mgr.atoms.observedModel)).toEqual({ + model: { providerID: 'kilo', modelID: 'anthropic/claude-sonnet-4' }, + }); + }); + it('applies a live session.updated model while retaining the explicit override', async () => { const config = createMockConfig(); const mgr = createSessionManager(config); @@ -809,7 +875,7 @@ describe('createSessionManager', () => { expect(atomValue(config.store, mgr.atoms.remoteModelOverride)).toEqual(override); }); - it('keeps explicit override separate from observations and clears it on owner change', async () => { + it('keeps an explicit override through a still-replaying observation, but clears it on owner change', async () => { const config = createMockConfig(); const mgr = createSessionManager(config); const catalog = { @@ -835,6 +901,9 @@ describe('createSessionManager', () => { } as const; mgr.setRemoteModelOverride(override); + // onReplayComplete hasn't fired yet, so this message is still treated + // as replayed history and must not clear the override (see the + // dedicated "live" divergence test below for the post-replay case). mockSessionCallbacks.onEvent?.({ type: 'message.updated', info: createStoredAssistantMessage('msg-assistant', 'ses-1', { @@ -861,6 +930,76 @@ describe('createSessionManager', () => { }); }); + it('clears a stale override once a live message shows the CLI actually used a different model', async () => { + const config = createMockConfig(); + const mgr = createSessionManager(config); + + await mgr.switchSession(kiloId('ses-1')); + mockSessionCallbacks.onResolved?.({ type: 'remote', kiloSessionId: kiloId('ses-1') }); + mockSessionCallbacks.onRemoteModelStateChange?.({ + ownerConnectionId: 'owner-a', + protocol: 'v1', + catalog: remoteCatalog, + refresh: 'idle', + }); + mgr.setRemoteModelOverride({ + source: 'cli-catalog', + selection: { model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' } }, + }); + + // Initial connect has finished replaying whatever history existed. + mockSessionCallbacks.onReplayComplete?.(); + + mockSessionCallbacks.onEvent?.({ + type: 'message.updated', + info: createStoredAssistantMessage('msg-live', 'ses-1', { + providerID: 'openai', + modelID: 'gpt-5', + }).info, + }); + + expect(atomValue(config.store, mgr.atoms.remoteModelOverride)).toBeNull(); + expect(atomValue(config.store, mgr.atoms.observedModel)).toEqual({ + model: { providerID: 'openai', modelID: 'gpt-5' }, + }); + }); + + it('keeps a fresh override intact through a reconnect that replays pre-override history', async () => { + const config = createMockConfig(); + const mgr = createSessionManager(config); + + await mgr.switchSession(kiloId('ses-1')); + mockSessionCallbacks.onResolved?.({ type: 'remote', kiloSessionId: kiloId('ses-1') }); + mockSessionCallbacks.onRemoteModelStateChange?.({ + ownerConnectionId: 'owner-a', + protocol: 'v1', + catalog: remoteCatalog, + refresh: 'idle', + }); + mockSessionCallbacks.onReplayComplete?.(); + + const override = { + source: 'cli-catalog', + selection: { model: { providerID: 'anthropic', modelID: 'claude-sonnet-4' } }, + } as const; + mgr.setRemoteModelOverride(override); + + // A reconnect starts a fresh replay before the override was ever used. + mockSessionCallbacks.onSessionCreated?.({ id: 'ses-1' }); + mockSessionCallbacks.onEvent?.({ + type: 'message.updated', + info: createStoredAssistantMessage('msg-old', 'ses-1', { + providerID: 'openai', + modelID: 'gpt-5', + }).info, + }); + + expect(atomValue(config.store, mgr.atoms.remoteModelOverride)).toEqual(override); + + mockSessionCallbacks.onReplayComplete?.(); + expect(atomValue(config.store, mgr.atoms.remoteModelOverride)).toEqual(override); + }); + it('clears an explicit override when the same owner changes to an incompatible protocol', async () => { const config = createMockConfig(); const mgr = createSessionManager(config); diff --git a/apps/web/src/lib/cloud-agent-sdk/session-manager.ts b/apps/web/src/lib/cloud-agent-sdk/session-manager.ts index 8d0047d55b..3c8b388ef2 100644 --- a/apps/web/src/lib/cloud-agent-sdk/session-manager.ts +++ b/apps/web/src/lib/cloud-agent-sdk/session-manager.ts @@ -444,6 +444,9 @@ function createSessionManager(config: SessionManagerConfig): SessionManager { let currentSession: CloudAgentSession | null = null; let activeSessionType: ActiveSessionType | null = null; let observedModelSource: ObservedModelSource | null = null; + // True while a connect/reconnect cycle is still replaying its message + // history; false once live events are flowing. See clearOverrideIfDiverged. + let remoteHistoryReplaying = true; let stateUnsub: (() => void) | null = null; let indicatorTimer: ReturnType | null = null; let childSessionHydrationGeneration = 0; @@ -473,6 +476,7 @@ function createSessionManager(config: SessionManagerConfig): SessionManager { store.set(remoteModelStateAtom, EMPTY_REMOTE_MODEL_STATE); store.set(observedModelAtom, null); observedModelSource = null; + remoteHistoryReplaying = true; store.set(remoteModelOverrideAtom, null); store.set(canSendAtom, false); store.set(canInterruptAtom, false); @@ -586,6 +590,20 @@ function createSessionManager(config: SessionManagerConfig): SessionManager { } } + // A web-picked override should stop applying once we see live proof the + // CLI actually ran a message on a different model — otherwise the picker + // gets stuck showing a choice that's no longer what's being sent. Gated on + // `remoteHistoryReplaying` so a reconnect's replayed history (which can + // predate the override) can't wipe a selection that just hasn't been used + // yet. + function clearOverrideIfDiverged(model: ModelSelection): void { + if (remoteHistoryReplaying) return; + const override = store.get(remoteModelOverrideAtom); + if (override && !modelRefsEqual(override.selection.model, model.model)) { + store.set(remoteModelOverrideAtom, null); + } + } + function handleRemoteModelStateChange(state: RemoteModelState): void { const previousOwnerConnectionId = store.get(remoteModelStateAtom).ownerConnectionId; store.set(remoteModelStateAtom, state); @@ -795,6 +813,9 @@ function createSessionManager(config: SessionManagerConfig): SessionManager { // cast cloudAgentSessionId (the createAndStart path). store.set(rootSessionIdAtom, info.id); store.set(isLoadingAtom, false); + // A fresh replay is starting (initial connect or a reconnect); + // onReplayComplete flips this back off once it's done. + remoteHistoryReplaying = true; if (info.model) { updateObservedModel( toModelSelection( @@ -860,6 +881,9 @@ function createSessionManager(config: SessionManagerConfig): SessionManager { onTransportCapabilityChange: () => { if (currentSession === session) updateCapabilityAtoms(session); }, + onReplayComplete: () => { + remoteHistoryReplaying = false; + }, onBranchChanged: branch => { const currentFetched = store.get(fetchedSessionDataAtom); @@ -888,24 +912,33 @@ function createSessionManager(config: SessionManagerConfig): SessionManager { const rootSessionId = store.get(rootSessionIdAtom); if (rootSessionId !== null && event.info.sessionID !== rootSessionId) return; + // A live message always wins: it's the freshest, most specific proof + // of what model actually ran for this turn, more reliable than + // `session.updated` (which can lag behind or never fire for a + // per-request override that doesn't change the session's persisted + // default). During the initial replay, only suppress this when + // `session.created` already claimed a value for this connect cycle + // — its snapshot-time value is fresher than an older replayed + // message, but if it never had a model to begin with there's + // nothing fresher to protect. + const canApplyMessageObservation = + !remoteHistoryReplaying || observedModelSource !== 'session'; if (event.info.role === 'user') { - if (observedModelSource !== 'session') { - updateObservedModel( - toModelSelection(event.info.model, event.info.variant), - 'message' - ); + if (canApplyMessageObservation) { + const selection = toModelSelection(event.info.model, event.info.variant); + updateObservedModel(selection, 'message'); + clearOverrideIfDiverged(selection); } return; } - if (observedModelSource !== 'session') { - updateObservedModel( - toModelSelection( - { providerID: event.info.providerID, modelID: event.info.modelID }, - event.info.variant - ), - 'message' + if (canApplyMessageObservation) { + const selection = toModelSelection( + { providerID: event.info.providerID, modelID: event.info.modelID }, + event.info.variant ); + updateObservedModel(selection, 'message'); + clearOverrideIfDiverged(selection); } // `info.agent` is the agent slug (e.g. 'code', 'e-code'); `info.mode` diff --git a/apps/web/src/lib/cloud-agent-sdk/session.ts b/apps/web/src/lib/cloud-agent-sdk/session.ts index 0aa28d42d7..b0a1d4f12e 100644 --- a/apps/web/src/lib/cloud-agent-sdk/session.ts +++ b/apps/web/src/lib/cloud-agent-sdk/session.ts @@ -68,6 +68,7 @@ type CloudAgentSessionConfig = { onTransportCapabilityChange?: () => void; onSessionCreated?: (info: SessionInfo) => void; onSessionUpdated?: (info: SessionInfo) => void; + onReplayComplete?: () => void; onEvent?: (event: NormalizedEvent) => void; onMessageQueued?: (messageId: string) => void; onMessageCompleted?: (messageId: string) => void; @@ -198,6 +199,7 @@ function createCloudAgentSession(config: CloudAgentSessionConfig): CloudAgentSes } config.onEvent?.(event); }, + onReplayComplete: () => config.onReplayComplete?.(), }; function pickTransportFactory(resolved: ResolvedSession): TransportFactory { diff --git a/apps/web/src/lib/cloud-agent-sdk/transport.ts b/apps/web/src/lib/cloud-agent-sdk/transport.ts index f385c12145..0ed8a21454 100644 --- a/apps/web/src/lib/cloud-agent-sdk/transport.ts +++ b/apps/web/src/lib/cloud-agent-sdk/transport.ts @@ -22,6 +22,13 @@ type CloudAgentStreamTicketResult = string | CloudAgentStreamTicket; type TransportSink = { onChatEvent: (event: ChatEvent) => void; onServiceEvent: (event: ServiceEvent) => void; + /** + * Fired once a connect/reconnect cycle has finished replaying history and + * has switched to delivering live events. Lets consumers distinguish a + * historical `message.updated` (replayed from a snapshot, may be stale) + * from a live one (happened just now, authoritative). + */ + onReplayComplete?: () => void; }; /**