Browse Source

feat: add model management to ACP sessions (#3358)

Yesh Yendamuri 4 months ago
parent
commit
f81e28c673

+ 66 - 3
packages/opencode/src/acp/agent.ts

@@ -12,6 +12,8 @@ import type {
   NewSessionResponse,
   PromptRequest,
   PromptResponse,
+  SetSessionModelRequest,
+  SetSessionModelResponse,
 } from "@agentclientprotocol/sdk"
 import { Log } from "../util/log"
 import { ACPSessionManager } from "./session"
@@ -55,10 +57,16 @@ export class OpenCodeAgent implements Agent {
   async newSession(params: NewSessionRequest): Promise<NewSessionResponse> {
     this.log.info("newSession", { cwd: params.cwd, mcpServers: params.mcpServers.length })
 
-    const session = await this.sessionManager.create(params.cwd, params.mcpServers)
+    const model = await this.defaultModel()
+    const session = await this.sessionManager.create(params.cwd, params.mcpServers, model)
+    const availableModels = await this.availableModels()
 
     return {
       sessionId: session.id,
+      models: {
+        currentModelId: `${model.providerID}/${model.modelID}`,
+        availableModels,
+      },
       _meta: {},
     }
   }
@@ -66,13 +74,64 @@ export class OpenCodeAgent implements Agent {
   async loadSession(params: LoadSessionRequest): Promise<LoadSessionResponse> {
     this.log.info("loadSession", { sessionId: params.sessionId, cwd: params.cwd })
 
-    await this.sessionManager.load(params.sessionId, params.cwd, params.mcpServers)
+    const defaultModel = await this.defaultModel()
+    const session = await this.sessionManager.load(params.sessionId, params.cwd, params.mcpServers, defaultModel)
+    const availableModels = await this.availableModels()
 
     return {
+      models: {
+        currentModelId: `${session.model.providerID}/${session.model.modelID}`,
+        availableModels,
+      },
       _meta: {},
     }
   }
 
+  async setSessionModel(params: SetSessionModelRequest): Promise<SetSessionModelResponse> {
+    this.log.info("setSessionModel", { sessionId: params.sessionId, modelId: params.modelId })
+
+    const session = this.sessionManager.get(params.sessionId)
+    if (!session) {
+      throw new Error(`Session not found: ${params.sessionId}`)
+    }
+
+    const parsed = Provider.parseModel(params.modelId)
+    const model = await Provider.getModel(parsed.providerID, parsed.modelID)
+
+    this.sessionManager.setModel(session.id, {
+      providerID: model.providerID,
+      modelID: model.modelID,
+    })
+
+    return {
+      _meta: {},
+    }
+  }
+
+  private async defaultModel() {
+    const configured = this.config.defaultModel
+    if (configured) return configured
+    return Provider.defaultModel()
+  }
+
+  private async availableModels() {
+    const providers = await Provider.list()
+    const entries = Object.entries(providers).sort((a, b) => {
+      const nameA = a[1].info.name.toLowerCase()
+      const nameB = b[1].info.name.toLowerCase()
+      if (nameA < nameB) return -1
+      if (nameA > nameB) return 1
+      return 0
+    })
+    return entries.flatMap(([providerID, provider]) => {
+      const models = Provider.sort(Object.values(provider.info.models))
+      return models.map((model) => ({
+        modelId: `${providerID}/${model.id}`,
+        name: `${provider.info.name}/${model.name}`,
+      }))
+    })
+  }
+
   async prompt(params: PromptRequest): Promise<PromptResponse> {
     this.log.info("prompt", {
       sessionId: params.sessionId,
@@ -84,7 +143,11 @@ export class OpenCodeAgent implements Agent {
       throw new Error(`Session not found: ${params.sessionId}`)
     }
 
-    const model = this.config.defaultModel || (await Provider.defaultModel())
+    const current = acpSession.model
+    const model = current ?? (await this.defaultModel())
+    if (!current) {
+      this.sessionManager.setModel(acpSession.id, model)
+    }
 
     const parts = params.prompt.map((content) => {
       if (content.type === "text") {

+ 35 - 2
packages/opencode/src/acp/session.ts

@@ -1,14 +1,20 @@
 import type { McpServer } from "@agentclientprotocol/sdk"
 import { Identifier } from "../id/id"
 import { Session } from "../session"
+import { Provider } from "../provider/provider"
 import type { ACPSessionState } from "./types"
 
 export class ACPSessionManager {
   private sessions = new Map<string, ACPSessionState>()
 
-  async create(cwd: string, mcpServers: McpServer[]): Promise<ACPSessionState> {
+  async create(
+    cwd: string,
+    mcpServers: McpServer[],
+    model?: ACPSessionState["model"],
+  ): Promise<ACPSessionState> {
     const sessionId = `acp_${Identifier.ascending("session")}`
     const openCodeSession = await Session.create({ title: `ACP Session ${sessionId}` })
+    const resolvedModel = model ?? (await Provider.defaultModel())
 
     const state: ACPSessionState = {
       id: sessionId,
@@ -16,6 +22,7 @@ export class ACPSessionManager {
       mcpServers,
       openCodeSessionId: openCodeSession.id,
       createdAt: new Date(),
+      model: resolvedModel,
     }
 
     this.sessions.set(sessionId, state)
@@ -38,13 +45,24 @@ export class ACPSessionManager {
     return this.sessions.has(sessionId)
   }
 
-  async load(sessionId: string, cwd: string, mcpServers: McpServer[]): Promise<ACPSessionState> {
+  async load(
+    sessionId: string,
+    cwd: string,
+    mcpServers: McpServer[],
+    model?: ACPSessionState["model"],
+  ): Promise<ACPSessionState> {
     const existing = this.sessions.get(sessionId)
     if (existing) {
+      if (!existing.model) {
+        const resolved = model ?? (await Provider.defaultModel())
+        existing.model = resolved
+        this.sessions.set(sessionId, existing)
+      }
       return existing
     }
 
     const openCodeSession = await Session.create({ title: `ACP Session ${sessionId} (loaded)` })
+    const resolvedModel = model ?? (await Provider.defaultModel())
 
     const state: ACPSessionState = {
       id: sessionId,
@@ -52,9 +70,24 @@ export class ACPSessionManager {
       mcpServers,
       openCodeSessionId: openCodeSession.id,
       createdAt: new Date(),
+      model: resolvedModel,
     }
 
     this.sessions.set(sessionId, state)
     return state
   }
+
+  getModel(sessionId: string) {
+    const session = this.sessions.get(sessionId)
+    if (!session) return
+    return session.model
+  }
+
+  setModel(sessionId: string, model: ACPSessionState["model"]) {
+    const session = this.sessions.get(sessionId)
+    if (!session) return
+    session.model = model
+    this.sessions.set(sessionId, session)
+    return session
+  }
 }

+ 4 - 0
packages/opencode/src/acp/types.ts

@@ -6,6 +6,10 @@ export interface ACPSessionState {
   mcpServers: McpServer[]
   openCodeSessionId: string
   createdAt: Date
+  model: {
+    providerID: string
+    modelID: string
+  }
 }
 
 export interface ACPConfig {