فهرست منبع

allow temperature to be configured per mode

Dax Raad 7 ماه پیش
والد
کامیت
e97613ef9f

+ 1 - 0
packages/opencode/src/config/config.ts

@@ -99,6 +99,7 @@ export namespace Config {
   export const Mode = z
     .object({
       model: z.string().optional(),
+      temperature: z.number().optional(),
       prompt: z.string().optional(),
       tools: z.record(z.string(), z.boolean()).optional(),
       disable: z.boolean().optional(),

+ 0 - 142
packages/opencode/src/provider/provider.ts

@@ -5,22 +5,11 @@ import { mergeDeep, sortBy } from "remeda"
 import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
 import { Log } from "../util/log"
 import { BunProc } from "../bun"
-import { BashTool } from "../tool/bash"
-import { EditTool } from "../tool/edit"
-import { WebFetchTool } from "../tool/webfetch"
-import { GlobTool } from "../tool/glob"
-import { GrepTool } from "../tool/grep"
-import { ListTool } from "../tool/ls"
-import { PatchTool } from "../tool/patch"
-import { ReadTool } from "../tool/read"
-import { WriteTool } from "../tool/write"
-import { TodoReadTool, TodoWriteTool } from "../tool/todo"
 import { AuthAnthropic } from "../auth/anthropic"
 import { AuthCopilot } from "../auth/copilot"
 import { ModelsDev } from "./models"
 import { NamedError } from "../util/error"
 import { Auth } from "../auth"
-import { TaskTool } from "../tool/task"
 
 export namespace Provider {
   const log = Log.create({ service: "provider" })
@@ -468,137 +457,6 @@ export namespace Provider {
     }
   }
 
-  const TOOLS = [
-    BashTool,
-    EditTool,
-    WebFetchTool,
-    GlobTool,
-    GrepTool,
-    ListTool,
-    // LspDiagnosticTool,
-    // LspHoverTool,
-    PatchTool,
-    ReadTool,
-    // MultiEditTool,
-    WriteTool,
-    TodoWriteTool,
-    TodoReadTool,
-    TaskTool,
-  ]
-
-  export async function tools(providerID: string) {
-    const result = await Promise.all(TOOLS.map((t) => t()))
-    switch (providerID) {
-      case "anthropic":
-        return result.filter((t) => t.id !== "patch")
-      case "openai":
-        return result.map((t) => ({
-          ...t,
-          parameters: optionalToNullable(t.parameters),
-        }))
-      case "azure":
-        return result.map((t) => ({
-          ...t,
-          parameters: optionalToNullable(t.parameters),
-        }))
-      case "google":
-        return result.map((t) => ({
-          ...t,
-          parameters: sanitizeGeminiParameters(t.parameters),
-        }))
-      default:
-        return result
-    }
-  }
-
-  function sanitizeGeminiParameters(schema: z.ZodTypeAny, visited = new Set()): z.ZodTypeAny {
-    if (!schema || visited.has(schema)) {
-      return schema
-    }
-    visited.add(schema)
-
-    if (schema instanceof z.ZodDefault) {
-      const innerSchema = schema.removeDefault()
-      // Handle Gemini's incompatibility with `default` on `anyOf` (unions).
-      if (innerSchema instanceof z.ZodUnion) {
-        // The schema was `z.union(...).default(...)`, which is not allowed.
-        // We strip the default and return the sanitized union.
-        return sanitizeGeminiParameters(innerSchema, visited)
-      }
-      // Otherwise, the default is on a regular type, which is allowed.
-      // We recurse on the inner type and then re-apply the default.
-      return sanitizeGeminiParameters(innerSchema, visited).default(schema._def.defaultValue())
-    }
-
-    if (schema instanceof z.ZodOptional) {
-      return z.optional(sanitizeGeminiParameters(schema.unwrap(), visited))
-    }
-
-    if (schema instanceof z.ZodObject) {
-      const newShape: Record<string, z.ZodTypeAny> = {}
-      for (const [key, value] of Object.entries(schema.shape)) {
-        newShape[key] = sanitizeGeminiParameters(value as z.ZodTypeAny, visited)
-      }
-      return z.object(newShape)
-    }
-
-    if (schema instanceof z.ZodArray) {
-      return z.array(sanitizeGeminiParameters(schema.element, visited))
-    }
-
-    if (schema instanceof z.ZodUnion) {
-      // This schema corresponds to `anyOf` in JSON Schema.
-      // We recursively sanitize each option in the union.
-      const sanitizedOptions = schema.options.map((option: z.ZodTypeAny) => sanitizeGeminiParameters(option, visited))
-      return z.union(sanitizedOptions as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
-    }
-
-    if (schema instanceof z.ZodString) {
-      const newSchema = z.string({ description: schema.description })
-      const safeChecks = ["min", "max", "length", "regex", "startsWith", "endsWith", "includes", "trim"]
-      // rome-ignore lint/suspicious/noExplicitAny: <explanation>
-      ;(newSchema._def as any).checks = (schema._def as z.ZodStringDef).checks.filter((check) =>
-        safeChecks.includes(check.kind),
-      )
-      return newSchema
-    }
-
-    return schema
-  }
-  function optionalToNullable(schema: z.ZodTypeAny): z.ZodTypeAny {
-    if (schema instanceof z.ZodObject) {
-      const shape = schema.shape
-      const newShape: Record<string, z.ZodTypeAny> = {}
-
-      for (const [key, value] of Object.entries(shape)) {
-        const zodValue = value as z.ZodTypeAny
-        if (zodValue instanceof z.ZodOptional) {
-          newShape[key] = zodValue.unwrap().nullable()
-        } else {
-          newShape[key] = optionalToNullable(zodValue)
-        }
-      }
-
-      return z.object(newShape)
-    }
-
-    if (schema instanceof z.ZodArray) {
-      return z.array(optionalToNullable(schema.element))
-    }
-
-    if (schema instanceof z.ZodUnion) {
-      return z.union(
-        schema.options.map((option: z.ZodTypeAny) => optionalToNullable(option)) as [
-          z.ZodTypeAny,
-          z.ZodTypeAny,
-          ...z.ZodTypeAny[],
-        ],
-      )
-    }
-
-    return schema
-  }
-
   export const ModelNotFoundError = NamedError.create(
     "ProviderModelNotFoundError",
     z.object({

+ 5 - 0
packages/opencode/src/provider/transform.ts

@@ -44,4 +44,9 @@ export namespace ProviderTransform {
     }
     return msgs
   }
+
+  export function temperature(_providerID: string, modelID: string) {
+    if (modelID.includes("qwen")) return 0.55
+    return 0
+  }
 }

+ 13 - 7
packages/opencode/src/session/index.ts

@@ -39,7 +39,8 @@ import { MessageV2 } from "./message-v2"
 import { Mode } from "./mode"
 import { LSP } from "../lsp"
 import { ReadTool } from "../tool/read"
-import { splitWhen } from "remeda"
+import { mergeDeep, pipe, splitWhen } from "remeda"
+import { ToolRegistry } from "../tool/registry"
 
 export namespace Session {
   const log = Log.create({ service: "session" })
@@ -430,7 +431,7 @@ export namespace Session {
                   }
                 }
                 const args = { filePath, offset, limit }
-                const result = await ReadTool().then((t) =>
+                const result = await ReadTool.init().then((t) =>
                   t.execute(args, {
                     sessionID: input.sessionID,
                     abort: new AbortController().signal,
@@ -660,10 +661,13 @@ export namespace Session {
 
     const processor = createProcessor(assistantMsg, model.info)
 
-    for (const item of await Provider.tools(input.providerID)) {
-      if (mode.tools[item.id] === false) continue
-      if (input.tools?.[item.id] === false) continue
-      if (session.parentID && item.id === "task") continue
+    const enabledTools = pipe(
+      mode.tools,
+      mergeDeep(ToolRegistry.enabled(input.providerID, input.modelID)),
+      mergeDeep(input.tools ?? {}),
+    )
+    for (const item of await ToolRegistry.tools(input.providerID, input.modelID)) {
+      if (enabledTools[item.id] === false) continue
       tools[item.id] = tool({
         id: item.id as any,
         description: item.description,
@@ -791,7 +795,9 @@ export namespace Session {
         ),
         ...MessageV2.toModelMessage(msgs),
       ],
-      temperature: model.info.temperature ? 0 : undefined,
+      temperature: model.info.temperature
+        ? (mode.temperature ?? ProviderTransform.temperature(input.providerID, input.modelID))
+        : undefined,
       tools: model.info.tool_call === false ? undefined : tools,
       model: wrapLanguageModel({
         model: model.language,

+ 2 - 0
packages/opencode/src/session/mode.ts

@@ -7,6 +7,7 @@ export namespace Mode {
   export const Info = z
     .object({
       name: z.string(),
+      temperature: z.number().optional(),
       model: z
         .object({
           modelID: z.string(),
@@ -50,6 +51,7 @@ export namespace Mode {
       item.name = key
       if (value.model) item.model = Provider.parseModel(value.model)
       if (value.prompt) item.prompt = value.prompt
+      if (value.temperature) item.temperature = value.temperature
       if (value.tools)
         item.tools = {
           ...value.tools,

+ 1 - 2
packages/opencode/src/tool/bash.ts

@@ -7,8 +7,7 @@ const MAX_OUTPUT_LENGTH = 30000
 const DEFAULT_TIMEOUT = 1 * 60 * 1000
 const MAX_TIMEOUT = 10 * 60 * 1000
 
-export const BashTool = Tool.define({
-  id: "bash",
+export const BashTool = Tool.define("bash", {
   description: DESCRIPTION,
   parameters: z.object({
     command: z.string().describe("The command to execute"),

+ 1 - 2
packages/opencode/src/tool/edit.ts

@@ -14,8 +14,7 @@ import { File } from "../file"
 import { Bus } from "../bus"
 import { FileTime } from "../file/time"
 
-export const EditTool = Tool.define({
-  id: "edit",
+export const EditTool = Tool.define("edit", {
   description: DESCRIPTION,
   parameters: z.object({
     filePath: z.string().describe("The absolute path to the file to modify"),

+ 1 - 2
packages/opencode/src/tool/glob.ts

@@ -5,8 +5,7 @@ import { App } from "../app/app"
 import DESCRIPTION from "./glob.txt"
 import { Ripgrep } from "../file/ripgrep"
 
-export const GlobTool = Tool.define({
-  id: "glob",
+export const GlobTool = Tool.define("glob", {
   description: DESCRIPTION,
   parameters: z.object({
     pattern: z.string().describe("The glob pattern to match files against"),

+ 1 - 2
packages/opencode/src/tool/grep.ts

@@ -5,8 +5,7 @@ import { Ripgrep } from "../file/ripgrep"
 
 import DESCRIPTION from "./grep.txt"
 
-export const GrepTool = Tool.define({
-  id: "grep",
+export const GrepTool = Tool.define("grep", {
   description: DESCRIPTION,
   parameters: z.object({
     pattern: z.string().describe("The regex pattern to search for in file contents"),

+ 1 - 2
packages/opencode/src/tool/ls.ts

@@ -33,8 +33,7 @@ export const IGNORE_PATTERNS = [
 
 const LIMIT = 100
 
-export const ListTool = Tool.define({
-  id: "list",
+export const ListTool = Tool.define("list", {
   description: DESCRIPTION,
   parameters: z.object({
     path: z.string().describe("The absolute path to the directory to list (must be absolute, not relative)").optional(),

+ 1 - 2
packages/opencode/src/tool/lsp-diagnostics.ts

@@ -5,8 +5,7 @@ import { LSP } from "../lsp"
 import { App } from "../app/app"
 import DESCRIPTION from "./lsp-diagnostics.txt"
 
-export const LspDiagnosticTool = Tool.define({
-  id: "lsp_diagnostics",
+export const LspDiagnosticTool = Tool.define("lsp_diagnostics", {
   description: DESCRIPTION,
   parameters: z.object({
     path: z.string().describe("The path to the file to get diagnostics."),

+ 1 - 2
packages/opencode/src/tool/lsp-hover.ts

@@ -5,8 +5,7 @@ import { LSP } from "../lsp"
 import { App } from "../app/app"
 import DESCRIPTION from "./lsp-hover.txt"
 
-export const LspHoverTool = Tool.define({
-  id: "lsp_hover",
+export const LspHoverTool = Tool.define("lsp_hover", {
   description: DESCRIPTION,
   parameters: z.object({
     file: z.string().describe("The path to the file to get diagnostics."),

+ 2 - 3
packages/opencode/src/tool/multiedit.ts

@@ -5,8 +5,7 @@ import DESCRIPTION from "./multiedit.txt"
 import path from "path"
 import { App } from "../app/app"
 
-export const MultiEditTool = Tool.define({
-  id: "multiedit",
+export const MultiEditTool = Tool.define("multiedit", {
   description: DESCRIPTION,
   parameters: z.object({
     filePath: z.string().describe("The absolute path to the file to modify"),
@@ -22,7 +21,7 @@ export const MultiEditTool = Tool.define({
       .describe("Array of edit operations to perform sequentially on the file"),
   }),
   async execute(params, ctx) {
-    const tool = await EditTool()
+    const tool = await EditTool.init()
     const results = []
     for (const [, edit] of params.edits.entries()) {
       const result = await tool.execute(

+ 1 - 2
packages/opencode/src/tool/patch.ts

@@ -210,8 +210,7 @@ async function applyCommit(
   }
 }
 
-export const PatchTool = Tool.define({
-  id: "patch",
+export const PatchTool = Tool.define("patch", {
   description: DESCRIPTION,
   parameters: PatchParams,
   execute: async (params, ctx) => {

+ 1 - 2
packages/opencode/src/tool/read.ts

@@ -10,8 +10,7 @@ import { App } from "../app/app"
 const DEFAULT_READ_LIMIT = 2000
 const MAX_LINE_LENGTH = 2000
 
-export const ReadTool = Tool.define({
-  id: "read",
+export const ReadTool = Tool.define("read", {
   description: DESCRIPTION,
   parameters: z.object({
     filePath: z.string().describe("The path to the file to read"),

+ 170 - 0
packages/opencode/src/tool/registry.ts

@@ -0,0 +1,170 @@
+import z from "zod"
+import { BashTool } from "./bash"
+import { EditTool } from "./edit"
+import { GlobTool } from "./glob"
+import { GrepTool } from "./grep"
+import { ListTool } from "./ls"
+import { PatchTool } from "./patch"
+import { ReadTool } from "./read"
+import { TaskTool } from "./task"
+import { TodoWriteTool, TodoReadTool } from "./todo"
+import { WebFetchTool } from "./webfetch"
+import { WriteTool } from "./write"
+
+export namespace ToolRegistry {
+  const ALL = [
+    BashTool,
+    EditTool,
+    WebFetchTool,
+    GlobTool,
+    GrepTool,
+    ListTool,
+    PatchTool,
+    ReadTool,
+    WriteTool,
+    TodoWriteTool,
+    TodoReadTool,
+    TaskTool,
+  ]
+
+  export function ids() {
+    return ALL.map((t) => t.id)
+  }
+
+  export async function tools(providerID: string, _modelID: string) {
+    const result = await Promise.all(
+      ALL.map(async (t) => ({
+        id: t.id,
+        ...(await t.init()),
+      })),
+    )
+
+    if (providerID === "openai") {
+      return result.map((t) => ({
+        ...t,
+        parameters: optionalToNullable(t.parameters),
+      }))
+    }
+
+    if (providerID === "azure") {
+      return result.map((t) => ({
+        ...t,
+        parameters: optionalToNullable(t.parameters),
+      }))
+    }
+
+    if (providerID === "google") {
+      return result.map((t) => ({
+        ...t,
+        parameters: sanitizeGeminiParameters(t.parameters),
+      }))
+    }
+
+    return result
+  }
+
+  export function enabled(_providerID: string, modelID: string): Record<string, boolean> {
+    if (modelID.includes("claude")) {
+      return {
+        patch: false,
+      }
+    }
+    if (modelID.includes("qwen")) {
+      return {
+        patch: false,
+        todowrite: false,
+        todoread: false,
+      }
+    }
+    return {}
+  }
+
+  function sanitizeGeminiParameters(schema: z.ZodTypeAny, visited = new Set()): z.ZodTypeAny {
+    if (!schema || visited.has(schema)) {
+      return schema
+    }
+    visited.add(schema)
+
+    if (schema instanceof z.ZodDefault) {
+      const innerSchema = schema.removeDefault()
+      // Handle Gemini's incompatibility with `default` on `anyOf` (unions).
+      if (innerSchema instanceof z.ZodUnion) {
+        // The schema was `z.union(...).default(...)`, which is not allowed.
+        // We strip the default and return the sanitized union.
+        return sanitizeGeminiParameters(innerSchema, visited)
+      }
+      // Otherwise, the default is on a regular type, which is allowed.
+      // We recurse on the inner type and then re-apply the default.
+      return sanitizeGeminiParameters(innerSchema, visited).default(schema._def.defaultValue())
+    }
+
+    if (schema instanceof z.ZodOptional) {
+      return z.optional(sanitizeGeminiParameters(schema.unwrap(), visited))
+    }
+
+    if (schema instanceof z.ZodObject) {
+      const newShape: Record<string, z.ZodTypeAny> = {}
+      for (const [key, value] of Object.entries(schema.shape)) {
+        newShape[key] = sanitizeGeminiParameters(value as z.ZodTypeAny, visited)
+      }
+      return z.object(newShape)
+    }
+
+    if (schema instanceof z.ZodArray) {
+      return z.array(sanitizeGeminiParameters(schema.element, visited))
+    }
+
+    if (schema instanceof z.ZodUnion) {
+      // This schema corresponds to `anyOf` in JSON Schema.
+      // We recursively sanitize each option in the union.
+      const sanitizedOptions = schema.options.map((option: z.ZodTypeAny) => sanitizeGeminiParameters(option, visited))
+      return z.union(sanitizedOptions as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
+    }
+
+    if (schema instanceof z.ZodString) {
+      const newSchema = z.string({ description: schema.description })
+      const safeChecks = ["min", "max", "length", "regex", "startsWith", "endsWith", "includes", "trim"]
+      // rome-ignore lint/suspicious/noExplicitAny: <explanation>
+      ;(newSchema._def as any).checks = (schema._def as z.ZodStringDef).checks.filter((check) =>
+        safeChecks.includes(check.kind),
+      )
+      return newSchema
+    }
+
+    return schema
+  }
+
+  function optionalToNullable(schema: z.ZodTypeAny): z.ZodTypeAny {
+    if (schema instanceof z.ZodObject) {
+      const shape = schema.shape
+      const newShape: Record<string, z.ZodTypeAny> = {}
+
+      for (const [key, value] of Object.entries(shape)) {
+        const zodValue = value as z.ZodTypeAny
+        if (zodValue instanceof z.ZodOptional) {
+          newShape[key] = zodValue.unwrap().nullable()
+        } else {
+          newShape[key] = optionalToNullable(zodValue)
+        }
+      }
+
+      return z.object(newShape)
+    }
+
+    if (schema instanceof z.ZodArray) {
+      return z.array(optionalToNullable(schema.element))
+    }
+
+    if (schema instanceof z.ZodUnion) {
+      return z.union(
+        schema.options.map((option: z.ZodTypeAny) => optionalToNullable(option)) as [
+          z.ZodTypeAny,
+          z.ZodTypeAny,
+          ...z.ZodTypeAny[],
+        ],
+      )
+    }
+
+    return schema
+  }
+}

+ 5 - 3
packages/opencode/src/tool/task.ts

@@ -7,11 +7,10 @@ import { MessageV2 } from "../session/message-v2"
 import { Identifier } from "../id/id"
 import { Agent } from "../agent/agent"
 
-export const TaskTool = Tool.define(async () => {
+export const TaskTool = Tool.define("task", async () => {
   const agents = await Agent.list()
   const description = DESCRIPTION.replace("{agents}", agents.map((a) => `- ${a.name}: ${a.description}`).join("\n"))
   return {
-    id: "task",
     description,
     parameters: z.object({
       description: z.string().describe("A short (3-5 words) description of the task"),
@@ -53,7 +52,10 @@ export const TaskTool = Tool.define(async () => {
         providerID: model.providerID,
         mode: msg.mode,
         system: agent.prompt,
-        tools: agent.tools,
+        tools: {
+          ...agent.tools,
+          task: false,
+        },
         parts: [
           {
             id: Identifier.ascending("part"),

+ 2 - 4
packages/opencode/src/tool/todo.ts

@@ -18,8 +18,7 @@ const state = App.state("todo-tool", () => {
   return todos
 })
 
-export const TodoWriteTool = Tool.define({
-  id: "todowrite",
+export const TodoWriteTool = Tool.define("todowrite", {
   description: DESCRIPTION_WRITE,
   parameters: z.object({
     todos: z.array(TodoInfo).describe("The updated todo list"),
@@ -37,8 +36,7 @@ export const TodoWriteTool = Tool.define({
   },
 })
 
-export const TodoReadTool = Tool.define({
-  id: "todoread",
+export const TodoReadTool = Tool.define("todoread", {
   description: "Use this tool to read your todo list",
   parameters: z.object({}),
   async execute(_params, opts) {

+ 21 - 12
packages/opencode/src/tool/tool.ts

@@ -12,21 +12,30 @@ export namespace Tool {
   }
   export interface Info<Parameters extends StandardSchemaV1 = StandardSchemaV1, M extends Metadata = Metadata> {
     id: string
-    description: string
-    parameters: Parameters
-    execute(
-      args: StandardSchemaV1.InferOutput<Parameters>,
-      ctx: Context,
-    ): Promise<{
-      title: string
-      metadata: M
-      output: string
+    init: () => Promise<{
+      description: string
+      parameters: Parameters
+      execute(
+        args: StandardSchemaV1.InferOutput<Parameters>,
+        ctx: Context,
+      ): Promise<{
+        title: string
+        metadata: M
+        output: string
+      }>
     }>
   }
 
   export function define<Parameters extends StandardSchemaV1, Result extends Metadata>(
-    input: Info<Parameters, Result> | (() => Promise<Info<Parameters, Result>>),
-  ): () => Promise<Info<Parameters, Result>> {
-    return input instanceof Function ? input : async () => input
+    id: string,
+    init: Info<Parameters, Result>["init"] | Awaited<ReturnType<Info<Parameters, Result>["init"]>>,
+  ): Info<Parameters, Result> {
+    return {
+      id,
+      init: async () => {
+        if (init instanceof Function) return init()
+        return init
+      },
+    }
   }
 }

+ 1 - 2
packages/opencode/src/tool/webfetch.ts

@@ -7,8 +7,7 @@ const MAX_RESPONSE_SIZE = 5 * 1024 * 1024 // 5MB
 const DEFAULT_TIMEOUT = 30 * 1000 // 30 seconds
 const MAX_TIMEOUT = 120 * 1000 // 2 minutes
 
-export const WebFetchTool = Tool.define({
-  id: "webfetch",
+export const WebFetchTool = Tool.define("webfetch", {
   description: DESCRIPTION,
   parameters: z.object({
     url: z.string().describe("The URL to fetch content from"),

+ 1 - 2
packages/opencode/src/tool/write.ts

@@ -9,8 +9,7 @@ import { Bus } from "../bus"
 import { File } from "../file"
 import { FileTime } from "../file/time"
 
-export const WriteTool = Tool.define({
-  id: "write",
+export const WriteTool = Tool.define("write", {
   description: DESCRIPTION,
   parameters: z.object({
     filePath: z.string().describe("The absolute path to the file to write (must be absolute, not relative)"),

+ 2 - 2
packages/opencode/test/tool/tool.test.ts

@@ -9,8 +9,8 @@ const ctx = {
   abort: AbortSignal.any([]),
   metadata: () => {},
 }
-const glob = await GlobTool()
-const list = await ListTool()
+const glob = await GlobTool.init()
+const list = await ListTool.init()
 
 describe("tool.glob", () => {
   test("truncate", async () => {