Просмотр исходного кода

fix(opencode): keep user message variants scoped to model (#21332)

Dax 1 неделя назад
Родитель
Сommit
1f94c48bdd

+ 2 - 3
packages/app/src/components/prompt-input/submit.test.ts

@@ -146,7 +146,7 @@ beforeAll(async () => {
           add: (value: {
             directory?: string
             sessionID?: string
-            message: { agent: string; model: { providerID: string; modelID: string }; variant?: string }
+            message: { agent: string; model: { providerID: string; modelID: string; variant?: string } }
           }) => {
             optimistic.push(value)
             optimisticSeeded.push(
@@ -310,8 +310,7 @@ describe("prompt submit worktree selection", () => {
     expect(optimistic[0]).toMatchObject({
       message: {
         agent: "agent",
-        model: { providerID: "provider", modelID: "model" },
-        variant: "high",
+        model: { providerID: "provider", modelID: "model", variant: "high" },
       },
     })
   })

+ 1 - 2
packages/app/src/components/prompt-input/submit.ts

@@ -121,8 +121,7 @@ export async function sendFollowupDraft(input: FollowupSendInput) {
     role: "user",
     time: { created: Date.now() },
     agent: input.draft.agent,
-    model: input.draft.model,
-    variant: input.draft.variant,
+    model: { ...input.draft.model, variant: input.draft.variant },
   }
 
   const add = () =>

+ 3 - 3
packages/app/src/context/local.tsx

@@ -11,7 +11,7 @@ import { cycleModelVariant, getConfiguredAgentVariant, resolveModelVariant } fro
 import { useSDK } from "./sdk"
 import { useSync } from "./sync"
 
-export type ModelKey = { providerID: string; modelID: string }
+export type ModelKey = { providerID: string; modelID: string; variant?: string }
 
 type State = {
   agent?: string
@@ -373,7 +373,7 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
           handoff.set(handoffKey(dir, session), next)
           setStore("draft", undefined)
         },
-        restore(msg: { sessionID: string; agent: string; model: ModelKey; variant?: string }) {
+        restore(msg: { sessionID: string; agent: string; model: ModelKey }) {
           const session = id()
           if (!session) return
           if (msg.sessionID !== session) return
@@ -383,7 +383,7 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
           setSaved("session", session, {
             agent: msg.agent,
             model: msg.model,
-            variant: msg.variant ?? null,
+            variant: msg.model.variant ?? null,
           })
         },
       },

+ 1 - 2
packages/app/src/context/sync.tsx

@@ -416,8 +416,7 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({
             role: "user",
             time: { created: Date.now() },
             agent: input.agent,
-            model: input.model,
-            variant: input.variant,
+            model: { ...input.model, variant: input.variant },
           }
           const [, setStore] = target()
           setOptimistic(sdk.directory, input.sessionID, { message, parts: input.parts })

+ 5 - 4
packages/app/src/pages/session/session-model-helpers.test.ts

@@ -2,7 +2,7 @@ import { describe, expect, test } from "bun:test"
 import type { UserMessage } from "@opencode-ai/sdk/v2"
 import { resetSessionModel, syncSessionModel } from "./session-model-helpers"
 
-const message = (input?: Partial<Pick<UserMessage, "agent" | "model" | "variant">>) =>
+const message = (input?: { agent?: string; model?: UserMessage["model"] }) =>
   ({
     id: "msg",
     sessionID: "session",
@@ -10,7 +10,6 @@ const message = (input?: Partial<Pick<UserMessage, "agent" | "model" | "variant"
     time: { created: 1 },
     agent: input?.agent ?? "build",
     model: input?.model ?? { providerID: "anthropic", modelID: "claude-sonnet-4" },
-    variant: input?.variant,
   }) as UserMessage
 
 describe("syncSessionModel", () => {
@@ -26,10 +25,12 @@ describe("syncSessionModel", () => {
           reset() {},
         },
       },
-      message({ variant: "high" }),
+      message({ model: { providerID: "anthropic", modelID: "claude-sonnet-4", variant: "high" } }),
     )
 
-    expect(calls).toEqual([message({ variant: "high" })])
+    expect(calls).toEqual([
+      message({ model: { providerID: "anthropic", modelID: "claude-sonnet-4", variant: "high" } }),
+    ])
   })
 })
 

+ 6 - 4
packages/opencode/src/cli/cmd/tui/component/prompt/index.tsx

@@ -23,7 +23,7 @@ import { useRenderer, type JSX } from "@opentui/solid"
 import { Editor } from "@tui/util/editor"
 import { useExit } from "../../context/exit"
 import { Clipboard } from "../../util/clipboard"
-import type { AssistantMessage, FilePart } from "@opencode-ai/sdk/v2"
+import type { AssistantMessage, FilePart, UserMessage } from "@opencode-ai/sdk/v2"
 import { TuiEvent } from "../../event"
 import { iife } from "@/util/iife"
 import { Locale } from "@/util/locale"
@@ -145,7 +145,7 @@ export function Prompt(props: PromptProps) {
     if (!props.sessionID) return undefined
     const messages = sync.data.message[props.sessionID]
     if (!messages) return undefined
-    return messages.findLast((m) => m.role === "user")
+    return messages.findLast((m): m is UserMessage => m.role === "user")
   })
 
   const usage = createMemo(() => {
@@ -209,8 +209,10 @@ export function Prompt(props: PromptProps) {
       const isPrimaryAgent = local.agent.list().some((x) => x.name === msg.agent)
       if (msg.agent && isPrimaryAgent) {
         local.agent.set(msg.agent)
-        if (msg.model) local.model.set(msg.model)
-        if (msg.variant) local.model.variant.set(msg.variant)
+        if (msg.model) {
+          local.model.set(msg.model)
+          local.model.variant.set(msg.model.variant)
+        }
       }
     }
   })

+ 1 - 2
packages/opencode/src/session/compaction.ts

@@ -228,7 +228,7 @@ When constructing the summary, try to stick to this template:
           sessionID: input.sessionID,
           mode: "compaction",
           agent: "compaction",
-          variant: userMessage.variant,
+          variant: userMessage.model.variant,
           summary: true,
           path: {
             cwd: ctx.directory,
@@ -295,7 +295,6 @@ When constructing the summary, try to stick to this template:
               format: original.format,
               tools: original.tools,
               system: original.system,
-              variant: original.variant,
             })
             for (const part of replay.parts) {
               if (part.type === "compaction") continue

+ 3 - 1
packages/opencode/src/session/llm.ts

@@ -127,7 +127,9 @@ export namespace LLM {
     }
 
     const variant =
-      !input.small && input.model.variants && input.user.variant ? input.model.variants[input.user.variant] : {}
+      !input.small && input.model.variants && input.user.model.variant
+        ? input.model.variants[input.user.model.variant]
+        : {}
     const base = input.small
       ? ProviderTransform.smallOptions(input.model)
       : ProviderTransform.options({

+ 1 - 1
packages/opencode/src/session/message-v2.ts

@@ -371,10 +371,10 @@ export namespace MessageV2 {
     model: z.object({
       providerID: ProviderID.zod,
       modelID: ModelID.zod,
+      variant: z.string().optional(),
     }),
     system: z.string().optional(),
     tools: z.record(z.string(), z.boolean()).optional(),
-    variant: z.string().optional(),
   }).meta({
     ref: "UserMessage",
   })

+ 8 - 5
packages/opencode/src/session/prompt.ts

@@ -569,7 +569,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
           sessionID,
           mode: task.agent,
           agent: task.agent,
-          variant: lastUser.variant,
+          variant: lastUser.model.variant,
           path: { cwd: ctx.directory, root: ctx.worktree },
           cost: 0,
           tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } },
@@ -967,17 +967,20 @@ NOTE: At any point in time through this workflow you should feel free to ask the
             : undefined
         const variant = input.variant ?? (ag.variant && full?.variants?.[ag.variant] ? ag.variant : undefined)
 
-        const info: MessageV2.Info = {
+        const info: MessageV2.User = {
           id: input.messageID ?? MessageID.ascending(),
           role: "user",
           sessionID: input.sessionID,
           time: { created: Date.now() },
           tools: input.tools,
           agent: ag.name,
-          model,
+          model: {
+            providerID: model.providerID,
+            modelID: model.modelID,
+            variant,
+          },
           system: input.system,
           format: input.format,
-          variant,
         }
 
         yield* Effect.addFinalizer(() =>
@@ -1436,7 +1439,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
               role: "assistant",
               mode: agent.name,
               agent: agent.name,
-              variant: lastUser.variant,
+              variant: lastUser.model.variant,
               path: { cwd: ctx.directory, root: ctx.worktree },
               cost: 0,
               tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } },

+ 2 - 4
packages/opencode/test/session/llm.test.ts

@@ -342,8 +342,7 @@ describe("session.llm.stream", () => {
           role: "user",
           time: { created: Date.now() },
           agent: agent.name,
-          model: { providerID: ProviderID.make(providerID), modelID: resolved.id },
-          variant: "high",
+          model: { providerID: ProviderID.make(providerID), modelID: resolved.id, variant: "high" },
         } satisfies MessageV2.User
 
         const stream = await LLM.stream({
@@ -716,8 +715,7 @@ describe("session.llm.stream", () => {
           role: "user",
           time: { created: Date.now() },
           agent: agent.name,
-          model: { providerID: ProviderID.make("openai"), modelID: resolved.id },
-          variant: "high",
+          model: { providerID: ProviderID.make("openai"), modelID: resolved.id, variant: "high" },
         } satisfies MessageV2.User
 
         const stream = await LLM.stream({

+ 8 - 4
packages/opencode/test/session/prompt.test.ts

@@ -410,7 +410,7 @@ describe("session.prompt agent variant", () => {
             parts: [{ type: "text", text: "hello" }],
           })
           if (other.info.role !== "user") throw new Error("expected user message")
-          expect(other.info.variant).toBeUndefined()
+          expect(other.info.model.variant).toBeUndefined()
 
           const match = await SessionPrompt.prompt({
             sessionID: session.id,
@@ -419,8 +419,12 @@ describe("session.prompt agent variant", () => {
             parts: [{ type: "text", text: "hello again" }],
           })
           if (match.info.role !== "user") throw new Error("expected user message")
-          expect(match.info.model).toEqual({ providerID: ProviderID.make("openai"), modelID: ModelID.make("gpt-5.2") })
-          expect(match.info.variant).toBe("xhigh")
+          expect(match.info.model).toEqual({
+            providerID: ProviderID.make("openai"),
+            modelID: ModelID.make("gpt-5.2"),
+            variant: "xhigh",
+          })
+          expect(match.info.model.variant).toBe("xhigh")
 
           const override = await SessionPrompt.prompt({
             sessionID: session.id,
@@ -430,7 +434,7 @@ describe("session.prompt agent variant", () => {
             parts: [{ type: "text", text: "hello third" }],
           })
           if (override.info.role !== "user") throw new Error("expected user message")
-          expect(override.info.variant).toBe("high")
+          expect(override.info.model.variant).toBe("high")
 
           await Session.remove(session.id)
         },

+ 1 - 1
packages/sdk/js/src/v2/gen/types.gen.ts

@@ -548,12 +548,12 @@ export type UserMessage = {
   model: {
     providerID: string
     modelID: string
+    variant?: string
   }
   system?: string
   tools?: {
     [key: string]: boolean
   }
-  variant?: string
 }
 
 export type AssistantMessage = {