Browse Source

chore: refactor agent.go

adamdottv 9 months ago
parent
commit
d941be3f1f
1 changed files with 103 additions and 47 deletions
  1. 103 47
      internal/llm/agent/agent.go

+ 103 - 47
internal/llm/agent/agent.go

@@ -73,7 +73,7 @@ func NewAgent(
 		return nil, err
 		return nil, err
 	}
 	}
 	var titleProvider provider.Provider
 	var titleProvider provider.Provider
-	// Only generate titles for the coder agent
+	// Only generate titles for the primary agent
 	if agentName == config.AgentPrimary {
 	if agentName == config.AgentPrimary {
 		titleProvider, err = createAgentProvider(config.AgentTitle)
 		titleProvider, err = createAgentProvider(config.AgentTitle)
 		if err != nil {
 		if err != nil {
@@ -197,11 +197,11 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string, attac
 	return events, nil
 	return events, nil
 }
 }
 
 
-func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
+func (a *agent) prepareMessageHistory(ctx context.Context, sessionID string) (session.Session, []message.Message, error) {
 	// Get the current session to check for summary
 	// Get the current session to check for summary
 	currentSession, err := a.sessions.Get(ctx, sessionID)
 	currentSession, err := a.sessions.Get(ctx, sessionID)
 	if err != nil {
 	if err != nil {
-		return a.err(fmt.Errorf("failed to get session: %w", err))
+		return currentSession, nil, fmt.Errorf("failed to get session: %w", err)
 	}
 	}
 
 
 	// Fetch messages based on whether a summary exists
 	// Fetch messages based on whether a summary exists
@@ -210,34 +210,16 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
 		// If summary exists, only fetch messages after the summarization timestamp
 		// If summary exists, only fetch messages after the summarization timestamp
 		sessionMessages, err = a.messages.ListAfter(ctx, sessionID, currentSession.SummarizedAt)
 		sessionMessages, err = a.messages.ListAfter(ctx, sessionID, currentSession.SummarizedAt)
 		if err != nil {
 		if err != nil {
-			return a.err(fmt.Errorf("failed to list messages after summary: %w", err))
+			return currentSession, nil, fmt.Errorf("failed to list messages after summary: %w", err)
 		}
 		}
 	} else {
 	} else {
 		// If no summary, fetch all messages
 		// If no summary, fetch all messages
 		sessionMessages, err = a.messages.List(ctx, sessionID)
 		sessionMessages, err = a.messages.List(ctx, sessionID)
 		if err != nil {
 		if err != nil {
-			return a.err(fmt.Errorf("failed to list messages: %w", err))
+			return currentSession, nil, fmt.Errorf("failed to list messages: %w", err)
 		}
 		}
 	}
 	}
 
 
-	// If this is a new session, start title generation asynchronously
-	if len(sessionMessages) == 0 && currentSession.Summary == "" {
-		go func() {
-			defer logging.RecoverPanic("agent.Run", func() {
-				status.Error("panic while generating title")
-			})
-			titleErr := a.generateTitle(context.Background(), sessionID, content)
-			if titleErr != nil {
-				status.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
-			}
-		}()
-	}
-
-	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
-	if err != nil {
-		return a.err(fmt.Errorf("failed to create user message: %w", err))
-	}
-
 	// Prepare the message history for the LLM
 	// Prepare the message history for the LLM
 	var messages []message.Message
 	var messages []message.Message
 	if currentSession.Summary != "" && currentSession.SummarizedAt > 0 {
 	if currentSession.Summary != "" && currentSession.SummarizedAt > 0 {
@@ -255,8 +237,40 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
 		messages = sessionMessages
 		messages = sessionMessages
 	}
 	}
 
 
+	return currentSession, messages, nil
+}
+
+func (a *agent) triggerTitleGeneration(sessionID string, content string) {
+	go func() {
+		defer logging.RecoverPanic("agent.Run", func() {
+			status.Error("panic while generating title")
+		})
+		titleErr := a.generateTitle(context.Background(), sessionID, content)
+		if titleErr != nil {
+			status.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
+		}
+	}()
+}
+
+func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
+	// Get message history and session info
+	currentSession, sessionMessages, err := a.prepareMessageHistory(ctx, sessionID)
+	if err != nil {
+		return a.err(err)
+	}
+
+	// If this is a new session, start title generation asynchronously
+	if len(sessionMessages) == 0 && currentSession.Summary == "" {
+		a.triggerTitleGeneration(sessionID, content)
+	}
+
+	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
+	if err != nil {
+		return a.err(fmt.Errorf("failed to create user message: %w", err))
+	}
+
 	// Append the new user message to the conversation history
 	// Append the new user message to the conversation history
-	messages = append(messages, userMsg)
+	messages := append(sessionMessages, userMsg)
 
 
 	for {
 	for {
 		// Check for cancellation before each iteration
 		// Check for cancellation before each iteration
@@ -296,6 +310,27 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string
 	})
 	})
 }
 }
 
 
+func (a *agent) createToolResponseMessage(ctx context.Context, sessionID string, toolResults []message.ToolResult) (*message.Message, error) {
+	if len(toolResults) == 0 {
+		return nil, nil
+	}
+
+	parts := make([]message.ContentPart, 0, len(toolResults))
+	for _, tr := range toolResults {
+		parts = append(parts, tr)
+	}
+
+	msg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
+		Role:  message.Tool,
+		Parts: parts,
+	})
+	if err != nil {
+		return nil, fmt.Errorf("failed to create tool response message: %w", err)
+	}
+
+	return &msg, nil
+}
+
 func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
 	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
 
 
@@ -324,12 +359,37 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
 		}
 		}
 	}
 	}
 
 
-	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
-	toolCalls := assistantMsg.ToolCalls()
+	// If the assistant wants to use tools, execute them
+	if assistantMsg.FinishReason() == message.FinishReasonToolUse {
+		toolCalls := assistantMsg.ToolCalls()
+		if len(toolCalls) > 0 {
+			toolResults, err := a.executeToolCalls(ctx, toolCalls)
+			if err != nil {
+				if errors.Is(err, context.Canceled) {
+					a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
+				}
+				return assistantMsg, nil, err
+			}
+
+			// Create a message with the tool results
+			toolResponseMsg, err := a.createToolResponseMessage(ctx, sessionID, toolResults)
+			if err != nil {
+				return assistantMsg, nil, err
+			}
+
+			return assistantMsg, toolResponseMsg, nil
+		}
+	}
+
+	return assistantMsg, nil, nil
+}
+
+func (a *agent) executeToolCalls(ctx context.Context, toolCalls []message.ToolCall) ([]message.ToolResult, error) {
+	toolResults := make([]message.ToolResult, len(toolCalls))
+
 	for i, toolCall := range toolCalls {
 	for i, toolCall := range toolCalls {
 		select {
 		select {
 		case <-ctx.Done():
 		case <-ctx.Done():
-			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
 			// Make all future tool calls cancelled
 			// Make all future tool calls cancelled
 			for j := i; j < len(toolCalls); j++ {
 			for j := i; j < len(toolCalls); j++ {
 				toolResults[j] = message.ToolResult{
 				toolResults[j] = message.ToolResult{
@@ -338,7 +398,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
 					IsError:    true,
 					IsError:    true,
 				}
 				}
 			}
 			}
-			goto out
+			return toolResults, ctx.Err()
 		default:
 		default:
 			// Continue processing
 			// Continue processing
 			var tool tools.BaseTool
 			var tool tools.BaseTool
@@ -357,11 +417,13 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
 				}
 				}
 				continue
 				continue
 			}
 			}
+
 			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
 			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
 				ID:    toolCall.ID,
 				ID:    toolCall.ID,
 				Name:  toolCall.Name,
 				Name:  toolCall.Name,
 				Input: toolCall.Input,
 				Input: toolCall.Input,
 			})
 			})
+
 			if toolErr != nil {
 			if toolErr != nil {
 				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 					toolResults[i] = message.ToolResult{
 					toolResults[i] = message.ToolResult{
@@ -369,6 +431,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
 						Content:    "Permission denied",
 						Content:    "Permission denied",
 						IsError:    true,
 						IsError:    true,
 					}
 					}
+					// Cancel all remaining tool calls if permission is denied
 					for j := i + 1; j < len(toolCalls); j++ {
 					for j := i + 1; j < len(toolCalls); j++ {
 						toolResults[j] = message.ToolResult{
 						toolResults[j] = message.ToolResult{
 							ToolCallID: toolCalls[j].ID,
 							ToolCallID: toolCalls[j].ID,
@@ -376,10 +439,18 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
 							IsError:    true,
 							IsError:    true,
 						}
 						}
 					}
 					}
-					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
-					break
+					return toolResults, nil
+				}
+
+				// Handle other errors
+				toolResults[i] = message.ToolResult{
+					ToolCallID: toolCall.ID,
+					Content:    toolErr.Error(),
+					IsError:    true,
 				}
 				}
+				continue
 			}
 			}
+
 			toolResults[i] = message.ToolResult{
 			toolResults[i] = message.ToolResult{
 				ToolCallID: toolCall.ID,
 				ToolCallID: toolCall.ID,
 				Content:    toolResult.Content,
 				Content:    toolResult.Content,
@@ -388,23 +459,8 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
 			}
 			}
 		}
 		}
 	}
 	}
-out:
-	if len(toolResults) == 0 {
-		return assistantMsg, nil, nil
-	}
-	parts := make([]message.ContentPart, 0)
-	for _, tr := range toolResults {
-		parts = append(parts, tr)
-	}
-	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
-		Role:  message.Tool,
-		Parts: parts,
-	})
-	if err != nil {
-		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
-	}
 
 
-	return assistantMsg, &msg, err
+	return toolResults, nil
 }
 }
 
 
 func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
 func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
@@ -475,7 +531,7 @@ func (a *agent) GetUsage(ctx context.Context, sessionID string) (*int64, error)
 	return &usage, nil
 	return &usage, nil
 }
 }
 
 
-func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { //nolint:lll
+func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
 	sess, err := a.sessions.Get(ctx, sessionID)
 	sess, err := a.sessions.Get(ctx, sessionID)
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("failed to get session: %w", err)
 		return fmt.Errorf("failed to get session: %w", err)