registry.ts 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import z from "zod"
  2. import { BashTool } from "./bash"
  3. import { EditTool } from "./edit"
  4. import { GlobTool } from "./glob"
  5. import { GrepTool } from "./grep"
  6. import { ListTool } from "./ls"
  7. import { PatchTool } from "./patch"
  8. import { ReadTool } from "./read"
  9. import { TaskTool } from "./task"
  10. import { TodoWriteTool, TodoReadTool } from "./todo"
  11. import { WebFetchTool } from "./webfetch"
  12. import { WriteTool } from "./write"
  13. export namespace ToolRegistry {
  14. const ALL = [
  15. BashTool,
  16. EditTool,
  17. WebFetchTool,
  18. GlobTool,
  19. GrepTool,
  20. ListTool,
  21. PatchTool,
  22. ReadTool,
  23. WriteTool,
  24. TodoWriteTool,
  25. TodoReadTool,
  26. TaskTool,
  27. ]
  28. export function ids() {
  29. return ALL.map((t) => t.id)
  30. }
  31. export async function tools(providerID: string, _modelID: string) {
  32. const result = await Promise.all(
  33. ALL.map(async (t) => ({
  34. id: t.id,
  35. ...(await t.init()),
  36. })),
  37. )
  38. if (providerID === "openai") {
  39. return result.map((t) => ({
  40. ...t,
  41. parameters: optionalToNullable(t.parameters),
  42. }))
  43. }
  44. if (providerID === "azure") {
  45. return result.map((t) => ({
  46. ...t,
  47. parameters: optionalToNullable(t.parameters),
  48. }))
  49. }
  50. if (providerID === "google") {
  51. return result.map((t) => ({
  52. ...t,
  53. parameters: sanitizeGeminiParameters(t.parameters),
  54. }))
  55. }
  56. return result
  57. }
  58. export function enabled(_providerID: string, modelID: string): Record<string, boolean> {
  59. if (modelID.toLowerCase().includes("claude")) {
  60. return {
  61. patch: false,
  62. }
  63. }
  64. if (
  65. modelID.toLowerCase().includes("qwen") ||
  66. modelID.includes("gpt-") ||
  67. modelID.includes("o1") ||
  68. modelID.includes("o3")
  69. ) {
  70. return {
  71. patch: false,
  72. todowrite: false,
  73. todoread: false,
  74. }
  75. }
  76. return {}
  77. }
  78. function sanitizeGeminiParameters(schema: z.ZodTypeAny, visited = new Set()): z.ZodTypeAny {
  79. if (!schema || visited.has(schema)) {
  80. return schema
  81. }
  82. visited.add(schema)
  83. if (schema instanceof z.ZodDefault) {
  84. const innerSchema = schema.removeDefault()
  85. // Handle Gemini's incompatibility with `default` on `anyOf` (unions).
  86. if (innerSchema instanceof z.ZodUnion) {
  87. // The schema was `z.union(...).default(...)`, which is not allowed.
  88. // We strip the default and return the sanitized union.
  89. return sanitizeGeminiParameters(innerSchema, visited)
  90. }
  91. // Otherwise, the default is on a regular type, which is allowed.
  92. // We recurse on the inner type and then re-apply the default.
  93. return sanitizeGeminiParameters(innerSchema, visited).default(schema._def.defaultValue())
  94. }
  95. if (schema instanceof z.ZodOptional) {
  96. return z.optional(sanitizeGeminiParameters(schema.unwrap(), visited))
  97. }
  98. if (schema instanceof z.ZodObject) {
  99. const newShape: Record<string, z.ZodTypeAny> = {}
  100. for (const [key, value] of Object.entries(schema.shape)) {
  101. newShape[key] = sanitizeGeminiParameters(value as z.ZodTypeAny, visited)
  102. }
  103. return z.object(newShape)
  104. }
  105. if (schema instanceof z.ZodArray) {
  106. return z.array(sanitizeGeminiParameters(schema.element, visited))
  107. }
  108. if (schema instanceof z.ZodUnion) {
  109. // This schema corresponds to `anyOf` in JSON Schema.
  110. // We recursively sanitize each option in the union.
  111. const sanitizedOptions = schema.options.map((option: z.ZodTypeAny) => sanitizeGeminiParameters(option, visited))
  112. return z.union(sanitizedOptions as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
  113. }
  114. if (schema instanceof z.ZodString) {
  115. const newSchema = z.string({ description: schema.description })
  116. const safeChecks = ["min", "max", "length", "regex", "startsWith", "endsWith", "includes", "trim"]
  117. // rome-ignore lint/suspicious/noExplicitAny: <explanation>
  118. ;(newSchema._def as any).checks = (schema._def as z.ZodStringDef).checks.filter((check) =>
  119. safeChecks.includes(check.kind),
  120. )
  121. return newSchema
  122. }
  123. return schema
  124. }
  125. function optionalToNullable(schema: z.ZodTypeAny): z.ZodTypeAny {
  126. if (schema instanceof z.ZodObject) {
  127. const shape = schema.shape
  128. const newShape: Record<string, z.ZodTypeAny> = {}
  129. for (const [key, value] of Object.entries(shape)) {
  130. const zodValue = value as z.ZodTypeAny
  131. if (zodValue instanceof z.ZodOptional) {
  132. newShape[key] = zodValue.unwrap().nullable()
  133. } else {
  134. newShape[key] = optionalToNullable(zodValue)
  135. }
  136. }
  137. return z.object(newShape)
  138. }
  139. if (schema instanceof z.ZodArray) {
  140. return z.array(optionalToNullable(schema.element))
  141. }
  142. if (schema instanceof z.ZodUnion) {
  143. return z.union(
  144. schema.options.map((option: z.ZodTypeAny) => optionalToNullable(option)) as [
  145. z.ZodTypeAny,
  146. z.ZodTypeAny,
  147. ...z.ZodTypeAny[],
  148. ],
  149. )
  150. }
  151. return schema
  152. }
  153. }