Explorar el Código

fix(app): model selection persist by session (#17348)

Adam hace 1 mes
padre
commit
4ad8116ce3

+ 3 - 0
packages/app/e2e/fixtures.ts

@@ -95,6 +95,9 @@ async function seedStorage(page: Page, input: { directory: string; extra?: strin
     const win = window as E2EWindow
     win.__opencode_e2e = {
       ...win.__opencode_e2e,
+      model: {
+        enabled: true,
+      },
       terminal: {
         enabled: true,
         terminals: {},

+ 3 - 0
packages/app/e2e/selectors.ts

@@ -13,6 +13,9 @@ export const sessionTodoToggleButtonSelector = '[data-action="session-todo-toggl
 export const sessionTodoListSelector = '[data-slot="session-todo-list"]'
 
 export const modelVariantCycleSelector = '[data-action="model-variant-cycle"]'
+export const promptAgentSelector = '[data-component="prompt-agent-control"]'
+export const promptModelSelector = '[data-component="prompt-model-control"]'
+export const promptVariantSelector = '[data-component="prompt-variant-control"]'
 export const settingsLanguageSelectSelector = '[data-action="settings-language"]'
 export const settingsColorSchemeSelector = '[data-action="settings-color-scheme"]'
 export const settingsThemeSelector = '[data-action="settings-theme"]'

+ 351 - 0
packages/app/e2e/session/session-model-persistence.spec.ts

@@ -0,0 +1,351 @@
+import { base64Decode } from "@opencode-ai/util/encode"
+import type { Locator, Page } from "@playwright/test"
+import { test, expect } from "../fixtures"
+import { openSidebar, sessionIDFromUrl, setWorkspacesEnabled, waitSessionIdle, waitSlug } from "../actions"
+import {
+  promptAgentSelector,
+  promptModelSelector,
+  promptSelector,
+  promptVariantSelector,
+  workspaceItemSelector,
+  workspaceNewSessionSelector,
+} from "../selectors"
+import { createSdk, sessionPath } from "../utils"
+
+type Footer = {
+  agent: string
+  model: string
+  variant: string
+}
+
+type Probe = {
+  dir?: string
+  sessionID?: string
+  model?: { providerID: string; modelID: string }
+}
+
+const escape = (value: string) => value.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")
+
+const text = async (locator: Locator) => ((await locator.textContent()) ?? "").trim()
+
+const modelKey = (state: Probe | null) => (state?.model ? `${state.model.providerID}:${state.model.modelID}` : null)
+
+const dirKey = (state: Probe | null) => state?.dir ?? ""
+
+async function probe(page: Page): Promise<Probe | null> {
+  return page.evaluate(() => {
+    const win = window as Window & {
+      __opencode_e2e?: {
+        model?: {
+          current?: Probe
+        }
+      }
+    }
+    return win.__opencode_e2e?.model?.current ?? null
+  })
+}
+
+async function currentDir(page: Page) {
+  let hit = ""
+  await expect
+    .poll(
+      async () => {
+        const next = dirKey(await probe(page))
+        if (next) hit = next
+        return next
+      },
+      { timeout: 30_000 },
+    )
+    .not.toBe("")
+  return hit
+}
+
+async function read(page: Page): Promise<Footer> {
+  return {
+    agent: await text(page.locator(`${promptAgentSelector} [data-slot="select-select-trigger-value"]`).first()),
+    model: await text(page.locator(`${promptModelSelector} [data-action="prompt-model"] span`).first()),
+    variant: await text(page.locator(`${promptVariantSelector} [data-slot="select-select-trigger-value"]`).first()),
+  }
+}
+
+async function waitFooter(page: Page, expected: Partial<Footer>) {
+  let hit: Footer | null = null
+  await expect
+    .poll(
+      async () => {
+        const state = await read(page)
+        const ok = Object.entries(expected).every(([key, value]) => state[key as keyof Footer] === value)
+        if (ok) hit = state
+        return ok
+      },
+      { timeout: 30_000 },
+    )
+    .toBe(true)
+  if (!hit) throw new Error("Failed to resolve prompt footer state")
+  return hit
+}
+
+async function waitModel(page: Page, value: string) {
+  await expect.poll(() => probe(page).then(modelKey), { timeout: 30_000 }).toBe(value)
+}
+
+async function choose(page: Page, root: string, value: string) {
+  const select = page.locator(root)
+  await expect(select).toBeVisible()
+  await select.locator('[data-action], [data-slot="select-select-trigger"]').first().click()
+  const item = page
+    .locator('[data-slot="select-select-item"]')
+    .filter({ hasText: new RegExp(`^\\s*${escape(value)}\\s*$`) })
+    .first()
+  await expect(item).toBeVisible()
+  await item.click()
+}
+
+async function variantCount(page: Page) {
+  const select = page.locator(promptVariantSelector)
+  await expect(select).toBeVisible()
+  await select.locator('[data-slot="select-select-trigger"]').click()
+  const count = await page.locator('[data-slot="select-select-item"]').count()
+  await page.keyboard.press("Escape")
+  return count
+}
+
+async function agents(page: Page) {
+  const select = page.locator(promptAgentSelector)
+  await expect(select).toBeVisible()
+  await select.locator('[data-action], [data-slot="select-select-trigger"]').first().click()
+  const labels = await page.locator('[data-slot="select-select-item-label"]').allTextContents()
+  await page.keyboard.press("Escape")
+  return labels.map((item) => item.trim()).filter(Boolean)
+}
+
+async function ensureVariant(page: Page, directory: string): Promise<Footer> {
+  const current = await read(page)
+  if ((await variantCount(page)) >= 2) return current
+
+  const cfg = await createSdk(directory)
+    .config.get()
+    .then((x) => x.data)
+  const visible = new Set(await agents(page))
+  const entry = Object.entries(cfg?.agent ?? {}).find((item) => {
+    const value = item[1]
+    return !!value && typeof value === "object" && "variant" in value && "model" in value && visible.has(item[0])
+  })
+  const name = entry?.[0]
+  test.skip(!name, "no agent with alternate variants available")
+  if (!name) return current
+
+  await choose(page, promptAgentSelector, name)
+  await expect.poll(() => variantCount(page), { timeout: 30_000 }).toBeGreaterThanOrEqual(2)
+  return waitFooter(page, { agent: name })
+}
+
+async function chooseDifferentVariant(page: Page): Promise<Footer> {
+  const current = await read(page)
+  const select = page.locator(promptVariantSelector)
+  await expect(select).toBeVisible()
+  await select.locator('[data-slot="select-select-trigger"]').click()
+
+  const items = page.locator('[data-slot="select-select-item"]')
+  const count = await items.count()
+  if (count < 2) throw new Error("Current model has no alternate variant to select")
+
+  for (let i = 0; i < count; i++) {
+    const item = items.nth(i)
+    const next = await text(item.locator('[data-slot="select-select-item-label"]').first())
+    if (!next || next === current.variant) continue
+    await item.click()
+    return waitFooter(page, { agent: current.agent, model: current.model, variant: next })
+  }
+
+  throw new Error("Failed to choose a different variant")
+}
+
+async function chooseOtherModel(page: Page): Promise<Footer> {
+  const current = await read(page)
+  const button = page.locator(`${promptModelSelector} [data-action="prompt-model"]`)
+  await expect(button).toBeVisible()
+  await button.click()
+
+  const dialog = page.getByRole("dialog")
+  await expect(dialog).toBeVisible()
+  const items = dialog.locator('[data-slot="list-item"]')
+  const count = await items.count()
+  expect(count).toBeGreaterThan(1)
+
+  for (let i = 0; i < count; i++) {
+    const item = items.nth(i)
+    const selected = (await item.getAttribute("data-selected")) === "true"
+    if (selected) continue
+    await item.click()
+    await expect(dialog).toHaveCount(0)
+    await expect.poll(async () => (await read(page)).model !== current.model, { timeout: 30_000 }).toBe(true)
+    return read(page)
+  }
+
+  throw new Error("Failed to choose a different model")
+}
+
+async function goto(page: Page, directory: string, sessionID?: string) {
+  await page.goto(sessionPath(directory, sessionID))
+  await expect(page.locator(promptSelector)).toBeVisible()
+  await expect.poll(async () => dirKey(await probe(page)), { timeout: 30_000 }).toBe(directory)
+}
+
+async function submit(page: Page, value: string) {
+  const prompt = page.locator(promptSelector)
+  await expect(prompt).toBeVisible()
+  await prompt.click()
+  await prompt.fill(value)
+  await prompt.press("Enter")
+
+  await expect.poll(() => sessionIDFromUrl(page.url()) ?? "", { timeout: 30_000 }).not.toBe("")
+  const id = sessionIDFromUrl(page.url())
+  if (!id) throw new Error(`Failed to resolve session id from ${page.url()}`)
+  return id
+}
+
+async function waitUser(directory: string, sessionID: string) {
+  const sdk = createSdk(directory)
+  await expect
+    .poll(
+      async () => {
+        const items = await sdk.session.messages({ sessionID, limit: 20 }).then((x) => x.data ?? [])
+        return items.some((item) => item.info.role === "user")
+      },
+      { timeout: 30_000 },
+    )
+    .toBe(true)
+  await sdk.session.abort({ sessionID }).catch(() => undefined)
+  await waitSessionIdle(sdk, sessionID, 30_000).catch(() => undefined)
+}
+
+async function createWorkspace(page: Page, root: string, seen: string[]) {
+  await openSidebar(page)
+  await page.getByRole("button", { name: "New workspace" }).first().click()
+
+  const slug = await waitSlug(page, [root, ...seen])
+  const directory = base64Decode(slug)
+  if (!directory) throw new Error(`Failed to decode workspace slug: ${slug}`)
+  return { slug, directory }
+}
+
+async function waitWorkspace(page: Page, slug: string) {
+  await openSidebar(page)
+  await expect
+    .poll(
+      async () => {
+        const item = page.locator(workspaceItemSelector(slug)).first()
+        try {
+          await item.hover({ timeout: 500 })
+          return true
+        } catch {
+          return false
+        }
+      },
+      { timeout: 60_000 },
+    )
+    .toBe(true)
+}
+
+async function newWorkspaceSession(page: Page, slug: string) {
+  await waitWorkspace(page, slug)
+  const item = page.locator(workspaceItemSelector(slug)).first()
+  await item.hover()
+
+  const button = page.locator(workspaceNewSessionSelector(slug)).first()
+  await expect(button).toBeVisible()
+  await button.click({ force: true })
+
+  const next = await waitSlug(page)
+  await expect(page).toHaveURL(new RegExp(`/${next}/session(?:[/?#]|$)`))
+  await expect(page.locator(promptSelector)).toBeVisible()
+  return currentDir(page)
+}
+
+test("session model and variant restore per session without leaking into new sessions", async ({
+  page,
+  withProject,
+}) => {
+  await page.setViewportSize({ width: 1440, height: 900 })
+
+  await withProject(async ({ directory, gotoSession, trackSession }) => {
+    await gotoSession()
+
+    await ensureVariant(page, directory)
+    const firstState = await chooseDifferentVariant(page)
+    const first = await submit(page, `session variant ${Date.now()}`)
+    trackSession(first)
+    await waitUser(directory, first)
+
+    await page.reload()
+    await expect(page.locator(promptSelector)).toBeVisible()
+    await waitFooter(page, firstState)
+
+    await gotoSession()
+    const fresh = await ensureVariant(page, directory)
+    expect(fresh.variant).not.toBe(firstState.variant)
+
+    const secondState = await chooseOtherModel(page)
+    const second = await submit(page, `session model ${Date.now()}`)
+    trackSession(second)
+    await waitUser(directory, second)
+
+    await goto(page, directory, first)
+    await waitFooter(page, firstState)
+
+    await goto(page, directory, second)
+    await waitFooter(page, secondState)
+
+    await gotoSession()
+    await waitFooter(page, fresh)
+  })
+})
+
+test("session model restore across workspaces", async ({ page, withProject }) => {
+  await page.setViewportSize({ width: 1440, height: 900 })
+
+  await withProject(async ({ directory: root, slug, gotoSession, trackDirectory, trackSession }) => {
+    await gotoSession()
+
+    await ensureVariant(page, root)
+    const firstState = await chooseDifferentVariant(page)
+    const first = await submit(page, `root session ${Date.now()}`)
+    trackSession(first, root)
+    await waitUser(root, first)
+
+    await openSidebar(page)
+    await setWorkspacesEnabled(page, slug, true)
+
+    const one = await createWorkspace(page, slug, [])
+    const oneDir = await newWorkspaceSession(page, one.slug)
+    trackDirectory(oneDir)
+
+    const secondState = await chooseOtherModel(page)
+    const second = await submit(page, `workspace one ${Date.now()}`)
+    trackSession(second, oneDir)
+    await waitUser(oneDir, second)
+
+    const two = await createWorkspace(page, slug, [one.slug])
+    const twoDir = await newWorkspaceSession(page, two.slug)
+    trackDirectory(twoDir)
+
+    await ensureVariant(page, twoDir)
+    const thirdState = await chooseDifferentVariant(page)
+    const third = await submit(page, `workspace two ${Date.now()}`)
+    trackSession(third, twoDir)
+    await waitUser(twoDir, third)
+
+    await goto(page, root, first)
+    await waitFooter(page, firstState)
+
+    await goto(page, oneDir, second)
+    await waitFooter(page, secondState)
+
+    await goto(page, twoDir, third)
+    await waitFooter(page, thirdState)
+
+    await goto(page, root, first)
+    await waitFooter(page, firstState)
+  })
+})

+ 7 - 5
packages/app/src/components/dialog-select-model-unpaid.tsx

@@ -13,8 +13,10 @@ import { DialogSelectProvider } from "./dialog-select-provider"
 import { ModelTooltip } from "./model-tooltip"
 import { useLanguage } from "@/context/language"
 
-export const DialogSelectModelUnpaid: Component = () => {
-  const local = useLocal()
+type ModelState = ReturnType<typeof useLocal>["model"]
+
+export const DialogSelectModelUnpaid: Component<{ model?: ModelState }> = (props) => {
+  const model = props.model ?? useLocal().model
   const dialog = useDialog()
   const providers = useProviders()
   const language = useLanguage()
@@ -35,8 +37,8 @@ export const DialogSelectModelUnpaid: Component = () => {
         <List
           class="[&_[data-slot=list-scroll]]:overflow-visible"
           ref={(ref) => (listRef = ref)}
-          items={local.model.list}
-          current={local.model.current()}
+          items={model.list}
+          current={model.current()}
           key={(x) => `${x.provider.id}:${x.id}`}
           itemWrapper={(item, node) => (
             <Tooltip
@@ -55,7 +57,7 @@ export const DialogSelectModelUnpaid: Component = () => {
             </Tooltip>
           )}
           onSelect={(x) => {
-            local.model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined, {
+            model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined, {
               recent: true,
             })
             dialog.close()

+ 12 - 7
packages/app/src/components/dialog-select-model.tsx

@@ -18,19 +18,22 @@ import { useLanguage } from "@/context/language"
 const isFree = (provider: string, cost: { input: number } | undefined) =>
   provider === "opencode" && (!cost || cost.input === 0)
 
+type ModelState = ReturnType<typeof useLocal>["model"]
+
 const ModelList: Component<{
   provider?: string
   class?: string
   onSelect: () => void
   action?: JSX.Element
+  model?: ModelState
 }> = (props) => {
-  const local = useLocal()
+  const model = props.model ?? useLocal().model
   const language = useLanguage()
 
   const models = createMemo(() =>
-    local.model
+    model
       .list()
-      .filter((m) => local.model.visible({ modelID: m.id, providerID: m.provider.id }))
+      .filter((m) => model.visible({ modelID: m.id, providerID: m.provider.id }))
       .filter((m) => (props.provider ? m.provider.id === props.provider : true)),
   )
 
@@ -41,7 +44,7 @@ const ModelList: Component<{
       emptyMessage={language.t("dialog.model.empty")}
       key={(x) => `${x.provider.id}:${x.id}`}
       items={models}
-      current={local.model.current()}
+      current={model.current()}
       filterKeys={["provider.name", "name", "id"]}
       sortBy={(a, b) => a.name.localeCompare(b.name)}
       groupBy={(x) => x.provider.name}
@@ -63,7 +66,7 @@ const ModelList: Component<{
         </Tooltip>
       )}
       onSelect={(x) => {
-        local.model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined, {
+        model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined, {
           recent: true,
         })
         props.onSelect()
@@ -88,6 +91,7 @@ type ModelSelectorTriggerProps = Omit<ComponentProps<typeof Kobalte.Trigger>, "a
 
 export function ModelSelectorPopover(props: {
   provider?: string
+  model?: ModelState
   children?: JSX.Element
   triggerAs?: ValidComponent
   triggerProps?: ModelSelectorTriggerProps
@@ -151,6 +155,7 @@ export function ModelSelectorPopover(props: {
           <Kobalte.Title class="sr-only">{language.t("dialog.model.select.title")}</Kobalte.Title>
           <ModelList
             provider={props.provider}
+            model={props.model}
             onSelect={() => setStore("open", false)}
             class="p-1"
             action={
@@ -184,7 +189,7 @@ export function ModelSelectorPopover(props: {
   )
 }
 
-export const DialogSelectModel: Component<{ provider?: string }> = (props) => {
+export const DialogSelectModel: Component<{ provider?: string; model?: ModelState }> = (props) => {
   const dialog = useDialog()
   const language = useLanguage()
 
@@ -202,7 +207,7 @@ export const DialogSelectModel: Component<{ provider?: string }> = (props) => {
         </Button>
       }
     >
-      <ModelList provider={props.provider} onSelect={() => dialog.close()} />
+      <ModelList provider={props.provider} model={props.model} onSelect={() => dialog.close()} />
       <Button
         variant="ghost"
         class="ml-3 mt-5 mb-6 text-text-base self-start"

+ 83 - 72
packages/app/src/components/prompt-input.tsx

@@ -1430,39 +1430,76 @@ export const PromptInput: Component<PromptInputProps> = (props) => {
                 <div class="size-4 shrink-0" />
               </div>
               <div class="flex items-center gap-1.5 min-w-0 flex-1">
-                <TooltipKeybind
-                  placement="top"
-                  gutter={4}
-                  title={language.t("command.agent.cycle")}
-                  keybind={command.keybind("agent.cycle")}
-                >
-                  <Select
-                    size="normal"
-                    options={agentNames()}
-                    current={local.agent.current()?.name ?? ""}
-                    onSelect={local.agent.set}
-                    class="capitalize max-w-[160px] text-text-base"
-                    valueClass="truncate text-13-regular text-text-base"
-                    triggerStyle={control()}
-                    variant="ghost"
-                  />
-                </TooltipKeybind>
-                <Show
-                  when={providers.paid().length > 0}
-                  fallback={
+                <div data-component="prompt-agent-control">
+                  <TooltipKeybind
+                    placement="top"
+                    gutter={4}
+                    title={language.t("command.agent.cycle")}
+                    keybind={command.keybind("agent.cycle")}
+                  >
+                    <Select
+                      size="normal"
+                      options={agentNames()}
+                      current={local.agent.current()?.name ?? ""}
+                      onSelect={local.agent.set}
+                      class="capitalize max-w-[160px] text-text-base"
+                      valueClass="truncate text-13-regular text-text-base"
+                      triggerStyle={control()}
+                      triggerProps={{ "data-action": "prompt-agent" }}
+                      variant="ghost"
+                    />
+                  </TooltipKeybind>
+                </div>
+                <div data-component="prompt-model-control">
+                  <Show
+                    when={providers.paid().length > 0}
+                    fallback={
+                      <TooltipKeybind
+                        placement="top"
+                        gutter={4}
+                        title={language.t("command.model.choose")}
+                        keybind={command.keybind("model.choose")}
+                      >
+                        <Button
+                          data-action="prompt-model"
+                          as="div"
+                          variant="ghost"
+                          size="normal"
+                          class="min-w-0 max-w-[320px] text-13-regular text-text-base group"
+                          style={control()}
+                          onClick={() => dialog.show(() => <DialogSelectModelUnpaid model={local.model} />)}
+                        >
+                          <Show when={local.model.current()?.provider?.id}>
+                            <ProviderIcon
+                              id={local.model.current()!.provider.id}
+                              class="size-4 shrink-0 opacity-40 group-hover:opacity-100 transition-opacity duration-150"
+                              style={{ "will-change": "opacity", transform: "translateZ(0)" }}
+                            />
+                          </Show>
+                          <span class="truncate">
+                            {local.model.current()?.name ?? language.t("dialog.model.select.title")}
+                          </span>
+                          <Icon name="chevron-down" size="small" class="shrink-0" />
+                        </Button>
+                      </TooltipKeybind>
+                    }
+                  >
                     <TooltipKeybind
                       placement="top"
                       gutter={4}
                       title={language.t("command.model.choose")}
                       keybind={command.keybind("model.choose")}
                     >
-                      <Button
-                        as="div"
-                        variant="ghost"
-                        size="normal"
-                        class="min-w-0 max-w-[320px] text-13-regular text-text-base group"
-                        style={control()}
-                        onClick={() => dialog.show(() => <DialogSelectModelUnpaid />)}
+                      <ModelSelectorPopover
+                        model={local.model}
+                        triggerAs={Button}
+                        triggerProps={{
+                          variant: "ghost",
+                          size: "normal",
+                          style: control(),
+                          class: "min-w-0 max-w-[320px] text-13-regular text-text-base group",
+                          "data-action": "prompt-model",
+                        }}
                       >
                         <Show when={local.model.current()?.provider?.id}>
                           <ProviderIcon
@@ -1475,57 +1512,31 @@ export const PromptInput: Component<PromptInputProps> = (props) => {
                           {local.model.current()?.name ?? language.t("dialog.model.select.title")}
                         </span>
                         <Icon name="chevron-down" size="small" class="shrink-0" />
-                      </Button>
+                      </ModelSelectorPopover>
                     </TooltipKeybind>
-                  }
-                >
+                  </Show>
+                </div>
+                <div data-component="prompt-variant-control">
                   <TooltipKeybind
                     placement="top"
                     gutter={4}
-                    title={language.t("command.model.choose")}
-                    keybind={command.keybind("model.choose")}
+                    title={language.t("command.model.variant.cycle")}
+                    keybind={command.keybind("model.variant.cycle")}
                   >
-                    <ModelSelectorPopover
-                      triggerAs={Button}
-                      triggerProps={{
-                        variant: "ghost",
-                        size: "normal",
-                        style: control(),
-                        class: "min-w-0 max-w-[320px] text-13-regular text-text-base group",
-                      }}
-                    >
-                      <Show when={local.model.current()?.provider?.id}>
-                        <ProviderIcon
-                          id={local.model.current()!.provider.id}
-                          class="size-4 shrink-0 opacity-40 group-hover:opacity-100 transition-opacity duration-150"
-                          style={{ "will-change": "opacity", transform: "translateZ(0)" }}
-                        />
-                      </Show>
-                      <span class="truncate">
-                        {local.model.current()?.name ?? language.t("dialog.model.select.title")}
-                      </span>
-                      <Icon name="chevron-down" size="small" class="shrink-0" />
-                    </ModelSelectorPopover>
+                    <Select
+                      size="normal"
+                      options={variants()}
+                      current={local.model.variant.current() ?? "default"}
+                      label={(x) => (x === "default" ? language.t("common.default") : x)}
+                      onSelect={(x) => local.model.variant.set(x === "default" ? undefined : x)}
+                      class="capitalize max-w-[160px] text-text-base"
+                      valueClass="truncate text-13-regular text-text-base"
+                      triggerStyle={control()}
+                      triggerProps={{ "data-action": "prompt-model-variant" }}
+                      variant="ghost"
+                    />
                   </TooltipKeybind>
-                </Show>
-                <TooltipKeybind
-                  placement="top"
-                  gutter={4}
-                  title={language.t("command.model.variant.cycle")}
-                  keybind={command.keybind("model.variant.cycle")}
-                >
-                  <Select
-                    size="normal"
-                    options={variants()}
-                    current={local.model.variant.current() ?? "default"}
-                    label={(x) => (x === "default" ? language.t("common.default") : x)}
-                    onSelect={(x) => local.model.variant.set(x === "default" ? undefined : x)}
-                    class="capitalize max-w-[160px] text-text-base"
-                    valueClass="truncate text-13-regular text-text-base"
-                    triggerStyle={control()}
-                    variant="ghost"
-                  />
-                </TooltipKeybind>
+                </div>
                 <TooltipKeybind
                   placement="top"
                   gutter={8}

+ 12 - 0
packages/app/src/components/prompt-input/submit.test.ts

@@ -17,6 +17,7 @@ const optimistic: Array<{
 }> = []
 const optimisticSeeded: boolean[] = []
 const storedSessions: Record<string, Array<{ id: string; title?: string }>> = {}
+const promoted: Array<{ directory: string; sessionID: string }> = []
 const sentShell: string[] = []
 const syncedDirectories: string[] = []
 
@@ -86,6 +87,11 @@ beforeAll(async () => {
       agent: {
         current: () => ({ name: "agent" }),
       },
+      session: {
+        promote(directory: string, sessionID: string) {
+          promoted.push({ directory, sessionID })
+        },
+      },
     }),
   }))
 
@@ -201,6 +207,7 @@ beforeEach(() => {
   enabledAutoAccept.length = 0
   optimistic.length = 0
   optimisticSeeded.length = 0
+  promoted.length = 0
   params = {}
   sentShell.length = 0
   syncedDirectories.length = 0
@@ -240,6 +247,11 @@ describe("prompt submit worktree selection", () => {
     expect(createdSessions).toEqual(["/repo/worktree-a", "/repo/worktree-b"])
     expect(sentShell).toEqual(["/repo/worktree-a", "/repo/worktree-b"])
     expect(syncedDirectories).toEqual(["/repo/worktree-a", "/repo/worktree-a", "/repo/worktree-b", "/repo/worktree-b"])
+    expect(promoted).toEqual([
+      { directory: "/repo/worktree-a", sessionID: "session-1" },
+      { directory: "/repo/worktree-b", sessionID: "session-2" },
+    ])
+    expect(syncedDirectories).toEqual(["/repo/worktree-a", "/repo/worktree-a", "/repo/worktree-b", "/repo/worktree-b"])
   })
 
   test("applies auto-accept to newly created sessions", async () => {

+ 2 - 1
packages/app/src/components/prompt-input/submit.ts

@@ -296,6 +296,7 @@ export function createPromptSubmit(input: PromptSubmitInput) {
 
     const currentModel = local.model.current()
     const currentAgent = local.agent.current()
+    const variant = local.model.variant.current()
     if (!currentModel || !currentAgent) {
       showToast({
         title: language.t("prompt.toast.modelAgentRequired.title"),
@@ -370,6 +371,7 @@ export function createPromptSubmit(input: PromptSubmitInput) {
         seed(sessionDirectory, created)
         session = created
         if (shouldAutoAccept) permission.enableAutoAccept(session.id, sessionDirectory)
+        local.session.promote(sessionDirectory, session.id)
         layout.handoff.setTabs(base64Encode(sessionDirectory), session.id)
         navigate(`/${base64Encode(sessionDirectory)}/session/${session.id}`)
       }
@@ -387,7 +389,6 @@ export function createPromptSubmit(input: PromptSubmitInput) {
       providerID: currentModel.provider.id,
     }
     const agent = currentAgent.name
-    const variant = local.model.variant.current()
     const context = prompt.context.items().slice()
     const draft: FollowupDraft = {
       sessionID: session.id,

+ 360 - 191
packages/app/src/context/local.tsx

@@ -1,252 +1,421 @@
-import { createStore } from "solid-js/store"
-import { batch, createMemo } from "solid-js"
 import { createSimpleContext } from "@opencode-ai/ui/context"
-import { useSDK } from "./sdk"
-import { useSync } from "./sync"
 import { base64Encode } from "@opencode-ai/util/encode"
-import { useProviders } from "@/hooks/use-providers"
+import { useParams } from "@solidjs/router"
+import { batch, createEffect, createMemo, onCleanup } from "solid-js"
+import { createStore } from "solid-js/store"
 import { useModels } from "@/context/models"
+import { useProviders } from "@/hooks/use-providers"
+import { modelEnabled, modelProbe } from "@/testing/model-selection"
+import { Persist, persisted } from "@/utils/persist"
 import { cycleModelVariant, getConfiguredAgentVariant, resolveModelVariant } from "./model-variant"
+import { useSDK } from "./sdk"
+import { useSync } from "./sync"
 
 export type ModelKey = { providerID: string; modelID: string }
 
+type State = {
+  agent?: string
+  model?: ModelKey
+  variant?: string | null
+}
+
+type Saved = {
+  session: Record<string, State | undefined>
+}
+
+const WORKSPACE_KEY = "__workspace__"
+const handoff = new Map<string, State>()
+
+const handoffKey = (dir: string, id: string) => `${dir}\n${id}`
+
+const migrate = (value: unknown) => {
+  if (!value || typeof value !== "object") return { session: {} }
+
+  const item = value as {
+    session?: Record<string, State | undefined>
+    pick?: Record<string, State | undefined>
+  }
+
+  if (item.session && typeof item.session === "object") return { session: item.session }
+  if (!item.pick || typeof item.pick !== "object") return { session: {} }
+
+  return {
+    session: Object.fromEntries(Object.entries(item.pick).filter(([key]) => key !== WORKSPACE_KEY)),
+  }
+}
+
+const clone = (value: State | undefined) => {
+  if (!value) return undefined
+  return {
+    ...value,
+    model: value.model ? { ...value.model } : undefined,
+  } satisfies State
+}
+
 export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
   name: "Local",
   init: () => {
+    const params = useParams()
     const sdk = useSDK()
     const sync = useSync()
     const providers = useProviders()
-    const connected = createMemo(() => new Set(providers.connected().map((provider) => provider.id)))
+    const models = useModels()
+
+    const id = createMemo(() => params.id || undefined)
+    const list = createMemo(() => sync.data.agent.filter((item) => item.mode !== "subagent" && !item.hidden))
+    const connected = createMemo(() => new Set(providers.connected().map((item) => item.id)))
 
-    function isModelValid(model: ModelKey) {
-      const provider = providers.all().find((x) => x.id === model.providerID)
+    const [saved, setSaved] = persisted(
+      {
+        ...Persist.workspace(sdk.directory, "model-selection", ["model-selection.v1"]),
+        migrate,
+      },
+      createStore<Saved>({
+        session: {},
+      }),
+    )
+
+    const [store, setStore] = createStore<{
+      current?: string
+      draft?: State
+      last?: {
+        type: "agent" | "model" | "variant"
+        agent?: string
+        model?: ModelKey | null
+        variant?: string | null
+      }
+    }>({
+      current: list()[0]?.name,
+      draft: undefined,
+      last: undefined,
+    })
+
+    const validModel = (model: ModelKey) => {
+      const provider = providers.all().find((item) => item.id === model.providerID)
       return !!provider?.models[model.modelID] && connected().has(model.providerID)
     }
 
-    function getFirstValidModel(...modelFns: (() => ModelKey | undefined)[]) {
-      for (const modelFn of modelFns) {
-        const model = modelFn()
+    const firstModel = (...items: Array<() => ModelKey | undefined>) => {
+      for (const item of items) {
+        const model = item()
         if (!model) continue
-        if (isModelValid(model)) return model
+        if (validModel(model)) return model
       }
     }
 
-    let setModel: (model: ModelKey | undefined, options?: { recent?: boolean }) => void = () => undefined
+    const pickAgent = (name: string | undefined) => {
+      const items = list()
+      if (items.length === 0) return undefined
+      return items.find((item) => item.name === name) ?? items[0]
+    }
 
-    const agent = (() => {
-      const list = createMemo(() => sync.data.agent.filter((x) => x.mode !== "subagent" && !x.hidden))
-      const models = useModels()
+    createEffect(() => {
+      const items = list()
+      if (items.length === 0) {
+        if (store.current !== undefined) setStore("current", undefined)
+        return
+      }
+      if (items.some((item) => item.name === store.current)) return
+      setStore("current", items[0]?.name)
+    })
 
-      const [store, setStore] = createStore<{
-        current?: string
-      }>({
-        current: list()[0]?.name,
-      })
-      return {
-        list,
-        current() {
-          const available = list()
-          if (available.length === 0) return undefined
-          return available.find((x) => x.name === store.current) ?? available[0]
-        },
-        set(name: string | undefined) {
-          const available = list()
-          if (available.length === 0) {
-            setStore("current", undefined)
-            return
-          }
-          const match = name ? available.find((x) => x.name === name) : undefined
-          const value = match ?? available[0]
-          if (!value) return
-          setStore("current", value.name)
-          if (!value.model) return
-          setModel({
-            providerID: value.model.providerID,
-            modelID: value.model.modelID,
-          })
-          if (value.variant)
-            models.variant.set({ providerID: value.model.providerID, modelID: value.model.modelID }, value.variant)
-        },
-        move(direction: 1 | -1) {
-          const available = list()
-          if (available.length === 0) {
-            setStore("current", undefined)
-            return
-          }
-          let next = available.findIndex((x) => x.name === store.current) + direction
-          if (next < 0) next = available.length - 1
-          if (next >= available.length) next = 0
-          const value = available[next]
-          if (!value) return
-          setStore("current", value.name)
-          if (!value.model) return
-          setModel({
-            providerID: value.model.providerID,
-            modelID: value.model.modelID,
-          })
-          if (value.variant)
-            models.variant.set({ providerID: value.model.providerID, modelID: value.model.modelID }, value.variant)
-        },
+    const scope = createMemo<State | undefined>(() => {
+      const session = id()
+      if (!session) return store.draft
+      return saved.session[session] ?? handoff.get(handoffKey(sdk.directory, session))
+    })
+
+    createEffect(() => {
+      const session = id()
+      if (!session) return
+
+      const key = handoffKey(sdk.directory, session)
+      const next = handoff.get(key)
+      if (!next) return
+      if (saved.session[session] !== undefined) {
+        handoff.delete(key)
+        return
       }
-    })()
 
-    const model = (() => {
-      const models = useModels()
+      setSaved("session", session, clone(next))
+      handoff.delete(key)
+    })
 
-      const [ephemeral, setEphemeral] = createStore<{
-        model: Record<string, ModelKey | undefined>
-      }>({
-        model: {},
-      })
+    const configuredModel = () => {
+      if (!sync.data.config.model) return
+      const [providerID, modelID] = sync.data.config.model.split("/")
+      const model = { providerID, modelID }
+      if (validModel(model)) return model
+    }
 
-      const resolveConfigured = () => {
-        if (!sync.data.config.model) return
-        const [providerID, modelID] = sync.data.config.model.split("/")
-        const key = { providerID, modelID }
-        if (isModelValid(key)) return key
+    const recentModel = () => {
+      for (const item of models.recent.list()) {
+        if (validModel(item)) return item
       }
+    }
 
-      const resolveRecent = () => {
-        for (const item of models.recent.list()) {
-          if (isModelValid(item)) return item
+    const defaultModel = () => {
+      const defaults = providers.default()
+      for (const provider of providers.connected()) {
+        const configured = defaults[provider.id]
+        if (configured) {
+          const model = { providerID: provider.id, modelID: configured }
+          if (validModel(model)) return model
         }
+
+        const first = Object.values(provider.models)[0]
+        if (!first) continue
+        const model = { providerID: provider.id, modelID: first.id }
+        if (validModel(model)) return model
       }
+    }
 
-      const resolveDefault = () => {
-        const defaults = providers.default()
-        for (const provider of providers.connected()) {
-          const configured = defaults[provider.id]
-          if (configured) {
-            const key = { providerID: provider.id, modelID: configured }
-            if (isModelValid(key)) return key
-          }
+    const fallback = createMemo<ModelKey | undefined>(() => configuredModel() ?? recentModel() ?? defaultModel())
 
-          const first = Object.values(provider.models)[0]
-          if (!first) continue
-          const key = { providerID: provider.id, modelID: first.id }
-          if (isModelValid(key)) return key
+    const agent = {
+      list,
+      current() {
+        return pickAgent(scope()?.agent ?? store.current)
+      },
+      set(name: string | undefined) {
+        const item = pickAgent(name)
+        if (!item) {
+          setStore("current", undefined)
+          return
         }
-      }
 
-      const fallbackModel = createMemo<ModelKey | undefined>(() => {
-        return resolveConfigured() ?? resolveRecent() ?? resolveDefault()
-      })
+        batch(() => {
+          setStore("current", item.name)
+          setStore("last", {
+            type: "agent",
+            agent: item.name,
+            model: item.model,
+            variant: item.variant ?? null,
+          })
+          const next = {
+            agent: item.name,
+            model: item.model,
+            variant: item.variant,
+          } satisfies State
+          const session = id()
+          if (session) {
+            setSaved("session", session, next)
+            return
+          }
+          setStore("draft", next)
+        })
+      },
+      move(direction: 1 | -1) {
+        const items = list()
+        if (items.length === 0) {
+          setStore("current", undefined)
+          return
+        }
 
-      const current = createMemo(() => {
-        const a = agent.current()
-        if (!a) return undefined
-        const key = getFirstValidModel(
-          () => ephemeral.model[a.name],
-          () => a.model,
-          fallbackModel,
-        )
-        if (!key) return undefined
-        return models.find(key)
-      })
+        let next = items.findIndex((item) => item.name === agent.current()?.name) + direction
+        if (next < 0) next = items.length - 1
+        if (next >= items.length) next = 0
+        const item = items[next]
+        if (!item) return
+        agent.set(item.name)
+      },
+    }
 
-      const recent = createMemo(() => models.recent.list().map(models.find).filter(Boolean))
+    const current = () => {
+      const item = firstModel(
+        () => scope()?.model,
+        () => agent.current()?.model,
+        fallback,
+      )
+      if (!item) return undefined
+      return models.find(item)
+    }
 
-      const cycle = (direction: 1 | -1) => {
-        const recentList = recent()
-        const currentModel = current()
-        if (!currentModel) return
+    const configured = () => {
+      const item = agent.current()
+      const model = current()
+      if (!item || !model) return undefined
+      return getConfiguredAgentVariant({
+        agent: { model: item.model, variant: item.variant },
+        model: { providerID: model.provider.id, modelID: model.id, variants: model.variants },
+      })
+    }
 
-        const index = recentList.findIndex(
-          (x) => x?.provider.id === currentModel.provider.id && x?.id === currentModel.id,
-        )
-        if (index === -1) return
+    const selected = () => scope()?.variant
 
-        let next = index + direction
-        if (next < 0) next = recentList.length - 1
-        if (next >= recentList.length) next = 0
+    const snapshot = () => {
+      const model = current()
+      return {
+        agent: agent.current()?.name,
+        model: model ? { providerID: model.provider.id, modelID: model.id } : undefined,
+        variant: selected(),
+      } satisfies State
+    }
 
-        const val = recentList[next]
-        if (!val) return
+    const write = (next: Partial<State>) => {
+      const state = {
+        ...(scope() ?? { agent: agent.current()?.name }),
+        ...next,
+      } satisfies State
 
-        model.set({
-          providerID: val.provider.id,
-          modelID: val.id,
-        })
+      const session = id()
+      if (session) {
+        setSaved("session", session, state)
+        return
       }
+      setStore("draft", state)
+    }
 
-      const set = (model: ModelKey | undefined, options?: { recent?: boolean }) => {
-        batch(() => {
-          const currentAgent = agent.current()
-          const next = model ?? fallbackModel()
-          if (currentAgent) setEphemeral("model", currentAgent.name, next)
-          if (model) models.setVisibility(model, true)
-          if (options?.recent && model) models.recent.push(model)
-        })
-      }
+    const recent = createMemo(() => models.recent.list().map(models.find).filter(Boolean))
 
-      setModel = set
+    const model = {
+      ready: models.ready,
+      current,
+      recent,
+      list: models.list,
+      cycle(direction: 1 | -1) {
+        const items = recent()
+        const item = current()
+        if (!item) return
 
-      return {
-        ready: models.ready,
-        current,
-        recent,
-        list: models.list,
-        cycle,
-        set,
-        visible(model: ModelKey) {
-          return models.visible(model)
+        const index = items.findIndex((entry) => entry?.provider.id === item.provider.id && entry?.id === item.id)
+        if (index === -1) return
+
+        let next = index + direction
+        if (next < 0) next = items.length - 1
+        if (next >= items.length) next = 0
+
+        const entry = items[next]
+        if (!entry) return
+        model.set({ providerID: entry.provider.id, modelID: entry.id })
+      },
+      set(item: ModelKey | undefined, options?: { recent?: boolean }) {
+        batch(() => {
+          setStore("last", {
+            type: "model",
+            agent: agent.current()?.name,
+            model: item ?? null,
+            variant: selected(),
+          })
+          write({ model: item })
+          if (!item) return
+          models.setVisibility(item, true)
+          if (!options?.recent) return
+          models.recent.push(item)
+        })
+      },
+      visible(item: ModelKey) {
+        return models.visible(item)
+      },
+      setVisibility(item: ModelKey, visible: boolean) {
+        models.setVisibility(item, visible)
+      },
+      variant: {
+        configured,
+        selected,
+        current() {
+          return resolveModelVariant({
+            variants: this.list(),
+            selected: this.selected(),
+            configured: this.configured(),
+          })
         },
-        setVisibility(model: ModelKey, visible: boolean) {
-          models.setVisibility(model, visible)
+        list() {
+          const item = current()
+          if (!item?.variants) return []
+          return Object.keys(item.variants)
         },
-        variant: {
-          configured() {
-            const a = agent.current()
-            const m = current()
-            if (!a || !m) return undefined
-            return getConfiguredAgentVariant({
-              agent: { model: a.model, variant: a.variant },
-              model: { providerID: m.provider.id, modelID: m.id, variants: m.variants },
+        set(value: string | undefined) {
+          batch(() => {
+            const model = current()
+            setStore("last", {
+              type: "variant",
+              agent: agent.current()?.name,
+              model: model ? { providerID: model.provider.id, modelID: model.id } : null,
+              variant: value ?? null,
             })
-          },
-          selected() {
-            const m = current()
-            if (!m) return undefined
-            return models.variant.get({ providerID: m.provider.id, modelID: m.id })
-          },
-          current() {
-            return resolveModelVariant({
-              variants: this.list(),
+            write({ variant: value ?? null })
+          })
+        },
+        cycle() {
+          const items = this.list()
+          if (items.length === 0) return
+          this.set(
+            cycleModelVariant({
+              variants: items,
               selected: this.selected(),
               configured: this.configured(),
-            })
-          },
-          list() {
-            const m = current()
-            if (!m) return []
-            if (!m.variants) return []
-            return Object.keys(m.variants)
-          },
-          set(value: string | undefined) {
-            const m = current()
-            if (!m) return
-            models.variant.set({ providerID: m.provider.id, modelID: m.id }, value)
-          },
-          cycle() {
-            const variants = this.list()
-            if (variants.length === 0) return
-            this.set(
-              cycleModelVariant({
-                variants,
-                selected: this.selected(),
-                configured: this.configured(),
-              }),
-            )
-          },
+            }),
+          )
         },
-      }
-    })()
+      },
+    }
 
     const result = {
       slug: createMemo(() => base64Encode(sdk.directory)),
       model,
       agent,
+      session: {
+        reset() {
+          setStore("draft", undefined)
+        },
+        promote(dir: string, session: string) {
+          const next = clone(snapshot())
+          if (!next) return
+
+          if (dir === sdk.directory) {
+            setSaved("session", session, next)
+            setStore("draft", undefined)
+            return
+          }
+
+          handoff.set(handoffKey(dir, session), next)
+          setStore("draft", undefined)
+        },
+        restore(msg: { sessionID: string; agent: string; model: ModelKey; variant?: string }) {
+          const session = id()
+          if (!session) return
+          if (msg.sessionID !== session) return
+          if (saved.session[session] !== undefined) return
+          if (handoff.has(handoffKey(sdk.directory, session))) return
+
+          setSaved("session", session, {
+            agent: msg.agent,
+            model: msg.model,
+            variant: msg.variant ?? null,
+          })
+        },
+      },
     }
+
+    if (modelEnabled()) {
+      createEffect(() => {
+        const agent = result.agent.current()
+        const model = result.model.current()
+        modelProbe.set({
+          dir: sdk.directory,
+          sessionID: id(),
+          last: store.last,
+          agent: agent?.name,
+          model: model
+            ? {
+                providerID: model.provider.id,
+                modelID: model.id,
+                name: model.name,
+              }
+            : undefined,
+          variant: result.model.variant.current() ?? null,
+          selected: result.model.variant.selected(),
+          configured: result.model.variant.configured(),
+          pick: scope(),
+          base: undefined,
+          current: store.current,
+        })
+      })
+
+      onCleanup(() => modelProbe.clear())
+    }
+
     return result
   },
 })

+ 20 - 0
packages/app/src/context/model-variant.test.ts

@@ -44,6 +44,16 @@ describe("model variant", () => {
     expect(value).toBe("high")
   })
 
+  test("lets an explicit default override the configured variant", () => {
+    const value = resolveModelVariant({
+      variants: ["low", "high", "xhigh"],
+      selected: null,
+      configured: "xhigh",
+    })
+
+    expect(value).toBeUndefined()
+  })
+
   test("cycles from configured variant to next", () => {
     const value = cycleModelVariant({
       variants: ["low", "high", "xhigh"],
@@ -63,4 +73,14 @@ describe("model variant", () => {
 
     expect(value).toBe("low")
   })
+
+  test("cycles from an explicit default to the first variant", () => {
+    const value = cycleModelVariant({
+      variants: ["low", "high", "xhigh"],
+      selected: null,
+      configured: "xhigh",
+    })
+
+    expect(value).toBe("low")
+  })
 })

+ 3 - 1
packages/app/src/context/model-variant.ts

@@ -14,7 +14,7 @@ type Model = AgentModel & {
 
 type VariantInput = {
   variants: string[]
-  selected: string | undefined
+  selected: string | null | undefined
   configured: string | undefined
 }
 
@@ -29,6 +29,7 @@ export function getConfiguredAgentVariant(input: { agent: Agent | undefined; mod
 }
 
 export function resolveModelVariant(input: VariantInput) {
+  if (input.selected === null) return undefined
   if (input.selected && input.variants.includes(input.selected)) return input.selected
   if (input.configured && input.variants.includes(input.configured)) return input.configured
   return undefined
@@ -36,6 +37,7 @@ export function resolveModelVariant(input: VariantInput) {
 
 export function cycleModelVariant(input: VariantInput) {
   if (input.variants.length === 0) return undefined
+  if (input.selected === null) return input.variants[0]
   if (input.selected && input.variants.includes(input.selected)) {
     const index = input.variants.indexOf(input.selected)
     if (index === input.variants.length - 1) return undefined

+ 3 - 3
packages/app/src/pages/directory-layout.tsx

@@ -80,11 +80,11 @@ export default function Layout(props: ParentProps) {
   })
 
   return (
-    <Show when={state.resolved}>
+    <Show when={state.resolved} keyed>
       {(resolved) => (
-        <SDKProvider directory={resolved}>
+        <SDKProvider directory={() => resolved}>
           <SyncProvider>
-            <DirectoryDataProvider directory={resolved()}>{props.children}</DirectoryDataProvider>
+            <DirectoryDataProvider directory={resolved}>{props.children}</DirectoryDataProvider>
           </SyncProvider>
         </SDKProvider>
       )}

+ 2 - 2
packages/app/src/pages/session.tsx

@@ -44,7 +44,7 @@ import { createOpenReviewFile, createSessionTabs, createSizing, focusTerminalByI
 import { MessageTimeline } from "@/pages/session/message-timeline"
 import { type DiffStyle, SessionReviewTab, type SessionReviewTabProps } from "@/pages/session/review-tab"
 import { useSessionLayout } from "@/pages/session/session-layout"
-import { resetSessionModel, syncSessionModel } from "@/pages/session/session-model-helpers"
+import { syncSessionModel } from "@/pages/session/session-model-helpers"
 import { SessionSidePanel } from "@/pages/session/session-side-panel"
 import { TerminalPanel } from "@/pages/session/terminal-panel"
 import { useSessionCommands } from "@/pages/session/use-session-commands"
@@ -490,7 +490,7 @@ export default function Page() {
       (next, prev) => {
         if (!prev) return
         if (next.dir === prev.dir && next.id === prev.id) return
-        if (!next.id) resetSessionModel(local)
+        if (prev.id && !next.id) local.session.reset()
       },
       { defer: true },
     ),

+ 13 - 120
packages/app/src/pages/session/session-model-helpers.test.ts

@@ -14,145 +14,38 @@ const message = (input?: Partial<Pick<UserMessage, "agent" | "model" | "variant"
   }) as UserMessage
 
 describe("syncSessionModel", () => {
-  test("restores the last message model and variant", () => {
+  test("restores the last message through session state", () => {
     const calls: unknown[] = []
 
     syncSessionModel(
       {
-        agent: {
-          current() {
-            return undefined
-          },
-          set(value) {
-            calls.push(["agent", value])
-          },
-        },
-        model: {
-          set(value) {
-            calls.push(["model", value])
-          },
-          current() {
-            return { id: "claude-sonnet-4", provider: { id: "anthropic" } }
-          },
-          variant: {
-            set(value) {
-              calls.push(["variant", value])
-            },
-          },
-        },
-      },
-      message({ variant: "high" }),
-    )
-
-    expect(calls).toEqual([
-      ["agent", "build"],
-      ["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }],
-      ["variant", "high"],
-    ])
-  })
-
-  test("skips variant when the model falls back", () => {
-    const calls: unknown[] = []
-
-    syncSessionModel(
-      {
-        agent: {
-          current() {
-            return undefined
-          },
-          set(value) {
-            calls.push(["agent", value])
-          },
-        },
-        model: {
-          set(value) {
-            calls.push(["model", value])
-          },
-          current() {
-            return { id: "gpt-5", provider: { id: "openai" } }
-          },
-          variant: {
-            set(value) {
-              calls.push(["variant", value])
-            },
+        session: {
+          restore(value) {
+            calls.push(value)
           },
+          reset() {},
         },
       },
       message({ variant: "high" }),
     )
 
-    expect(calls).toEqual([
-      ["agent", "build"],
-      ["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }],
-    ])
+    expect(calls).toEqual([message({ variant: "high" })])
   })
 })
 
 describe("resetSessionModel", () => {
-  test("restores the current agent defaults", () => {
-    const calls: unknown[] = []
+  test("clears draft session state", () => {
+    const calls: string[] = []
 
     resetSessionModel({
-      agent: {
-        current() {
-          return {
-            model: { providerID: "anthropic", modelID: "claude-sonnet-4" },
-            variant: "high",
-          }
-        },
-        set() {},
-      },
-      model: {
-        set(value) {
-          calls.push(["model", value])
-        },
-        current() {
-          return undefined
-        },
-        variant: {
-          set(value) {
-            calls.push(["variant", value])
-          },
-        },
-      },
-    })
-
-    expect(calls).toEqual([
-      ["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }],
-      ["variant", "high"],
-    ])
-  })
-
-  test("clears the variant when the agent has none", () => {
-    const calls: unknown[] = []
-
-    resetSessionModel({
-      agent: {
-        current() {
-          return {
-            model: { providerID: "anthropic", modelID: "claude-sonnet-4" },
-          }
-        },
-        set() {},
-      },
-      model: {
-        set(value) {
-          calls.push(["model", value])
-        },
-        current() {
-          return undefined
-        },
-        variant: {
-          set(value) {
-            calls.push(["variant", value])
-          },
+      session: {
+        reset() {
+          calls.push("reset")
         },
+        restore() {},
       },
     })
 
-    expect(calls).toEqual([
-      ["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }],
-      ["variant", undefined],
-    ])
+    expect(calls).toEqual(["reset"])
   })
 })

+ 5 - 37
packages/app/src/pages/session/session-model-helpers.ts

@@ -1,48 +1,16 @@
 import type { UserMessage } from "@opencode-ai/sdk/v2"
-import { batch } from "solid-js"
 
 type Local = {
-  agent: {
-    current():
-      | {
-          model?: UserMessage["model"]
-          variant?: string
-        }
-      | undefined
-    set(name: string | undefined): void
-  }
-  model: {
-    set(model: UserMessage["model"] | undefined): void
-    current():
-      | {
-          id: string
-          provider: { id: string }
-        }
-      | undefined
-    variant: {
-      set(value: string | undefined): void
-    }
+  session: {
+    reset(): void
+    restore(msg: UserMessage): void
   }
 }
 
 export const resetSessionModel = (local: Local) => {
-  const agent = local.agent.current()
-  if (!agent) return
-  batch(() => {
-    local.model.set(agent.model)
-    local.model.variant.set(agent.variant)
-  })
+  local.session.reset()
 }
 
 export const syncSessionModel = (local: Local, msg: UserMessage) => {
-  batch(() => {
-    local.agent.set(msg.agent)
-    local.model.set(msg.model)
-  })
-
-  const model = local.model.current()
-  if (!model) return
-  if (model.provider.id !== msg.model.providerID) return
-  if (model.id !== msg.model.modelID) return
-  local.model.variant.set(msg.variant)
+  local.session.restore(msg)
 }

+ 1 - 1
packages/app/src/pages/session/use-session-commands.tsx

@@ -351,7 +351,7 @@ export const useSessionCommands = (actions: SessionCommandContext) => {
         description: language.t("command.model.choose.description"),
         keybind: "mod+'",
         slash: "model",
-        onSelect: () => dialog.show(() => <DialogSelectModel />),
+        onSelect: () => dialog.show(() => <DialogSelectModel model={local.model} />),
       }),
       mcpCommand({
         id: "mcp.toggle",

+ 80 - 0
packages/app/src/testing/model-selection.ts

@@ -0,0 +1,80 @@
+type ModelKey = {
+  providerID: string
+  modelID: string
+}
+
+type State = {
+  agent?: string
+  model?: ModelKey | null
+  variant?: string | null
+}
+
+export type ModelProbeState = {
+  dir?: string
+  sessionID?: string
+  last?: {
+    type: "agent" | "model" | "variant"
+    agent?: string
+    model?: ModelKey | null
+    variant?: string | null
+  }
+  agent?: string
+  model?: (ModelKey & { name?: string }) | undefined
+  variant?: string | null
+  selected?: string | null
+  configured?: string
+  pick?: State
+  base?: State
+  current?: string
+}
+
+export type ModelWindow = Window & {
+  __opencode_e2e?: {
+    model?: {
+      enabled?: boolean
+      current?: ModelProbeState
+    }
+  }
+}
+
+const clone = (state?: State) => {
+  if (!state) return undefined
+  return {
+    ...state,
+    model: state.model ? { ...state.model } : state.model,
+  }
+}
+
+export const modelEnabled = () => {
+  if (typeof window === "undefined") return false
+  return (window as ModelWindow).__opencode_e2e?.model?.enabled === true
+}
+
+const root = () => {
+  if (!modelEnabled()) return
+  return (window as ModelWindow).__opencode_e2e?.model
+}
+
+export const modelProbe = {
+  set(input: ModelProbeState) {
+    const state = root()
+    if (!state) return
+    state.current = {
+      ...input,
+      model: input.model ? { ...input.model } : undefined,
+      last: input.last
+        ? {
+            ...input.last,
+            model: input.last.model ? { ...input.last.model } : input.last.model,
+          }
+        : undefined,
+      pick: clone(input.pick),
+      base: clone(input.base),
+    }
+  },
+  clear() {
+    const state = root()
+    if (!state) return
+    state.current = undefined
+  },
+}

+ 6 - 0
packages/app/src/testing/terminal.ts

@@ -1,3 +1,5 @@
+import type { ModelProbeState } from "./model-selection"
+
 export const terminalAttr = "data-pty-id"
 
 export type TerminalProbeState = {
@@ -13,6 +15,10 @@ type TerminalProbeControl = {
 
 export type E2EWindow = Window & {
   __opencode_e2e?: {
+    model?: {
+      enabled?: boolean
+      current?: ModelProbeState
+    }
     terminal?: {
       enabled?: boolean
       terminals?: Record<string, TerminalProbeState>

+ 3 - 0
packages/ui/src/components/select.tsx

@@ -19,6 +19,7 @@ export type SelectProps<T> = Omit<ComponentProps<typeof Kobalte<T>>, "value" | "
   children?: (item: T | undefined) => JSX.Element
   triggerStyle?: JSX.CSSProperties
   triggerVariant?: "settings"
+  triggerProps?: Record<string, string | number | boolean | undefined>
 }
 
 export function Select<T>(props: SelectProps<T> & Omit<ButtonProps, "children">) {
@@ -38,6 +39,7 @@ export function Select<T>(props: SelectProps<T> & Omit<ButtonProps, "children">)
     "children",
     "triggerStyle",
     "triggerVariant",
+    "triggerProps",
   ])
 
   const state = {
@@ -131,6 +133,7 @@ export function Select<T>(props: SelectProps<T> & Omit<ButtonProps, "children">)
       }}
     >
       <Kobalte.Trigger
+        {...local.triggerProps}
         disabled={props.disabled}
         data-slot="select-select-trigger"
         as={Button}