|
|
@@ -5,22 +5,11 @@ import { mergeDeep, sortBy } from "remeda"
|
|
|
import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
|
|
|
import { Log } from "../util/log"
|
|
|
import { BunProc } from "../bun"
|
|
|
-import { BashTool } from "../tool/bash"
|
|
|
-import { EditTool } from "../tool/edit"
|
|
|
-import { WebFetchTool } from "../tool/webfetch"
|
|
|
-import { GlobTool } from "../tool/glob"
|
|
|
-import { GrepTool } from "../tool/grep"
|
|
|
-import { ListTool } from "../tool/ls"
|
|
|
-import { PatchTool } from "../tool/patch"
|
|
|
-import { ReadTool } from "../tool/read"
|
|
|
-import { WriteTool } from "../tool/write"
|
|
|
-import { TodoReadTool, TodoWriteTool } from "../tool/todo"
|
|
|
import { AuthAnthropic } from "../auth/anthropic"
|
|
|
import { AuthCopilot } from "../auth/copilot"
|
|
|
import { ModelsDev } from "./models"
|
|
|
import { NamedError } from "../util/error"
|
|
|
import { Auth } from "../auth"
|
|
|
-import { TaskTool } from "../tool/task"
|
|
|
|
|
|
export namespace Provider {
|
|
|
const log = Log.create({ service: "provider" })
|
|
|
@@ -468,137 +457,6 @@ export namespace Provider {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- const TOOLS = [
|
|
|
- BashTool,
|
|
|
- EditTool,
|
|
|
- WebFetchTool,
|
|
|
- GlobTool,
|
|
|
- GrepTool,
|
|
|
- ListTool,
|
|
|
- // LspDiagnosticTool,
|
|
|
- // LspHoverTool,
|
|
|
- PatchTool,
|
|
|
- ReadTool,
|
|
|
- // MultiEditTool,
|
|
|
- WriteTool,
|
|
|
- TodoWriteTool,
|
|
|
- TodoReadTool,
|
|
|
- TaskTool,
|
|
|
- ]
|
|
|
-
|
|
|
- export async function tools(providerID: string) {
|
|
|
- const result = await Promise.all(TOOLS.map((t) => t()))
|
|
|
- switch (providerID) {
|
|
|
- case "anthropic":
|
|
|
- return result.filter((t) => t.id !== "patch")
|
|
|
- case "openai":
|
|
|
- return result.map((t) => ({
|
|
|
- ...t,
|
|
|
- parameters: optionalToNullable(t.parameters),
|
|
|
- }))
|
|
|
- case "azure":
|
|
|
- return result.map((t) => ({
|
|
|
- ...t,
|
|
|
- parameters: optionalToNullable(t.parameters),
|
|
|
- }))
|
|
|
- case "google":
|
|
|
- return result.map((t) => ({
|
|
|
- ...t,
|
|
|
- parameters: sanitizeGeminiParameters(t.parameters),
|
|
|
- }))
|
|
|
- default:
|
|
|
- return result
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- function sanitizeGeminiParameters(schema: z.ZodTypeAny, visited = new Set()): z.ZodTypeAny {
|
|
|
- if (!schema || visited.has(schema)) {
|
|
|
- return schema
|
|
|
- }
|
|
|
- visited.add(schema)
|
|
|
-
|
|
|
- if (schema instanceof z.ZodDefault) {
|
|
|
- const innerSchema = schema.removeDefault()
|
|
|
- // Handle Gemini's incompatibility with `default` on `anyOf` (unions).
|
|
|
- if (innerSchema instanceof z.ZodUnion) {
|
|
|
- // The schema was `z.union(...).default(...)`, which is not allowed.
|
|
|
- // We strip the default and return the sanitized union.
|
|
|
- return sanitizeGeminiParameters(innerSchema, visited)
|
|
|
- }
|
|
|
- // Otherwise, the default is on a regular type, which is allowed.
|
|
|
- // We recurse on the inner type and then re-apply the default.
|
|
|
- return sanitizeGeminiParameters(innerSchema, visited).default(schema._def.defaultValue())
|
|
|
- }
|
|
|
-
|
|
|
- if (schema instanceof z.ZodOptional) {
|
|
|
- return z.optional(sanitizeGeminiParameters(schema.unwrap(), visited))
|
|
|
- }
|
|
|
-
|
|
|
- if (schema instanceof z.ZodObject) {
|
|
|
- const newShape: Record<string, z.ZodTypeAny> = {}
|
|
|
- for (const [key, value] of Object.entries(schema.shape)) {
|
|
|
- newShape[key] = sanitizeGeminiParameters(value as z.ZodTypeAny, visited)
|
|
|
- }
|
|
|
- return z.object(newShape)
|
|
|
- }
|
|
|
-
|
|
|
- if (schema instanceof z.ZodArray) {
|
|
|
- return z.array(sanitizeGeminiParameters(schema.element, visited))
|
|
|
- }
|
|
|
-
|
|
|
- if (schema instanceof z.ZodUnion) {
|
|
|
- // This schema corresponds to `anyOf` in JSON Schema.
|
|
|
- // We recursively sanitize each option in the union.
|
|
|
- const sanitizedOptions = schema.options.map((option: z.ZodTypeAny) => sanitizeGeminiParameters(option, visited))
|
|
|
- return z.union(sanitizedOptions as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
|
|
|
- }
|
|
|
-
|
|
|
- if (schema instanceof z.ZodString) {
|
|
|
- const newSchema = z.string({ description: schema.description })
|
|
|
- const safeChecks = ["min", "max", "length", "regex", "startsWith", "endsWith", "includes", "trim"]
|
|
|
- // rome-ignore lint/suspicious/noExplicitAny: <explanation>
|
|
|
- ;(newSchema._def as any).checks = (schema._def as z.ZodStringDef).checks.filter((check) =>
|
|
|
- safeChecks.includes(check.kind),
|
|
|
- )
|
|
|
- return newSchema
|
|
|
- }
|
|
|
-
|
|
|
- return schema
|
|
|
- }
|
|
|
- function optionalToNullable(schema: z.ZodTypeAny): z.ZodTypeAny {
|
|
|
- if (schema instanceof z.ZodObject) {
|
|
|
- const shape = schema.shape
|
|
|
- const newShape: Record<string, z.ZodTypeAny> = {}
|
|
|
-
|
|
|
- for (const [key, value] of Object.entries(shape)) {
|
|
|
- const zodValue = value as z.ZodTypeAny
|
|
|
- if (zodValue instanceof z.ZodOptional) {
|
|
|
- newShape[key] = zodValue.unwrap().nullable()
|
|
|
- } else {
|
|
|
- newShape[key] = optionalToNullable(zodValue)
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- return z.object(newShape)
|
|
|
- }
|
|
|
-
|
|
|
- if (schema instanceof z.ZodArray) {
|
|
|
- return z.array(optionalToNullable(schema.element))
|
|
|
- }
|
|
|
-
|
|
|
- if (schema instanceof z.ZodUnion) {
|
|
|
- return z.union(
|
|
|
- schema.options.map((option: z.ZodTypeAny) => optionalToNullable(option)) as [
|
|
|
- z.ZodTypeAny,
|
|
|
- z.ZodTypeAny,
|
|
|
- ...z.ZodTypeAny[],
|
|
|
- ],
|
|
|
- )
|
|
|
- }
|
|
|
-
|
|
|
- return schema
|
|
|
- }
|
|
|
-
|
|
|
export const ModelNotFoundError = NamedError.create(
|
|
|
"ProviderModelNotFoundError",
|
|
|
z.object({
|