|
|
@@ -73,7 +73,7 @@ func NewAgent(
|
|
|
return nil, err
|
|
|
}
|
|
|
var titleProvider provider.Provider
|
|
|
- // Only generate titles for the coder agent
|
|
|
+ // Only generate titles for the primary agent
|
|
|
if agentName == config.AgentPrimary {
|
|
|
titleProvider, err = createAgentProvider(config.AgentTitle)
|
|
|
if err != nil {
|
|
|
@@ -197,11 +197,11 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string, attac
|
|
|
return events, nil
|
|
|
}
|
|
|
|
|
|
-func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
|
|
|
+func (a *agent) prepareMessageHistory(ctx context.Context, sessionID string) (session.Session, []message.Message, error) {
|
|
|
// Get the current session to check for summary
|
|
|
currentSession, err := a.sessions.Get(ctx, sessionID)
|
|
|
if err != nil {
|
|
|
- return a.err(fmt.Errorf("failed to get session: %w", err))
|
|
|
+ return currentSession, nil, fmt.Errorf("failed to get session: %w", err)
|
|
|
}
|
|
|
|
|
|
// Fetch messages based on whether a summary exists
|
|
|
@@ -210,34 +210,16 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
|
|
|
// If summary exists, only fetch messages after the summarization timestamp
|
|
|
sessionMessages, err = a.messages.ListAfter(ctx, sessionID, currentSession.SummarizedAt)
|
|
|
if err != nil {
|
|
|
- return a.err(fmt.Errorf("failed to list messages after summary: %w", err))
|
|
|
+ return currentSession, nil, fmt.Errorf("failed to list messages after summary: %w", err)
|
|
|
}
|
|
|
} else {
|
|
|
// If no summary, fetch all messages
|
|
|
sessionMessages, err = a.messages.List(ctx, sessionID)
|
|
|
if err != nil {
|
|
|
- return a.err(fmt.Errorf("failed to list messages: %w", err))
|
|
|
+ return currentSession, nil, fmt.Errorf("failed to list messages: %w", err)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // If this is a new session, start title generation asynchronously
|
|
|
- if len(sessionMessages) == 0 && currentSession.Summary == "" {
|
|
|
- go func() {
|
|
|
- defer logging.RecoverPanic("agent.Run", func() {
|
|
|
- status.Error("panic while generating title")
|
|
|
- })
|
|
|
- titleErr := a.generateTitle(context.Background(), sessionID, content)
|
|
|
- if titleErr != nil {
|
|
|
- status.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
|
|
|
- }
|
|
|
- }()
|
|
|
- }
|
|
|
-
|
|
|
- userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
|
|
|
- if err != nil {
|
|
|
- return a.err(fmt.Errorf("failed to create user message: %w", err))
|
|
|
- }
|
|
|
-
|
|
|
// Prepare the message history for the LLM
|
|
|
var messages []message.Message
|
|
|
if currentSession.Summary != "" && currentSession.SummarizedAt > 0 {
|
|
|
@@ -255,8 +237,40 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
|
|
|
messages = sessionMessages
|
|
|
}
|
|
|
|
|
|
+ return currentSession, messages, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (a *agent) triggerTitleGeneration(sessionID string, content string) {
|
|
|
+ go func() {
|
|
|
+ defer logging.RecoverPanic("agent.Run", func() {
|
|
|
+ status.Error("panic while generating title")
|
|
|
+ })
|
|
|
+ titleErr := a.generateTitle(context.Background(), sessionID, content)
|
|
|
+ if titleErr != nil {
|
|
|
+ status.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
|
|
|
+ }
|
|
|
+ }()
|
|
|
+}
|
|
|
+
|
|
|
+func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
|
|
|
+ // Get message history and session info
|
|
|
+ currentSession, sessionMessages, err := a.prepareMessageHistory(ctx, sessionID)
|
|
|
+ if err != nil {
|
|
|
+ return a.err(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ // If this is a new session, start title generation asynchronously
|
|
|
+ if len(sessionMessages) == 0 && currentSession.Summary == "" {
|
|
|
+ a.triggerTitleGeneration(sessionID, content)
|
|
|
+ }
|
|
|
+
|
|
|
+ userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
|
|
|
+ if err != nil {
|
|
|
+ return a.err(fmt.Errorf("failed to create user message: %w", err))
|
|
|
+ }
|
|
|
+
|
|
|
// Append the new user message to the conversation history
|
|
|
- messages = append(messages, userMsg)
|
|
|
+ messages := append(sessionMessages, userMsg)
|
|
|
|
|
|
for {
|
|
|
// Check for cancellation before each iteration
|
|
|
@@ -296,6 +310,27 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string
|
|
|
})
|
|
|
}
|
|
|
|
|
|
+func (a *agent) createToolResponseMessage(ctx context.Context, sessionID string, toolResults []message.ToolResult) (*message.Message, error) {
|
|
|
+ if len(toolResults) == 0 {
|
|
|
+ return nil, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ parts := make([]message.ContentPart, 0, len(toolResults))
|
|
|
+ for _, tr := range toolResults {
|
|
|
+ parts = append(parts, tr)
|
|
|
+ }
|
|
|
+
|
|
|
+ msg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
|
|
|
+ Role: message.Tool,
|
|
|
+ Parts: parts,
|
|
|
+ })
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("failed to create tool response message: %w", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ return &msg, nil
|
|
|
+}
|
|
|
+
|
|
|
func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
|
|
|
eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
|
|
|
|
|
|
@@ -324,12 +359,37 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
|
|
|
- toolCalls := assistantMsg.ToolCalls()
|
|
|
+ // If the assistant wants to use tools, execute them
|
|
|
+ if assistantMsg.FinishReason() == message.FinishReasonToolUse {
|
|
|
+ toolCalls := assistantMsg.ToolCalls()
|
|
|
+ if len(toolCalls) > 0 {
|
|
|
+ toolResults, err := a.executeToolCalls(ctx, toolCalls)
|
|
|
+ if err != nil {
|
|
|
+ if errors.Is(err, context.Canceled) {
|
|
|
+ a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
|
|
|
+ }
|
|
|
+ return assistantMsg, nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ // Create a message with the tool results
|
|
|
+ toolResponseMsg, err := a.createToolResponseMessage(ctx, sessionID, toolResults)
|
|
|
+ if err != nil {
|
|
|
+ return assistantMsg, nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ return assistantMsg, toolResponseMsg, nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return assistantMsg, nil, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (a *agent) executeToolCalls(ctx context.Context, toolCalls []message.ToolCall) ([]message.ToolResult, error) {
|
|
|
+ toolResults := make([]message.ToolResult, len(toolCalls))
|
|
|
+
|
|
|
for i, toolCall := range toolCalls {
|
|
|
select {
|
|
|
case <-ctx.Done():
|
|
|
- a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
|
|
|
// Make all future tool calls cancelled
|
|
|
for j := i; j < len(toolCalls); j++ {
|
|
|
toolResults[j] = message.ToolResult{
|
|
|
@@ -338,7 +398,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
|
|
|
IsError: true,
|
|
|
}
|
|
|
}
|
|
|
- goto out
|
|
|
+ return toolResults, ctx.Err()
|
|
|
default:
|
|
|
// Continue processing
|
|
|
var tool tools.BaseTool
|
|
|
@@ -357,11 +417,13 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
|
|
|
}
|
|
|
continue
|
|
|
}
|
|
|
+
|
|
|
toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
|
|
|
ID: toolCall.ID,
|
|
|
Name: toolCall.Name,
|
|
|
Input: toolCall.Input,
|
|
|
})
|
|
|
+
|
|
|
if toolErr != nil {
|
|
|
if errors.Is(toolErr, permission.ErrorPermissionDenied) {
|
|
|
toolResults[i] = message.ToolResult{
|
|
|
@@ -369,6 +431,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
|
|
|
Content: "Permission denied",
|
|
|
IsError: true,
|
|
|
}
|
|
|
+ // Cancel all remaining tool calls if permission is denied
|
|
|
for j := i + 1; j < len(toolCalls); j++ {
|
|
|
toolResults[j] = message.ToolResult{
|
|
|
ToolCallID: toolCalls[j].ID,
|
|
|
@@ -376,10 +439,18 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
|
|
|
IsError: true,
|
|
|
}
|
|
|
}
|
|
|
- a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
|
|
|
- break
|
|
|
+ return toolResults, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ // Handle other errors
|
|
|
+ toolResults[i] = message.ToolResult{
|
|
|
+ ToolCallID: toolCall.ID,
|
|
|
+ Content: toolErr.Error(),
|
|
|
+ IsError: true,
|
|
|
}
|
|
|
+ continue
|
|
|
}
|
|
|
+
|
|
|
toolResults[i] = message.ToolResult{
|
|
|
ToolCallID: toolCall.ID,
|
|
|
Content: toolResult.Content,
|
|
|
@@ -388,23 +459,8 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
-out:
|
|
|
- if len(toolResults) == 0 {
|
|
|
- return assistantMsg, nil, nil
|
|
|
- }
|
|
|
- parts := make([]message.ContentPart, 0)
|
|
|
- for _, tr := range toolResults {
|
|
|
- parts = append(parts, tr)
|
|
|
- }
|
|
|
- msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
|
|
|
- Role: message.Tool,
|
|
|
- Parts: parts,
|
|
|
- })
|
|
|
- if err != nil {
|
|
|
- return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
|
|
|
- }
|
|
|
|
|
|
- return assistantMsg, &msg, err
|
|
|
+ return toolResults, nil
|
|
|
}
|
|
|
|
|
|
func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
|
|
|
@@ -475,7 +531,7 @@ func (a *agent) GetUsage(ctx context.Context, sessionID string) (*int64, error)
|
|
|
return &usage, nil
|
|
|
}
|
|
|
|
|
|
-func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { //nolint:lll
|
|
|
+func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
|
|
|
sess, err := a.sessions.Get(ctx, sessionID)
|
|
|
if err != nil {
|
|
|
return fmt.Errorf("failed to get session: %w", err)
|