Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions apps/mobile/src/app/(app)/agent-chat/model-picker.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/* eslint-disable max-lines */
import * as Haptics from 'expo-haptics';
import { useFocusEffect, useRouter } from 'expo-router';
import { BookOpenCheck, Check, Search } from 'lucide-react-native';
import { BookOpenCheck, Check, Search, Star } from 'lucide-react-native';
import { useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { FlatList, Pressable, ScrollView, TextInput, View } from 'react-native';
import { useSafeAreaInsets } from 'react-native-safe-area-context';
Expand All @@ -17,6 +18,7 @@ import {
} from '@/lib/free-model-data-disclosure';
import { useThemeColors } from '@/lib/hooks/use-theme-colors';
import { type ModelOption, thinkingEffortLabel } from '@/lib/hooks/use-available-models';
import { useModelPreferences } from '@/lib/hooks/use-model-preferences';
import { buildModelPickerRows, type ModelPickerRow } from '@/lib/model-picker-rows';
import { clearModelPickerBridge, getModelPickerBridge } from '@/lib/picker-bridge';

Expand All @@ -33,6 +35,7 @@ export default function ModelPickerScreen() {
const { bottom } = useSafeAreaInsets();
const [search, setSearch] = useState('');
const [bridge, setBridge] = useState(() => getModelPickerBridge());
const { favorites, toggleFavorite } = useModelPreferences(undefined);

const [selectedModel, setSelectedModel] = useState(bridge?.currentValue ?? '');
const [selectedVariant, setSelectedVariant] = useState(bridge?.currentVariant ?? '');
Expand All @@ -46,6 +49,8 @@ export default function ModelPickerScreen() {
router.back();
}, [router]);

const favoriteIds = useMemo(() => new Set(favorites), [favorites]);

useFocusEffect(
useCallback(() => {
const nextBridge = getModelPickerBridge();
Expand Down Expand Up @@ -96,8 +101,8 @@ export default function ModelPickerScreen() {
}, [currentModelOption]);

const rows = useMemo<ModelPickerRow[]>(
() => buildModelPickerRows({ models: bridge?.options ?? [], search }),
[bridge, search]
() => buildModelPickerRows({ models: bridge?.options ?? [], search, favoriteIds }),
[bridge, search, favoriteIds]
);

const handleSelectVariant = useCallback(
Expand Down Expand Up @@ -203,6 +208,7 @@ export default function ModelPickerScreen() {
}

const modelOption = item.model;
const isFavorite = item.isFavorite;
const selected = modelOption.id === selectedModel;
const free = isFreeModelOption(modelOption);
const byok = hasUserByokAvailable(modelOption);
Expand Down Expand Up @@ -263,6 +269,29 @@ export default function ModelPickerScreen() {
</View>
) : null}
</View>
{!isCliModel && (
<Pressable
onPress={() => {
void Haptics.selectionAsync();
toggleFavorite(modelOption.id);
}}
hitSlop={12}
className="min-h-[44px] min-w-[44px] items-center justify-center"
accessibilityRole="button"
accessibilityLabel={
isFavorite
? `Remove ${modelOption.name} from favorites`
: `Add ${modelOption.name} to favorites`
}
accessibilityState={{ selected: isFavorite }}
>
<Star
size={20}
color={isFavorite ? colors.primary : colors.mutedForeground}
fill={isFavorite ? colors.primary : 'transparent'}
/>
</Pressable>
)}
{selected && <Check size={18} color={colors.primary} />}
</Pressable>

Expand Down
51 changes: 15 additions & 36 deletions apps/mobile/src/app/(app)/agent-chat/new.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* eslint-disable max-lines -- New-session screen bundles closely related prompt/toolbar/repository concerns in a single component to keep navigation props colocated. */
import { useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useCallback, useMemo, useRef, useState } from 'react';
import {
ActivityIndicator,
type LayoutChangeEvent,
Expand Down Expand Up @@ -39,7 +39,8 @@ import {
} from '@/lib/agent-attachments/use-agent-attachment-upload';
import { WEB_BASE_URL } from '@/lib/config';
import { useAvailableModels } from '@/lib/hooks/use-available-models';
import { contextKey, resolveModelForContext } from '@/lib/hooks/agent-model-preference';
import { useAutoSelectModel } from '@/lib/hooks/use-auto-select-model';
import { useModelPreferences } from '@/lib/hooks/use-model-preferences';
import { usePersistedAgentModel } from '@/lib/hooks/use-persisted-agent-model';
import { useThemeColors } from '@/lib/hooks/use-theme-colors';
import { trpcClient, useTRPC } from '@/lib/trpc';
Expand Down Expand Up @@ -92,41 +93,18 @@ export default function NewSessionScreen() {

// ── Models ───────────────────────────────────────────────────────
const { models } = useAvailableModels(organizationId);
const {
hasLoaded: modelPrefLoaded,
stored: storedModelPref,
saveModel,
} = usePersistedAgentModel();
const { setLastSelected: persistServerLastSelected } = useModelPreferences(organizationId);
const { saveModel } = usePersistedAgentModel();
const autoSelected = useAutoSelectModel(models, organizationId);
const attachments = useAgentAttachmentUpload({ organizationId });

// Auto-select first model when models load, preferring the persisted preference
const hasAutoSelectedModel = useRef(false);
useEffect(() => {
if (hasAutoSelectedModel.current) {
return;
}
// Never overwrite a model the user already picked manually.
if (model) {
hasAutoSelectedModel.current = true;
return;
}
if (models.length === 0 || !modelPrefLoaded) {
return;
}

const persisted = resolveModelForContext(storedModelPref, contextKey(organizationId), models);
if (persisted) {
setModel(persisted.model);
setVariant(persisted.variant);
} else {
const firstModel = models[0];
if (firstModel) {
setModel(firstModel.id);
setVariant(firstModel.variants[0] ?? '');
}
}
hasAutoSelectedModel.current = true;
}, [models, modelPrefLoaded, storedModelPref, organizationId, model]);
// Apply auto-selected model when the user hasn't picked one yet.
const hasAppliedAutoSelection = useRef(false);
if (!hasAppliedAutoSelection.current && autoSelected.model && !model) {
hasAppliedAutoSelection.current = true;
setModel(autoSelected.model);
setVariant(autoSelected.variant);
}

// ── Repositories ─────────────────────────────────────────────────
const trpc = useTRPC();
Expand Down Expand Up @@ -168,8 +146,9 @@ export default function NewSessionScreen() {
setModel(modelId);
setVariant(newVariant);
saveModel(organizationId, { model: modelId, variant: newVariant });
persistServerLastSelected({ model: modelId, ...(newVariant ? { variant: newVariant } : {}) });
},
[organizationId, saveModel]
[organizationId, saveModel, persistServerLastSelected]
);

const handleOpenGitHubIntegration = useCallback(async () => {
Expand Down
6 changes: 6 additions & 0 deletions apps/mobile/src/components/agents/session-detail-content.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import { Text } from '@/components/ui/text';
import { type AgentAttachmentWire } from '@/lib/agent-attachments/use-agent-attachment-upload';
import { useAppLifecycle } from '@/lib/hooks/use-app-lifecycle';
import { useAvailableModels } from '@/lib/hooks/use-available-models';
import { useModelPreferences } from '@/lib/hooks/use-model-preferences';
import { usePersistedAgentModel } from '@/lib/hooks/use-persisted-agent-model';
import { useReasoningPreference } from '@/lib/hooks/use-reasoning-preference';

Expand Down Expand Up @@ -77,6 +78,7 @@ export function SessionDetailContent({ sessionId }: Readonly<SessionDetailConten

const { models: modelOptions } = useAvailableModels(organizationId);
const { saveModel: savePersistedModel } = usePersistedAgentModel();
const { setLastSelected: persistServerLastSelected } = useModelPreferences(organizationId);
const { defaultExpanded: reasoningDefaultExpanded } = useReasoningPreference();
const isRemote = sessionType === 'remote';
const composerModelOptions = useMemo(
Expand Down Expand Up @@ -294,6 +296,10 @@ export function SessionDetailContent({ sessionId }: Readonly<SessionDetailConten
// persisting it would clobber the real preference.
if (modelId !== CLI_MODEL_ID) {
savePersistedModel(organizationId, { model: modelId, variant: newVariant });
persistServerLastSelected({
model: modelId,
...(newVariant ? { variant: newVariant } : {}),
});
}
}}
organizationId={organizationId}
Expand Down
52 changes: 52 additions & 0 deletions apps/mobile/src/lib/hooks/use-auto-select-model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import { useRef } from 'react';

import { contextKey, resolveModelForContext } from '@/lib/hooks/agent-model-preference';
import { type ModelOption, useOrgDefaultModel } from '@/lib/hooks/use-available-models';
import { useModelPreferences } from '@/lib/hooks/use-model-preferences';
import { usePersistedAgentModel } from '@/lib/hooks/use-persisted-agent-model';

function pickVariant(model: ModelOption, preferredVariant: string | undefined): string {
if (preferredVariant && model.variants.includes(preferredVariant)) {
return preferredVariant;
}
return model.variants[0] ?? '';
}

const NO_SELECTION = { model: '', variant: '' };

export function useAutoSelectModel(
models: ModelOption[],
organizationId: string | undefined
): { model: string; variant: string } {
const { lastSelected, isLoading } = useModelPreferences(organizationId);
const { defaultModel: orgDefaultModel, isLoading: orgDefaultIsLoading } =
useOrgDefaultModel(organizationId);
const { stored, hasLoaded } = usePersistedAgentModel();
const chosenRef = useRef<{ model: string; variant: string } | null>(null);

if (chosenRef.current) {
return chosenRef.current;
}
// Wait for the server preference and org default too, or the shared value
// loses the race against the local cache on cold start and is never applied.
if (isLoading || orgDefaultIsLoading || !hasLoaded || models.length === 0) {
return NO_SELECTION;
}
const serverMatch = lastSelected ? models.find(m => m.id === lastSelected.model) : undefined;
const localEntry = resolveModelForContext(stored, contextKey(organizationId), models);
const orgDefaultMatch = orgDefaultModel ? models.find(m => m.id === orgDefaultModel) : undefined;
const fallback = orgDefaultMatch ?? models[0];
if (serverMatch) {
chosenRef.current = {
model: serverMatch.id,
variant: pickVariant(serverMatch, lastSelected?.variant),
};
} else if (localEntry) {
chosenRef.current = localEntry;
} else if (fallback) {
chosenRef.current = { model: fallback.id, variant: pickVariant(fallback, undefined) };
} else {
return NO_SELECTION;
}
return chosenRef.current;
}
32 changes: 31 additions & 1 deletion apps/mobile/src/lib/hooks/use-available-models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,37 @@ async function fetchModels(organizationId: string | undefined): Promise<ModelRes
}
}

// ── Hook ─────────────────────────────────────────────────────────────
async function fetchOrgDefaults(organizationId: string): Promise<{ defaultModel: string }> {
const token = await SecureStore.getItemAsync(AUTH_TOKEN_KEY);
const response = await fetch(`${API_BASE_URL}/api/organizations/${organizationId}/defaults`, {
headers: {
Accept: 'application/json',
...(token ? { Authorization: `Bearer ${token}` } : {}),
},
});
if (!response.ok) {
throw new Error(`Failed to fetch org defaults: ${response.status} ${response.statusText}`);
}
return (await response.json()) as { defaultModel: string };
}

// ── Hooks ────────────────────────────────────────────────────────────

export function useOrgDefaultModel(organizationId: string | undefined) {
const { data, isLoading } = useQuery({
queryKey: ['org-default-model', organizationId] as const,
queryFn: async () => {
if (!organizationId) {
throw new Error('Missing organizationId');
}
const defaults = await fetchOrgDefaults(organizationId);
return defaults;
},
enabled: Boolean(organizationId),
staleTime: 60_000,
});
return { defaultModel: data?.defaultModel, isLoading };
}

export function useAvailableModels(organizationId: string | undefined) {
const { data, isLoading } = useQuery({
Expand Down
Loading