adamdottv 9 месяцев назад
Родитель
Сommit
a65e593ab4

+ 43 - 21
internal/llm/agent/tools.go

@@ -21,30 +21,41 @@ func PrimaryAgentTools(
 	ctx := context.Background()
 	mcpTools := GetMcpTools(ctx, permissions)
 
-	return append(
-		[]tools.BaseTool{
-			tools.NewBashTool(permissions),
-			tools.NewEditTool(lspClients, permissions, history),
-			tools.NewFetchTool(permissions),
-			tools.NewGlobTool(),
-			tools.NewGrepTool(),
-			tools.NewLsTool(),
-			tools.NewViewTool(lspClients),
-			tools.NewPatchTool(lspClients, permissions, history),
-			tools.NewWriteTool(lspClients, permissions, history),
-			tools.NewDiagnosticsTool(lspClients),
-			tools.NewDefinitionTool(lspClients),
-			tools.NewReferencesTool(lspClients),
-			tools.NewDocSymbolsTool(lspClients),
-			tools.NewWorkspaceSymbolsTool(lspClients),
-			tools.NewCodeActionTool(lspClients),
-			NewAgentTool(sessions, messages, lspClients),
-		}, mcpTools...,
-	)
+	// Create the list of tools
+	toolsList := []tools.BaseTool{
+		tools.NewBashTool(permissions),
+		tools.NewEditTool(lspClients, permissions, history),
+		tools.NewFetchTool(permissions),
+		tools.NewGlobTool(),
+		tools.NewGrepTool(),
+		tools.NewLsTool(),
+		tools.NewViewTool(lspClients),
+		tools.NewPatchTool(lspClients, permissions, history),
+		tools.NewWriteTool(lspClients, permissions, history),
+		tools.NewDiagnosticsTool(lspClients),
+		tools.NewDefinitionTool(lspClients),
+		tools.NewReferencesTool(lspClients),
+		tools.NewDocSymbolsTool(lspClients),
+		tools.NewWorkspaceSymbolsTool(lspClients),
+		tools.NewCodeActionTool(lspClients),
+		NewAgentTool(sessions, messages, lspClients),
+	}
+
+	// Create a map of tools for the batch tool
+	toolsMap := make(map[string]tools.BaseTool)
+	for _, tool := range toolsList {
+		toolsMap[tool.Info().Name] = tool
+	}
+
+	// Add the batch tool with access to all other tools
+	toolsList = append(toolsList, tools.NewBatchTool(toolsMap))
+
+	return append(toolsList, mcpTools...)
 }
 
 func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool {
-	return []tools.BaseTool{
+	// Create the list of tools
+	toolsList := []tools.BaseTool{
 		tools.NewGlobTool(),
 		tools.NewGrepTool(),
 		tools.NewLsTool(),
@@ -54,4 +65,15 @@ func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool {
 		tools.NewDocSymbolsTool(lspClients),
 		tools.NewWorkspaceSymbolsTool(lspClients),
 	}
+
+	// Create a map of tools for the batch tool
+	toolsMap := make(map[string]tools.BaseTool)
+	for _, tool := range toolsList {
+		toolsMap[tool.Info().Name] = tool
+	}
+
+	// Add the batch tool with access to all other tools
+	toolsList = append(toolsList, tools.NewBatchTool(toolsMap))
+
+	return toolsList
 }

+ 191 - 0
internal/llm/tools/batch.go

@@ -0,0 +1,191 @@
+package tools
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"strings"
+	"sync"
+)
+
+type BatchToolCall struct {
+	Name  string          `json:"name"`
+	Input json.RawMessage `json:"input"`
+}
+
+type BatchParams struct {
+	Calls []BatchToolCall `json:"calls"`
+}
+
+type BatchToolResult struct {
+	ToolName  string          `json:"tool_name"`
+	ToolInput json.RawMessage `json:"tool_input"`
+	Result    json.RawMessage `json:"result"`
+	Error     string          `json:"error,omitempty"`
+	// Added for better formatting and separation between results
+	Separator string          `json:"separator,omitempty"`
+}
+
+type BatchResult struct {
+	Results []BatchToolResult `json:"results"`
+}
+
+type batchTool struct {
+	tools map[string]BaseTool
+}
+
+const (
+	BatchToolName        = "batch"
+	BatchToolDescription = `Executes multiple tool calls in parallel and returns their results.
+
+WHEN TO USE THIS TOOL:
+- Use when you need to run multiple independent tool calls at once
+- Helpful for improving performance by parallelizing operations
+- Great for gathering information from multiple sources simultaneously
+
+HOW TO USE:
+- Provide an array of tool calls, each with a name and input
+- Each tool call will be executed in parallel
+- Results are returned in the same order as the input calls
+
+FEATURES:
+- Runs tool calls concurrently for better performance
+- Returns both results and errors for each call
+- Maintains the order of results to match input calls
+
+LIMITATIONS:
+- All tools must be available in the current context
+- Complex error handling may be required for some use cases
+- Not suitable for tool calls that depend on each other's results
+
+TIPS:
+- Use for independent operations like multiple file reads or searches
+- Great for batch operations like searching multiple directories
+- Combine with other tools for more complex workflows`
+)
+
+func NewBatchTool(tools map[string]BaseTool) BaseTool {
+	return &batchTool{
+		tools: tools,
+	}
+}
+
+func (b *batchTool) Info() ToolInfo {
+	return ToolInfo{
+		Name:        BatchToolName,
+		Description: BatchToolDescription,
+		Parameters: map[string]any{
+			"calls": map[string]any{
+				"type":        "array",
+				"description": "Array of tool calls to execute in parallel",
+				"items": map[string]any{
+					"type": "object",
+					"properties": map[string]any{
+						"name": map[string]any{
+							"type":        "string",
+							"description": "Name of the tool to call",
+						},
+						"input": map[string]any{
+							"type":        "object",
+							"description": "Input parameters for the tool",
+						},
+					},
+					"required": []string{"name", "input"},
+				},
+			},
+		},
+		Required: []string{"calls"},
+	}
+}
+
+func (b *batchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
+	var params BatchParams
+	if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
+		return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
+	}
+
+	if len(params.Calls) == 0 {
+		return NewTextErrorResponse("no tool calls provided"), nil
+	}
+
+	var wg sync.WaitGroup
+	results := make([]BatchToolResult, len(params.Calls))
+
+	for i, toolCall := range params.Calls {
+		wg.Add(1)
+		go func(index int, tc BatchToolCall) {
+			defer wg.Done()
+
+			// Create separator for better visual distinction between results
+			separator := ""
+			if index > 0 {
+				separator = fmt.Sprintf("\n%s\n", strings.Repeat("=", 80))
+			}
+
+			result := BatchToolResult{
+				ToolName:  tc.Name,
+				ToolInput: tc.Input,
+				Separator: separator,
+			}
+
+			tool, ok := b.tools[tc.Name]
+			if !ok {
+				result.Error = fmt.Sprintf("tool not found: %s", tc.Name)
+				results[index] = result
+				return
+			}
+
+			// Create a proper ToolCall object
+			callObj := ToolCall{
+				ID:    fmt.Sprintf("batch-%d", index),
+				Name:  tc.Name,
+				Input: string(tc.Input),
+			}
+
+			response, err := tool.Run(ctx, callObj)
+			if err != nil {
+				result.Error = fmt.Sprintf("error executing tool %s: %s", tc.Name, err)
+				results[index] = result
+				return
+			}
+
+			// Standardize metadata format if present
+			if response.Metadata != "" {
+				var metadata map[string]interface{}
+				if err := json.Unmarshal([]byte(response.Metadata), &metadata); err == nil {
+					// Add tool name to metadata for better context
+					metadata["tool"] = tc.Name
+					
+					// Re-marshal with consistent formatting
+					if metadataBytes, err := json.MarshalIndent(metadata, "", "  "); err == nil {
+						response.Metadata = string(metadataBytes)
+					}
+				}
+			}
+
+			// Convert the response to JSON
+			responseJSON, err := json.Marshal(response)
+			if err != nil {
+				result.Error = fmt.Sprintf("error marshaling response: %s", err)
+				results[index] = result
+				return
+			}
+
+			result.Result = responseJSON
+			results[index] = result
+		}(i, toolCall)
+	}
+
+	wg.Wait()
+
+	batchResult := BatchResult{
+		Results: results,
+	}
+
+	resultJSON, err := json.Marshal(batchResult)
+	if err != nil {
+		return NewTextErrorResponse(fmt.Sprintf("error marshaling batch result: %s", err)), nil
+	}
+
+	return NewTextResponse(string(resultJSON)), nil
+}

+ 224 - 0
internal/llm/tools/batch_test.go

@@ -0,0 +1,224 @@
+package tools
+
+import (
+	"context"
+	"encoding/json"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+// MockTool is a simple tool implementation for testing
+type MockTool struct {
+	name        string
+	description string
+	response    ToolResponse
+	err         error
+}
+
+func (m *MockTool) Info() ToolInfo {
+	return ToolInfo{
+		Name:        m.name,
+		Description: m.description,
+		Parameters:  map[string]any{},
+		Required:    []string{},
+	}
+}
+
+func (m *MockTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
+	return m.response, m.err
+}
+
+func TestBatchTool(t *testing.T) {
+	t.Parallel()
+
+	t.Run("successful batch execution", func(t *testing.T) {
+		t.Parallel()
+
+		// Create mock tools
+		mockTools := map[string]BaseTool{
+			"tool1": &MockTool{
+				name:        "tool1",
+				description: "Mock Tool 1",
+				response:    NewTextResponse("Tool 1 Response"),
+				err:         nil,
+			},
+			"tool2": &MockTool{
+				name:        "tool2",
+				description: "Mock Tool 2",
+				response:    NewTextResponse("Tool 2 Response"),
+				err:         nil,
+			},
+		}
+
+		// Create batch tool
+		batchTool := NewBatchTool(mockTools)
+
+		// Create batch call
+		input := `{
+			"calls": [
+				{
+					"name": "tool1",
+					"input": {}
+				},
+				{
+					"name": "tool2",
+					"input": {}
+				}
+			]
+		}`
+
+		call := ToolCall{
+			ID:    "test-batch",
+			Name:  "batch",
+			Input: input,
+		}
+
+		// Execute batch
+		response, err := batchTool.Run(context.Background(), call)
+
+		// Verify results
+		assert.NoError(t, err)
+		assert.Equal(t, ToolResponseTypeText, response.Type)
+		assert.False(t, response.IsError)
+
+		// Parse the response
+		var batchResult BatchResult
+		err = json.Unmarshal([]byte(response.Content), &batchResult)
+		assert.NoError(t, err)
+
+		// Verify batch results
+		assert.Len(t, batchResult.Results, 2)
+		assert.Empty(t, batchResult.Results[0].Error)
+		assert.Empty(t, batchResult.Results[1].Error)
+		assert.Empty(t, batchResult.Results[0].Separator)
+		assert.NotEmpty(t, batchResult.Results[1].Separator)
+
+		// Verify individual results
+		var result1 ToolResponse
+		err = json.Unmarshal(batchResult.Results[0].Result, &result1)
+		assert.NoError(t, err)
+		assert.Equal(t, "Tool 1 Response", result1.Content)
+
+		var result2 ToolResponse
+		err = json.Unmarshal(batchResult.Results[1].Result, &result2)
+		assert.NoError(t, err)
+		assert.Equal(t, "Tool 2 Response", result2.Content)
+	})
+
+	t.Run("tool not found", func(t *testing.T) {
+		t.Parallel()
+
+		// Create mock tools
+		mockTools := map[string]BaseTool{
+			"tool1": &MockTool{
+				name:        "tool1",
+				description: "Mock Tool 1",
+				response:    NewTextResponse("Tool 1 Response"),
+				err:         nil,
+			},
+		}
+
+		// Create batch tool
+		batchTool := NewBatchTool(mockTools)
+
+		// Create batch call with non-existent tool
+		input := `{
+			"calls": [
+				{
+					"name": "tool1",
+					"input": {}
+				},
+				{
+					"name": "nonexistent",
+					"input": {}
+				}
+			]
+		}`
+
+		call := ToolCall{
+			ID:    "test-batch",
+			Name:  "batch",
+			Input: input,
+		}
+
+		// Execute batch
+		response, err := batchTool.Run(context.Background(), call)
+
+		// Verify results
+		assert.NoError(t, err)
+		assert.Equal(t, ToolResponseTypeText, response.Type)
+		assert.False(t, response.IsError)
+
+		// Parse the response
+		var batchResult BatchResult
+		err = json.Unmarshal([]byte(response.Content), &batchResult)
+		assert.NoError(t, err)
+
+		// Verify batch results
+		assert.Len(t, batchResult.Results, 2)
+		assert.Empty(t, batchResult.Results[0].Error)
+		assert.Contains(t, batchResult.Results[1].Error, "tool not found: nonexistent")
+	})
+
+	t.Run("empty calls", func(t *testing.T) {
+		t.Parallel()
+
+		// Create batch tool with empty tools map
+		batchTool := NewBatchTool(map[string]BaseTool{})
+
+		// Create batch call with empty calls
+		input := `{
+			"calls": []
+		}`
+
+		call := ToolCall{
+			ID:    "test-batch",
+			Name:  "batch",
+			Input: input,
+		}
+
+		// Execute batch
+		response, err := batchTool.Run(context.Background(), call)
+
+		// Verify results
+		assert.NoError(t, err)
+		assert.Equal(t, ToolResponseTypeText, response.Type)
+		assert.True(t, response.IsError)
+		assert.Contains(t, response.Content, "no tool calls provided")
+	})
+
+	t.Run("invalid input", func(t *testing.T) {
+		t.Parallel()
+
+		// Create batch tool with empty tools map
+		batchTool := NewBatchTool(map[string]BaseTool{})
+
+		// Create batch call with invalid JSON
+		input := `{
+			"calls": [
+				{
+					"name": "tool1",
+					"input": {
+						"invalid": json
+					}
+				}
+			]
+		}`
+
+		call := ToolCall{
+			ID:    "test-batch",
+			Name:  "batch",
+			Input: input,
+		}
+
+		// Execute batch
+		response, err := batchTool.Run(context.Background(), call)
+
+		// Verify results
+		assert.NoError(t, err)
+		assert.Equal(t, ToolResponseTypeText, response.Type)
+		assert.True(t, response.IsError)
+		assert.Contains(t, response.Content, "error parsing parameters")
+	})
+}

+ 40 - 0
internal/tui/components/chat/message.go

@@ -266,6 +266,8 @@ func toolName(name string) string {
 		return "Write"
 	case tools.PatchToolName:
 		return "Patch"
+	case tools.BatchToolName:
+		return "Batch"
 	}
 	return name
 }
@@ -292,6 +294,8 @@ func getToolAction(name string) string {
 		return "Preparing write..."
 	case tools.PatchToolName:
 		return "Preparing patch..."
+	case tools.BatchToolName:
+		return "Running batch operations..."
 	}
 	return "Working..."
 }
@@ -443,6 +447,10 @@ func renderToolParams(paramWidth int, toolCall message.ToolCall) string {
 		json.Unmarshal([]byte(toolCall.Input), &params)
 		filePath := removeWorkingDirPrefix(params.FilePath)
 		return renderParams(paramWidth, filePath)
+	case tools.BatchToolName:
+		var params tools.BatchParams
+		json.Unmarshal([]byte(toolCall.Input), &params)
+		return renderParams(paramWidth, fmt.Sprintf("%d parallel calls", len(params.Calls)))
 	default:
 		input := strings.ReplaceAll(toolCall.Input, "\n", " ")
 		params = renderParams(paramWidth, input)
@@ -540,6 +548,38 @@ func renderToolResponse(toolCall message.ToolCall, response message.ToolResult,
 			toMarkdown(resultContent, true, width),
 			t.Background(),
 		)
+	case tools.BatchToolName:
+		var batchResult tools.BatchResult
+		if err := json.Unmarshal([]byte(resultContent), &batchResult); err != nil {
+			return baseStyle.Width(width).Foreground(t.Error()).Render(fmt.Sprintf("Error parsing batch result: %s", err))
+		}
+
+		var toolCalls []string
+		for i, result := range batchResult.Results {
+			toolName := toolName(result.ToolName)
+
+			// Format the tool input as a string
+			inputStr := string(result.ToolInput)
+
+			// Format the result
+			var resultStr string
+			if result.Error != "" {
+				resultStr = fmt.Sprintf("Error: %s", result.Error)
+			} else {
+				var toolResponse tools.ToolResponse
+				if err := json.Unmarshal(result.Result, &toolResponse); err != nil {
+					resultStr = "Error parsing tool response"
+				} else {
+					resultStr = truncateHeight(toolResponse.Content, 3)
+				}
+			}
+
+			// Format the tool call
+			toolCall := fmt.Sprintf("%d. %s: %s\n   %s", i+1, toolName, inputStr, resultStr)
+			toolCalls = append(toolCalls, toolCall)
+		}
+
+		return baseStyle.Width(width).Foreground(t.TextMuted()).Render(strings.Join(toolCalls, "\n\n"))
 	default:
 		resultContent = fmt.Sprintf("```text\n%s\n```", resultContent)
 		return styles.ForceReplaceBackgroundWithLipgloss(