|
|
@@ -17,7 +17,6 @@ export const BatchTool = Tool.define("batch", async () => {
|
|
|
}),
|
|
|
)
|
|
|
.min(1, "Provide at least one tool call")
|
|
|
- .max(10, "Too many tools in batch. Maximum allowed is 10.")
|
|
|
.describe("Array of tool calls to execute in parallel"),
|
|
|
}),
|
|
|
formatValidationError(error) {
|
|
|
@@ -34,34 +33,16 @@ export const BatchTool = Tool.define("batch", async () => {
|
|
|
const { Session } = await import("../session")
|
|
|
const { Identifier } = await import("../id/id")
|
|
|
|
|
|
- const toolCalls = params.tool_calls
|
|
|
+ const toolCalls = params.tool_calls.slice(0, 10)
|
|
|
+ const discardedCalls = params.tool_calls.slice(10)
|
|
|
|
|
|
const { ToolRegistry } = await import("./registry")
|
|
|
const availableTools = await ToolRegistry.tools("", "")
|
|
|
const toolMap = new Map(availableTools.map((t) => [t.id, t]))
|
|
|
|
|
|
- const partIDs = new Map<(typeof toolCalls)[0], string>()
|
|
|
- for (const call of toolCalls) {
|
|
|
- const partID = Identifier.ascending("part")
|
|
|
- partIDs.set(call, partID)
|
|
|
- Session.updatePart({
|
|
|
- id: partID,
|
|
|
- messageID: ctx.messageID,
|
|
|
- sessionID: ctx.sessionID,
|
|
|
- type: "tool",
|
|
|
- tool: call.tool,
|
|
|
- callID: partID,
|
|
|
- state: {
|
|
|
- status: "pending",
|
|
|
- input: call.parameters,
|
|
|
- raw: JSON.stringify(call),
|
|
|
- },
|
|
|
- })
|
|
|
- }
|
|
|
-
|
|
|
const executeCall = async (call: (typeof toolCalls)[0]) => {
|
|
|
const callStartTime = Date.now()
|
|
|
- const partID = partIDs.get(call)!
|
|
|
+ const partID = Identifier.ascending("part")
|
|
|
|
|
|
try {
|
|
|
if (DISALLOWED.has(call.tool)) {
|
|
|
@@ -77,6 +58,22 @@ export const BatchTool = Tool.define("batch", async () => {
|
|
|
}
|
|
|
const validatedParams = tool.parameters.parse(call.parameters)
|
|
|
|
|
|
+ await Session.updatePart({
|
|
|
+ id: partID,
|
|
|
+ messageID: ctx.messageID,
|
|
|
+ sessionID: ctx.sessionID,
|
|
|
+ type: "tool",
|
|
|
+ tool: call.tool,
|
|
|
+ callID: partID,
|
|
|
+ state: {
|
|
|
+ status: "running",
|
|
|
+ input: call.parameters,
|
|
|
+ time: {
|
|
|
+ start: callStartTime,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ })
|
|
|
+
|
|
|
const result = await tool.execute(validatedParams, { ...ctx, callID: partID })
|
|
|
|
|
|
await Session.updatePart({
|
|
|
@@ -126,31 +123,48 @@ export const BatchTool = Tool.define("batch", async () => {
|
|
|
|
|
|
const results = await Promise.all(toolCalls.map((call) => executeCall(call)))
|
|
|
|
|
|
- const successfulCalls = results.filter((r) => r.success).length
|
|
|
- const failedCalls = toolCalls.length - successfulCalls
|
|
|
+ // Add discarded calls as errors
|
|
|
+ const now = Date.now()
|
|
|
+ for (const call of discardedCalls) {
|
|
|
+ const partID = Identifier.ascending("part")
|
|
|
+ await Session.updatePart({
|
|
|
+ id: partID,
|
|
|
+ messageID: ctx.messageID,
|
|
|
+ sessionID: ctx.sessionID,
|
|
|
+ type: "tool",
|
|
|
+ tool: call.tool,
|
|
|
+ callID: partID,
|
|
|
+ state: {
|
|
|
+ status: "error",
|
|
|
+ input: call.parameters,
|
|
|
+ error: "Maximum of 10 tools allowed in batch",
|
|
|
+ time: { start: now, end: now },
|
|
|
+ },
|
|
|
+ })
|
|
|
+ results.push({
|
|
|
+ success: false as const,
|
|
|
+ tool: call.tool,
|
|
|
+ error: new Error("Maximum of 10 tools allowed in batch"),
|
|
|
+ })
|
|
|
+ }
|
|
|
|
|
|
- const outputParts = results.map((r) => {
|
|
|
- if (r.success) {
|
|
|
- return `<tool_result name="${r.tool}">\n${r.result.output}\n</tool_result>`
|
|
|
- }
|
|
|
- const errorMessage = r.error instanceof Error ? r.error.message : String(r.error)
|
|
|
- return `<tool_result name="${r.tool}">\nError: ${errorMessage}\n</tool_result>`
|
|
|
- })
|
|
|
+ const successfulCalls = results.filter((r) => r.success).length
|
|
|
+ const failedCalls = results.length - successfulCalls
|
|
|
|
|
|
const outputMessage =
|
|
|
failedCalls > 0
|
|
|
- ? `Executed ${successfulCalls}/${toolCalls.length} tools successfully. ${failedCalls} failed.\n\n${outputParts.join("\n\n")}`
|
|
|
- : `All ${successfulCalls} tools executed successfully.\n\n${outputParts.join("\n\n")}\n\nKeep using the batch tool for optimal performance in your next response!`
|
|
|
+ ? `Executed ${successfulCalls}/${results.length} tools successfully. ${failedCalls} failed.`
|
|
|
+ : `All ${successfulCalls} tools executed successfully.\n\nKeep using the batch tool for optimal performance in your next response!`
|
|
|
|
|
|
return {
|
|
|
- title: `Batch execution (${successfulCalls}/${toolCalls.length} successful)`,
|
|
|
+ title: `Batch execution (${successfulCalls}/${results.length} successful)`,
|
|
|
output: outputMessage,
|
|
|
attachments: results.filter((result) => result.success).flatMap((r) => r.result.attachments ?? []),
|
|
|
metadata: {
|
|
|
- totalCalls: toolCalls.length,
|
|
|
+ totalCalls: results.length,
|
|
|
successful: successfulCalls,
|
|
|
failed: failedCalls,
|
|
|
- tools: toolCalls.map((c) => c.tool),
|
|
|
+ tools: params.tool_calls.map((c) => c.tool),
|
|
|
details: results.map((r) => ({ tool: r.tool, success: r.success })),
|
|
|
},
|
|
|
}
|