batch.ts 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import z from "zod"
  2. import { Tool } from "./tool"
  3. import DESCRIPTION from "./batch.txt"
  4. const DISALLOWED = new Set(["batch"])
  5. const FILTERED_FROM_SUGGESTIONS = new Set(["invalid", "patch", ...DISALLOWED])
  6. export const BatchTool = Tool.define("batch", async () => {
  7. return {
  8. description: DESCRIPTION,
  9. parameters: z.object({
  10. tool_calls: z
  11. .array(
  12. z.object({
  13. tool: z.string().describe("The name of the tool to execute"),
  14. parameters: z.object({}).loose().describe("Parameters for the tool"),
  15. }),
  16. )
  17. .min(1, "Provide at least one tool call")
  18. .describe("Array of tool calls to execute in parallel"),
  19. }),
  20. formatValidationError(error) {
  21. const formattedErrors = error.issues
  22. .map((issue) => {
  23. const path = issue.path.length > 0 ? issue.path.join(".") : "root"
  24. return ` - ${path}: ${issue.message}`
  25. })
  26. .join("\n")
  27. return `Invalid parameters for tool 'batch':\n${formattedErrors}\n\nExpected payload format:\n [{"tool": "tool_name", "parameters": {...}}, {...}]`
  28. },
  29. async execute(params, ctx) {
  30. const { Session } = await import("../session")
  31. const { Identifier } = await import("../id/id")
  32. const toolCalls = params.tool_calls.slice(0, 10)
  33. const discardedCalls = params.tool_calls.slice(10)
  34. const { ToolRegistry } = await import("./registry")
  35. const availableTools = await ToolRegistry.tools("")
  36. const toolMap = new Map(availableTools.map((t) => [t.id, t]))
  37. const executeCall = async (call: (typeof toolCalls)[0]) => {
  38. const callStartTime = Date.now()
  39. const partID = Identifier.ascending("part")
  40. try {
  41. if (DISALLOWED.has(call.tool)) {
  42. throw new Error(
  43. `Tool '${call.tool}' is not allowed in batch. Disallowed tools: ${Array.from(DISALLOWED).join(", ")}`,
  44. )
  45. }
  46. const tool = toolMap.get(call.tool)
  47. if (!tool) {
  48. const availableToolsList = Array.from(toolMap.keys()).filter((name) => !FILTERED_FROM_SUGGESTIONS.has(name))
  49. throw new Error(
  50. `Tool '${call.tool}' not in registry. External tools (MCP, environment) cannot be batched - call them directly. Available tools: ${availableToolsList.join(", ")}`,
  51. )
  52. }
  53. const validatedParams = tool.parameters.parse(call.parameters)
  54. await Session.updatePart({
  55. id: partID,
  56. messageID: ctx.messageID,
  57. sessionID: ctx.sessionID,
  58. type: "tool",
  59. tool: call.tool,
  60. callID: partID,
  61. state: {
  62. status: "running",
  63. input: call.parameters,
  64. time: {
  65. start: callStartTime,
  66. },
  67. },
  68. })
  69. const result = await tool.execute(validatedParams, { ...ctx, callID: partID })
  70. await Session.updatePart({
  71. id: partID,
  72. messageID: ctx.messageID,
  73. sessionID: ctx.sessionID,
  74. type: "tool",
  75. tool: call.tool,
  76. callID: partID,
  77. state: {
  78. status: "completed",
  79. input: call.parameters,
  80. output: result.output,
  81. title: result.title,
  82. metadata: result.metadata,
  83. attachments: result.attachments,
  84. time: {
  85. start: callStartTime,
  86. end: Date.now(),
  87. },
  88. },
  89. })
  90. return { success: true as const, tool: call.tool, result }
  91. } catch (error) {
  92. await Session.updatePart({
  93. id: partID,
  94. messageID: ctx.messageID,
  95. sessionID: ctx.sessionID,
  96. type: "tool",
  97. tool: call.tool,
  98. callID: partID,
  99. state: {
  100. status: "error",
  101. input: call.parameters,
  102. error: error instanceof Error ? error.message : String(error),
  103. time: {
  104. start: callStartTime,
  105. end: Date.now(),
  106. },
  107. },
  108. })
  109. return { success: false as const, tool: call.tool, error }
  110. }
  111. }
  112. const results = await Promise.all(toolCalls.map((call) => executeCall(call)))
  113. // Add discarded calls as errors
  114. const now = Date.now()
  115. for (const call of discardedCalls) {
  116. const partID = Identifier.ascending("part")
  117. await Session.updatePart({
  118. id: partID,
  119. messageID: ctx.messageID,
  120. sessionID: ctx.sessionID,
  121. type: "tool",
  122. tool: call.tool,
  123. callID: partID,
  124. state: {
  125. status: "error",
  126. input: call.parameters,
  127. error: "Maximum of 10 tools allowed in batch",
  128. time: { start: now, end: now },
  129. },
  130. })
  131. results.push({
  132. success: false as const,
  133. tool: call.tool,
  134. error: new Error("Maximum of 10 tools allowed in batch"),
  135. })
  136. }
  137. const successfulCalls = results.filter((r) => r.success).length
  138. const failedCalls = results.length - successfulCalls
  139. const outputMessage =
  140. failedCalls > 0
  141. ? `Executed ${successfulCalls}/${results.length} tools successfully. ${failedCalls} failed.`
  142. : `All ${successfulCalls} tools executed successfully.\n\nKeep using the batch tool for optimal performance in your next response!`
  143. return {
  144. title: `Batch execution (${successfulCalls}/${results.length} successful)`,
  145. output: outputMessage,
  146. attachments: results.filter((result) => result.success).flatMap((r) => r.result.attachments ?? []),
  147. metadata: {
  148. totalCalls: results.length,
  149. successful: successfulCalls,
  150. failed: failedCalls,
  151. tools: params.tool_calls.map((c) => c.tool),
  152. details: results.map((r) => ({ tool: r.tool, success: r.success })),
  153. },
  154. }
  155. },
  156. }
  157. })