|
|
@@ -4,8 +4,6 @@ import (
|
|
|
"context"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
- "os"
|
|
|
- "runtime/debug"
|
|
|
"strings"
|
|
|
"sync"
|
|
|
|
|
|
@@ -16,133 +14,101 @@ import (
|
|
|
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
|
|
"github.com/kujtimiihoxha/termai/internal/logging"
|
|
|
"github.com/kujtimiihoxha/termai/internal/message"
|
|
|
+ "github.com/kujtimiihoxha/termai/internal/permission"
|
|
|
"github.com/kujtimiihoxha/termai/internal/session"
|
|
|
)
|
|
|
|
|
|
// Common errors
|
|
|
var (
|
|
|
- ErrProviderNotEnabled = errors.New("provider is not enabled")
|
|
|
- ErrRequestCancelled = errors.New("request cancelled by user")
|
|
|
- ErrSessionBusy = errors.New("session is currently processing another request")
|
|
|
+ ErrRequestCancelled = errors.New("request cancelled by user")
|
|
|
+ ErrSessionBusy = errors.New("session is currently processing another request")
|
|
|
)
|
|
|
|
|
|
-// Service defines the interface for generating responses
|
|
|
+type AgentEvent struct {
|
|
|
+ message message.Message
|
|
|
+ err error
|
|
|
+}
|
|
|
+
|
|
|
+func (e *AgentEvent) Err() error {
|
|
|
+ return e.err
|
|
|
+}
|
|
|
+
|
|
|
+func (e *AgentEvent) Response() message.Message {
|
|
|
+ return e.message
|
|
|
+}
|
|
|
+
|
|
|
type Service interface {
|
|
|
- Generate(ctx context.Context, sessionID string, content string) error
|
|
|
- Cancel(sessionID string) error
|
|
|
+ Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error)
|
|
|
+ Cancel(sessionID string)
|
|
|
+ IsSessionBusy(sessionID string) bool
|
|
|
}
|
|
|
|
|
|
type agent struct {
|
|
|
- sessions session.Service
|
|
|
- messages message.Service
|
|
|
- model models.Model
|
|
|
- tools []tools.BaseTool
|
|
|
- agent provider.Provider
|
|
|
- titleGenerator provider.Provider
|
|
|
- activeRequests sync.Map // map[sessionID]context.CancelFunc
|
|
|
+ sessions session.Service
|
|
|
+ messages message.Service
|
|
|
+
|
|
|
+ tools []tools.BaseTool
|
|
|
+ provider provider.Provider
|
|
|
+
|
|
|
+ titleProvider provider.Provider
|
|
|
+
|
|
|
+ activeRequests sync.Map
|
|
|
}
|
|
|
|
|
|
-// NewAgent creates a new agent instance with the given model and tools
|
|
|
-func NewAgent(ctx context.Context, sessions session.Service, messages message.Service, model models.Model, tools []tools.BaseTool) (Service, error) {
|
|
|
- agentProvider, titleGenerator, err := getAgentProviders(ctx, model)
|
|
|
+func NewAgent(
|
|
|
+ agentName config.AgentName,
|
|
|
+ sessions session.Service,
|
|
|
+ messages message.Service,
|
|
|
+ agentTools []tools.BaseTool,
|
|
|
+) (Service, error) {
|
|
|
+ agentProvider, err := createAgentProvider(agentName)
|
|
|
if err != nil {
|
|
|
- return nil, fmt.Errorf("failed to initialize providers: %w", err)
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ var titleProvider provider.Provider
|
|
|
+ // Only generate titles for the coder agent
|
|
|
+ if agentName == config.AgentCoder {
|
|
|
+ titleProvider, err = createAgentProvider(config.AgentTitle)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- return &agent{
|
|
|
- model: model,
|
|
|
- tools: tools,
|
|
|
- sessions: sessions,
|
|
|
+ agent := &agent{
|
|
|
+ provider: agentProvider,
|
|
|
messages: messages,
|
|
|
- agent: agentProvider,
|
|
|
- titleGenerator: titleGenerator,
|
|
|
+ sessions: sessions,
|
|
|
+ tools: agentTools,
|
|
|
+ titleProvider: titleProvider,
|
|
|
activeRequests: sync.Map{},
|
|
|
- }, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ return agent, nil
|
|
|
}
|
|
|
|
|
|
-// Cancel cancels an active request by session ID
|
|
|
-func (a *agent) Cancel(sessionID string) error {
|
|
|
+func (a *agent) Cancel(sessionID string) {
|
|
|
if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
|
|
|
if cancel, ok := cancelFunc.(context.CancelFunc); ok {
|
|
|
logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
|
|
|
cancel()
|
|
|
- return nil
|
|
|
}
|
|
|
}
|
|
|
- return errors.New("no active request found for this session")
|
|
|
}
|
|
|
|
|
|
-// Generate starts the generation process
|
|
|
-func (a *agent) Generate(ctx context.Context, sessionID string, content string) error {
|
|
|
- // Check if this session already has an active request
|
|
|
- if _, busy := a.activeRequests.Load(sessionID); busy {
|
|
|
- return ErrSessionBusy
|
|
|
- }
|
|
|
-
|
|
|
- // Create a cancellable context
|
|
|
- genCtx, cancel := context.WithCancel(ctx)
|
|
|
-
|
|
|
- // Store cancel function to allow user cancellation
|
|
|
- a.activeRequests.Store(sessionID, cancel)
|
|
|
-
|
|
|
- // Launch the generation in a goroutine
|
|
|
- go func() {
|
|
|
- defer func() {
|
|
|
- if r := recover(); r != nil {
|
|
|
- logging.ErrorPersist(fmt.Sprintf("Panic in Generate: %v", r))
|
|
|
-
|
|
|
- // dump stack trace into a file
|
|
|
- file, err := os.Create("panic.log")
|
|
|
- if err != nil {
|
|
|
- logging.ErrorPersist(fmt.Sprintf("Failed to create panic log: %v", err))
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- defer file.Close()
|
|
|
-
|
|
|
- stackTrace := debug.Stack()
|
|
|
- if _, err := file.Write(stackTrace); err != nil {
|
|
|
- logging.ErrorPersist(fmt.Sprintf("Failed to write panic log: %v", err))
|
|
|
- }
|
|
|
-
|
|
|
- }
|
|
|
- }()
|
|
|
- defer a.activeRequests.Delete(sessionID)
|
|
|
- defer cancel()
|
|
|
-
|
|
|
- if err := a.generate(genCtx, sessionID, content); err != nil {
|
|
|
- if !errors.Is(err, ErrRequestCancelled) && !errors.Is(err, context.Canceled) {
|
|
|
- // Log the error (avoid logging cancellations as they're expected)
|
|
|
- logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, err))
|
|
|
-
|
|
|
- // You may want to create an error message in the chat
|
|
|
- bgCtx := context.Background()
|
|
|
- errorMsg := fmt.Sprintf("Sorry, an error occurred: %v", err)
|
|
|
- _, createErr := a.messages.Create(bgCtx, sessionID, message.CreateMessageParams{
|
|
|
- Role: message.System,
|
|
|
- Parts: []message.ContentPart{
|
|
|
- message.TextContent{
|
|
|
- Text: errorMsg,
|
|
|
- },
|
|
|
- },
|
|
|
- })
|
|
|
- if createErr != nil {
|
|
|
- logging.ErrorPersist(fmt.Sprintf("Failed to create error message: %v", createErr))
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }()
|
|
|
-
|
|
|
- return nil
|
|
|
-}
|
|
|
-
|
|
|
-// IsSessionBusy checks if a session currently has an active request
|
|
|
func (a *agent) IsSessionBusy(sessionID string) bool {
|
|
|
_, busy := a.activeRequests.Load(sessionID)
|
|
|
return busy
|
|
|
-} // handleTitleGeneration asynchronously generates a title for new sessions
|
|
|
-func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
|
|
|
- response, err := a.titleGenerator.SendMessages(
|
|
|
+}
|
|
|
+
|
|
|
+func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
|
|
|
+ if a.titleProvider == nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ session, err := a.sessions.Get(ctx, sessionID)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ response, err := a.titleProvider.SendMessages(
|
|
|
ctx,
|
|
|
[]message.Message{
|
|
|
{
|
|
|
@@ -154,121 +120,152 @@ func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content st
|
|
|
},
|
|
|
},
|
|
|
},
|
|
|
- nil,
|
|
|
+ make([]tools.BaseTool, 0),
|
|
|
)
|
|
|
if err != nil {
|
|
|
- logging.ErrorPersist(fmt.Sprintf("Failed to generate title: %v", err))
|
|
|
- return
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
- session, err := a.sessions.Get(ctx, sessionID)
|
|
|
- if err != nil {
|
|
|
- logging.ErrorPersist(fmt.Sprintf("Failed to get session: %v", err))
|
|
|
- return
|
|
|
+ title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
|
|
|
+ if title == "" {
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
- if response.Content != "" {
|
|
|
- session.Title = strings.TrimSpace(response.Content)
|
|
|
- session.Title = strings.ReplaceAll(session.Title, "\n", " ")
|
|
|
- if _, err := a.sessions.Save(ctx, session); err != nil {
|
|
|
- logging.ErrorPersist(fmt.Sprintf("Failed to save session title: %v", err))
|
|
|
- }
|
|
|
+ session.Title = title
|
|
|
+ _, err = a.sessions.Save(ctx, session)
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
+func (a *agent) err(err error) AgentEvent {
|
|
|
+ return AgentEvent{
|
|
|
+ err: err,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-// TrackUsage updates token usage statistics for the session
|
|
|
-func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
|
|
|
- session, err := a.sessions.Get(ctx, sessionID)
|
|
|
- if err != nil {
|
|
|
- return fmt.Errorf("failed to get session: %w", err)
|
|
|
+func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) {
|
|
|
+ events := make(chan AgentEvent)
|
|
|
+ if a.IsSessionBusy(sessionID) {
|
|
|
+ return nil, ErrSessionBusy
|
|
|
}
|
|
|
|
|
|
- cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
|
|
|
- model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
|
|
|
- model.CostPer1MIn/1e6*float64(usage.InputTokens) +
|
|
|
- model.CostPer1MOut/1e6*float64(usage.OutputTokens)
|
|
|
+ genCtx, cancel := context.WithCancel(ctx)
|
|
|
+
|
|
|
+ a.activeRequests.Store(sessionID, cancel)
|
|
|
+ go func() {
|
|
|
+ logging.Debug("Request started", "sessionID", sessionID)
|
|
|
+ defer logging.RecoverPanic("agent.Run", func() {
|
|
|
+ events <- a.err(fmt.Errorf("panic while running the agent"))
|
|
|
+ })
|
|
|
|
|
|
- session.Cost += cost
|
|
|
- session.CompletionTokens += usage.OutputTokens
|
|
|
- session.PromptTokens += usage.InputTokens
|
|
|
+ result := a.processGeneration(genCtx, sessionID, content)
|
|
|
+ if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
|
|
|
+ logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, result))
|
|
|
+ }
|
|
|
+ logging.Debug("Request completed", "sessionID", sessionID)
|
|
|
+ a.activeRequests.Delete(sessionID)
|
|
|
+ cancel()
|
|
|
+ events <- result
|
|
|
+ close(events)
|
|
|
+ }()
|
|
|
+ return events, nil
|
|
|
+}
|
|
|
|
|
|
- _, err = a.sessions.Save(ctx, session)
|
|
|
+func (a *agent) processGeneration(ctx context.Context, sessionID, content string) AgentEvent {
|
|
|
+ // List existing messages; if none, start title generation asynchronously.
|
|
|
+ msgs, err := a.messages.List(ctx, sessionID)
|
|
|
if err != nil {
|
|
|
- return fmt.Errorf("failed to save session: %w", err)
|
|
|
+ return a.err(fmt.Errorf("failed to list messages: %w", err))
|
|
|
+ }
|
|
|
+ if len(msgs) == 0 {
|
|
|
+ go func() {
|
|
|
+ defer logging.RecoverPanic("agent.Run", func() {
|
|
|
+ logging.ErrorPersist("panic while generating title")
|
|
|
+ })
|
|
|
+ titleErr := a.generateTitle(context.Background(), sessionID, content)
|
|
|
+ if titleErr != nil {
|
|
|
+ logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
|
|
|
+ }
|
|
|
+ }()
|
|
|
}
|
|
|
- return nil
|
|
|
-}
|
|
|
|
|
|
-// processEvent handles different types of events during generation
|
|
|
-func (a *agent) processEvent(
|
|
|
- ctx context.Context,
|
|
|
- sessionID string,
|
|
|
- assistantMsg *message.Message,
|
|
|
- event provider.ProviderEvent,
|
|
|
-) error {
|
|
|
- select {
|
|
|
- case <-ctx.Done():
|
|
|
- return ctx.Err()
|
|
|
- default:
|
|
|
- // Continue processing
|
|
|
+ userMsg, err := a.createUserMessage(ctx, sessionID, content)
|
|
|
+ if err != nil {
|
|
|
+ return a.err(fmt.Errorf("failed to create user message: %w", err))
|
|
|
}
|
|
|
|
|
|
- switch event.Type {
|
|
|
- case provider.EventThinkingDelta:
|
|
|
- assistantMsg.AppendReasoningContent(event.Content)
|
|
|
- return a.messages.Update(ctx, *assistantMsg)
|
|
|
- case provider.EventContentDelta:
|
|
|
- assistantMsg.AppendContent(event.Content)
|
|
|
- return a.messages.Update(ctx, *assistantMsg)
|
|
|
- case provider.EventError:
|
|
|
- if errors.Is(event.Error, context.Canceled) {
|
|
|
- logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
|
|
|
- return context.Canceled
|
|
|
+ // Append the new user message to the conversation history.
|
|
|
+ msgHistory := append(msgs, userMsg)
|
|
|
+ for {
|
|
|
+ // Check for cancellation before each iteration
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ return a.err(ctx.Err())
|
|
|
+ default:
|
|
|
+ // Continue processing
|
|
|
}
|
|
|
- logging.ErrorPersist(event.Error.Error())
|
|
|
- return event.Error
|
|
|
- case provider.EventWarning:
|
|
|
- logging.WarnPersist(event.Info)
|
|
|
- case provider.EventInfo:
|
|
|
- logging.InfoPersist(event.Info)
|
|
|
- case provider.EventComplete:
|
|
|
- assistantMsg.SetToolCalls(event.Response.ToolCalls)
|
|
|
- assistantMsg.AddFinish(event.Response.FinishReason)
|
|
|
- if err := a.messages.Update(ctx, *assistantMsg); err != nil {
|
|
|
- return fmt.Errorf("failed to update message: %w", err)
|
|
|
+ agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
|
|
|
+ if err != nil {
|
|
|
+ if errors.Is(err, context.Canceled) {
|
|
|
+ return a.err(ErrRequestCancelled)
|
|
|
+ }
|
|
|
+ return a.err(fmt.Errorf("failed to process events: %w", err))
|
|
|
+ }
|
|
|
+ logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
|
|
|
+ if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
|
|
|
+ // We are not done, we need to respond with the tool response
|
|
|
+ msgHistory = append(msgHistory, agentMessage, *toolResults)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ return AgentEvent{
|
|
|
+ message: agentMessage,
|
|
|
}
|
|
|
- return a.TrackUsage(ctx, sessionID, a.model, event.Response.Usage)
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- return nil
|
|
|
+func (a *agent) createUserMessage(ctx context.Context, sessionID, content string) (message.Message, error) {
|
|
|
+ return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
|
|
|
+ Role: message.User,
|
|
|
+ Parts: []message.ContentPart{
|
|
|
+ message.TextContent{Text: content},
|
|
|
+ },
|
|
|
+ })
|
|
|
}
|
|
|
|
|
|
-// ExecuteTools runs all tool calls sequentially and returns the results
|
|
|
-func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
|
|
|
- toolResults := make([]message.ToolResult, len(toolCalls))
|
|
|
+func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
|
|
|
+ eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
|
|
|
+
|
|
|
+ assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
|
|
|
+ Role: message.Assistant,
|
|
|
+ Parts: []message.ContentPart{},
|
|
|
+ Model: a.provider.Model().ID,
|
|
|
+ })
|
|
|
+ if err != nil {
|
|
|
+ return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
|
|
|
+ }
|
|
|
|
|
|
- // Create a child context that can be canceled
|
|
|
- ctx, cancel := context.WithCancel(ctx)
|
|
|
- defer cancel()
|
|
|
+ // Add the session and message ID into the context if needed by tools.
|
|
|
+ ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
|
|
|
+ ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
|
|
|
|
|
|
- // Check if already canceled before starting any execution
|
|
|
- if ctx.Err() != nil {
|
|
|
- // Mark all tools as canceled
|
|
|
- for i, toolCall := range toolCalls {
|
|
|
- toolResults[i] = message.ToolResult{
|
|
|
- ToolCallID: toolCall.ID,
|
|
|
- Content: "Tool execution canceled by user",
|
|
|
- IsError: true,
|
|
|
- }
|
|
|
+ // Process each event in the stream.
|
|
|
+ for event := range eventChan {
|
|
|
+ if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
|
|
|
+ a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
|
|
|
+ return assistantMsg, nil, processErr
|
|
|
+ }
|
|
|
+ if ctx.Err() != nil {
|
|
|
+ a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
|
|
|
+ return assistantMsg, nil, ctx.Err()
|
|
|
}
|
|
|
- return toolResults, ctx.Err()
|
|
|
}
|
|
|
|
|
|
+ toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
|
|
|
+ toolCalls := assistantMsg.ToolCalls()
|
|
|
for i, toolCall := range toolCalls {
|
|
|
- // Check for cancellation before executing each tool
|
|
|
select {
|
|
|
case <-ctx.Done():
|
|
|
- // Mark this and all remaining tools as canceled
|
|
|
+ a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
|
|
|
+ // Make all future tool calls cancelled
|
|
|
for j := i; j < len(toolCalls); j++ {
|
|
|
toolResults[j] = message.ToolResult{
|
|
|
ToolCallID: toolCalls[j].ID,
|
|
|
@@ -276,412 +273,180 @@ func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall,
|
|
|
IsError: true,
|
|
|
}
|
|
|
}
|
|
|
- return toolResults, ctx.Err()
|
|
|
+ goto out
|
|
|
default:
|
|
|
// Continue processing
|
|
|
- }
|
|
|
-
|
|
|
- response := ""
|
|
|
- isError := false
|
|
|
- found := false
|
|
|
-
|
|
|
- // Find and execute the appropriate tool
|
|
|
- for _, tool := range tls {
|
|
|
- if tool.Info().Name == toolCall.Name {
|
|
|
- found = true
|
|
|
- toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
|
|
|
- ID: toolCall.ID,
|
|
|
- Name: toolCall.Name,
|
|
|
- Input: toolCall.Input,
|
|
|
- })
|
|
|
-
|
|
|
- if toolErr != nil {
|
|
|
- if errors.Is(toolErr, context.Canceled) {
|
|
|
- response = "Tool execution canceled by user"
|
|
|
- } else {
|
|
|
- response = fmt.Sprintf("Error running tool: %s", toolErr)
|
|
|
- }
|
|
|
- isError = true
|
|
|
- } else {
|
|
|
- response = toolResult.Content
|
|
|
- isError = toolResult.IsError
|
|
|
+ var tool tools.BaseTool
|
|
|
+ for _, availableTools := range a.tools {
|
|
|
+ if availableTools.Info().Name == toolCall.Name {
|
|
|
+ tool = availableTools
|
|
|
}
|
|
|
- break
|
|
|
}
|
|
|
- }
|
|
|
-
|
|
|
- if !found {
|
|
|
- response = fmt.Sprintf("Tool not found: %s", toolCall.Name)
|
|
|
- isError = true
|
|
|
- }
|
|
|
-
|
|
|
- toolResults[i] = message.ToolResult{
|
|
|
- ToolCallID: toolCall.ID,
|
|
|
- Content: response,
|
|
|
- IsError: isError,
|
|
|
- }
|
|
|
- }
|
|
|
|
|
|
- return toolResults, nil
|
|
|
-}
|
|
|
-
|
|
|
-// handleToolExecution processes tool calls and creates tool result messages
|
|
|
-func (a *agent) handleToolExecution(
|
|
|
- ctx context.Context,
|
|
|
- assistantMsg message.Message,
|
|
|
-) (*message.Message, error) {
|
|
|
- select {
|
|
|
- case <-ctx.Done():
|
|
|
- // If cancelled, create tool results that indicate cancellation
|
|
|
- if len(assistantMsg.ToolCalls()) > 0 {
|
|
|
- toolResults := make([]message.ToolResult, 0, len(assistantMsg.ToolCalls()))
|
|
|
- for _, tc := range assistantMsg.ToolCalls() {
|
|
|
- toolResults = append(toolResults, message.ToolResult{
|
|
|
- ToolCallID: tc.ID,
|
|
|
- Content: "Tool execution canceled by user",
|
|
|
+ // Tool not found
|
|
|
+ if tool == nil {
|
|
|
+ toolResults[i] = message.ToolResult{
|
|
|
+ ToolCallID: toolCall.ID,
|
|
|
+ Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
|
|
|
IsError: true,
|
|
|
- })
|
|
|
+ }
|
|
|
+ continue
|
|
|
}
|
|
|
|
|
|
- // Use background context to ensure the message is created even if original context is cancelled
|
|
|
- bgCtx := context.Background()
|
|
|
- parts := make([]message.ContentPart, 0)
|
|
|
- for _, toolResult := range toolResults {
|
|
|
- parts = append(parts, toolResult)
|
|
|
- }
|
|
|
- msg, err := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
|
|
|
- Role: message.Tool,
|
|
|
- Parts: parts,
|
|
|
+ toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
|
|
|
+ ID: toolCall.ID,
|
|
|
+ Name: toolCall.Name,
|
|
|
+ Input: toolCall.Input,
|
|
|
})
|
|
|
- if err != nil {
|
|
|
- return nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
|
|
|
- }
|
|
|
- return &msg, ctx.Err()
|
|
|
- }
|
|
|
- return nil, ctx.Err()
|
|
|
- default:
|
|
|
- // Continue processing
|
|
|
- }
|
|
|
-
|
|
|
- if len(assistantMsg.ToolCalls()) == 0 {
|
|
|
- return nil, nil
|
|
|
- }
|
|
|
-
|
|
|
- toolResults, err := a.ExecuteTools(ctx, assistantMsg.ToolCalls(), a.tools)
|
|
|
- if err != nil {
|
|
|
- // If error is from cancellation, still return the partial results we have
|
|
|
- if errors.Is(err, context.Canceled) {
|
|
|
- // Use background context to ensure the message is created even if original context is cancelled
|
|
|
- bgCtx := context.Background()
|
|
|
- parts := make([]message.ContentPart, 0)
|
|
|
- for _, toolResult := range toolResults {
|
|
|
- parts = append(parts, toolResult)
|
|
|
+ if toolErr != nil {
|
|
|
+ if errors.Is(toolErr, permission.ErrorPermissionDenied) {
|
|
|
+ toolResults[i] = message.ToolResult{
|
|
|
+ ToolCallID: toolCall.ID,
|
|
|
+ Content: "Permission denied",
|
|
|
+ IsError: true,
|
|
|
+ }
|
|
|
+ for j := i + 1; j < len(toolCalls); j++ {
|
|
|
+ toolResults[j] = message.ToolResult{
|
|
|
+ ToolCallID: toolCalls[j].ID,
|
|
|
+ Content: "Tool execution canceled by user",
|
|
|
+ IsError: true,
|
|
|
+ }
|
|
|
+ }
|
|
|
+ a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
|
|
|
+ } else {
|
|
|
+ toolResults[i] = message.ToolResult{
|
|
|
+ ToolCallID: toolCall.ID,
|
|
|
+ Content: toolErr.Error(),
|
|
|
+ IsError: true,
|
|
|
+ }
|
|
|
+ for j := i; j < len(toolCalls); j++ {
|
|
|
+ toolResults[j] = message.ToolResult{
|
|
|
+ ToolCallID: toolCalls[j].ID,
|
|
|
+ Content: "Previous tool failed",
|
|
|
+ IsError: true,
|
|
|
+ }
|
|
|
+ }
|
|
|
+ a.finishMessage(ctx, &assistantMsg, message.FinishReasonError)
|
|
|
+ }
|
|
|
+ // If permission is denied or an error happens we cancel all the following tools
|
|
|
+ break
|
|
|
}
|
|
|
-
|
|
|
- msg, createErr := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
|
|
|
- Role: message.Tool,
|
|
|
- Parts: parts,
|
|
|
- })
|
|
|
- if createErr != nil {
|
|
|
- logging.ErrorPersist(fmt.Sprintf("Failed to create tool message after cancellation: %v", createErr))
|
|
|
- return nil, err
|
|
|
+ toolResults[i] = message.ToolResult{
|
|
|
+ ToolCallID: toolCall.ID,
|
|
|
+ Content: toolResult.Content,
|
|
|
+ Metadata: toolResult.Metadata,
|
|
|
+ IsError: toolResult.IsError,
|
|
|
}
|
|
|
- return &msg, err
|
|
|
}
|
|
|
- return nil, err
|
|
|
}
|
|
|
-
|
|
|
- parts := make([]message.ContentPart, 0, len(toolResults))
|
|
|
- for _, toolResult := range toolResults {
|
|
|
- parts = append(parts, toolResult)
|
|
|
+out:
|
|
|
+ if len(toolResults) == 0 {
|
|
|
+ return assistantMsg, nil, nil
|
|
|
}
|
|
|
-
|
|
|
- msg, err := a.messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{
|
|
|
+ parts := make([]message.ContentPart, 0)
|
|
|
+ for _, tr := range toolResults {
|
|
|
+ parts = append(parts, tr)
|
|
|
+ }
|
|
|
+ msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
|
|
|
Role: message.Tool,
|
|
|
Parts: parts,
|
|
|
})
|
|
|
if err != nil {
|
|
|
- return nil, fmt.Errorf("failed to create tool message: %w", err)
|
|
|
+ return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
|
|
|
}
|
|
|
|
|
|
- return &msg, nil
|
|
|
+ return assistantMsg, &msg, err
|
|
|
}
|
|
|
|
|
|
-// generate handles the main generation workflow
|
|
|
-func (a *agent) generate(ctx context.Context, sessionID string, content string) error {
|
|
|
- ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
|
|
|
+func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
|
|
|
+ msg.AddFinish(finishReson)
|
|
|
+ _ = a.messages.Update(ctx, *msg)
|
|
|
+}
|
|
|
|
|
|
- // Handle context cancellation at any point
|
|
|
- if err := ctx.Err(); err != nil {
|
|
|
- return ErrRequestCancelled
|
|
|
+func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ return ctx.Err()
|
|
|
+ default:
|
|
|
+ // Continue processing.
|
|
|
}
|
|
|
|
|
|
- messages, err := a.messages.List(ctx, sessionID)
|
|
|
- if err != nil {
|
|
|
- return fmt.Errorf("failed to list messages: %w", err)
|
|
|
+ switch event.Type {
|
|
|
+ case provider.EventThinkingDelta:
|
|
|
+ assistantMsg.AppendReasoningContent(event.Content)
|
|
|
+ return a.messages.Update(ctx, *assistantMsg)
|
|
|
+ case provider.EventContentDelta:
|
|
|
+ assistantMsg.AppendContent(event.Content)
|
|
|
+ return a.messages.Update(ctx, *assistantMsg)
|
|
|
+ case provider.EventError:
|
|
|
+ if errors.Is(event.Error, context.Canceled) {
|
|
|
+ logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
|
|
|
+ return context.Canceled
|
|
|
+ }
|
|
|
+ logging.ErrorPersist(event.Error.Error())
|
|
|
+ return event.Error
|
|
|
+ case provider.EventComplete:
|
|
|
+ assistantMsg.SetToolCalls(event.Response.ToolCalls)
|
|
|
+ assistantMsg.AddFinish(event.Response.FinishReason)
|
|
|
+ if err := a.messages.Update(ctx, *assistantMsg); err != nil {
|
|
|
+ return fmt.Errorf("failed to update message: %w", err)
|
|
|
+ }
|
|
|
+ return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
|
|
|
}
|
|
|
|
|
|
- if len(messages) == 0 {
|
|
|
- titleCtx := context.Background()
|
|
|
- go a.handleTitleGeneration(titleCtx, sessionID, content)
|
|
|
- }
|
|
|
+ return nil
|
|
|
+}
|
|
|
|
|
|
- userMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
|
|
|
- Role: message.User,
|
|
|
- Parts: []message.ContentPart{
|
|
|
- message.TextContent{
|
|
|
- Text: content,
|
|
|
- },
|
|
|
- },
|
|
|
- })
|
|
|
+func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
|
|
|
+ sess, err := a.sessions.Get(ctx, sessionID)
|
|
|
if err != nil {
|
|
|
- return fmt.Errorf("failed to create user message: %w", err)
|
|
|
+ return fmt.Errorf("failed to get session: %w", err)
|
|
|
}
|
|
|
|
|
|
- messages = append(messages, userMsg)
|
|
|
-
|
|
|
- for {
|
|
|
- // Check for cancellation before each iteration
|
|
|
- select {
|
|
|
- case <-ctx.Done():
|
|
|
- return ErrRequestCancelled
|
|
|
- default:
|
|
|
- // Continue processing
|
|
|
- }
|
|
|
-
|
|
|
- eventChan, err := a.agent.StreamResponse(ctx, messages, a.tools)
|
|
|
- if err != nil {
|
|
|
- if errors.Is(err, context.Canceled) {
|
|
|
- return ErrRequestCancelled
|
|
|
- }
|
|
|
- return fmt.Errorf("failed to stream response: %w", err)
|
|
|
- }
|
|
|
-
|
|
|
- assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
|
|
|
- Role: message.Assistant,
|
|
|
- Parts: []message.ContentPart{},
|
|
|
- Model: a.model.ID,
|
|
|
- })
|
|
|
- if err != nil {
|
|
|
- return fmt.Errorf("failed to create assistant message: %w", err)
|
|
|
- }
|
|
|
-
|
|
|
- ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
|
|
|
-
|
|
|
- // Process events from the LLM provider
|
|
|
- for event := range eventChan {
|
|
|
- if err := a.processEvent(ctx, sessionID, &assistantMsg, event); err != nil {
|
|
|
- if errors.Is(err, context.Canceled) {
|
|
|
- // Mark as canceled but don't create separate message
|
|
|
- assistantMsg.AddFinish("canceled")
|
|
|
- _ = a.messages.Update(context.Background(), assistantMsg)
|
|
|
- return ErrRequestCancelled
|
|
|
- }
|
|
|
- assistantMsg.AddFinish("error:" + err.Error())
|
|
|
- _ = a.messages.Update(ctx, assistantMsg)
|
|
|
- return fmt.Errorf("event processing error: %w", err)
|
|
|
- }
|
|
|
-
|
|
|
- // Check for cancellation during event processing
|
|
|
- select {
|
|
|
- case <-ctx.Done():
|
|
|
- // Mark as canceled
|
|
|
- assistantMsg.AddFinish("canceled")
|
|
|
- _ = a.messages.Update(context.Background(), assistantMsg)
|
|
|
- return ErrRequestCancelled
|
|
|
- default:
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // Check for cancellation before tool execution
|
|
|
- select {
|
|
|
- case <-ctx.Done():
|
|
|
- assistantMsg.AddFinish("canceled_by_user")
|
|
|
- _ = a.messages.Update(context.Background(), assistantMsg)
|
|
|
- return ErrRequestCancelled
|
|
|
- default:
|
|
|
- }
|
|
|
-
|
|
|
- // Execute any tool calls
|
|
|
- toolMsg, err := a.handleToolExecution(ctx, assistantMsg)
|
|
|
- if err != nil {
|
|
|
- if errors.Is(err, context.Canceled) {
|
|
|
- assistantMsg.AddFinish("canceled_by_user")
|
|
|
- _ = a.messages.Update(context.Background(), assistantMsg)
|
|
|
- return ErrRequestCancelled
|
|
|
- }
|
|
|
- return fmt.Errorf("tool execution error: %w", err)
|
|
|
- }
|
|
|
-
|
|
|
- if err := a.messages.Update(ctx, assistantMsg); err != nil {
|
|
|
- return fmt.Errorf("failed to update assistant message: %w", err)
|
|
|
- }
|
|
|
-
|
|
|
- // If no tool calls, we're done
|
|
|
- if len(assistantMsg.ToolCalls()) == 0 {
|
|
|
- break
|
|
|
- }
|
|
|
+ cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
|
|
|
+ model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
|
|
|
+ model.CostPer1MIn/1e6*float64(usage.InputTokens) +
|
|
|
+ model.CostPer1MOut/1e6*float64(usage.OutputTokens)
|
|
|
|
|
|
- // Add messages for next iteration
|
|
|
- messages = append(messages, assistantMsg)
|
|
|
- if toolMsg != nil {
|
|
|
- messages = append(messages, *toolMsg)
|
|
|
- }
|
|
|
+ sess.Cost += cost
|
|
|
+ sess.CompletionTokens += usage.OutputTokens
|
|
|
+ sess.PromptTokens += usage.InputTokens
|
|
|
|
|
|
- // Check for cancellation after tool execution
|
|
|
- select {
|
|
|
- case <-ctx.Done():
|
|
|
- return ErrRequestCancelled
|
|
|
- default:
|
|
|
- }
|
|
|
+ _, err = a.sessions.Save(ctx, sess)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("failed to save session: %w", err)
|
|
|
}
|
|
|
-
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-// getAgentProviders initializes the LLM providers based on the chosen model
|
|
|
-func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
|
|
|
- maxTokens := config.Get().Model.CoderMaxTokens
|
|
|
-
|
|
|
- providerConfig, ok := config.Get().Providers[model.Provider]
|
|
|
- if !ok || providerConfig.Disabled {
|
|
|
- return nil, nil, ErrProviderNotEnabled
|
|
|
+func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
|
|
|
+ cfg := config.Get()
|
|
|
+ agentConfig, ok := cfg.Agents[agentName]
|
|
|
+ if !ok {
|
|
|
+ return nil, fmt.Errorf("agent %s not found", agentName)
|
|
|
+ }
|
|
|
+ model, ok := models.SupportedModels[agentConfig.Model]
|
|
|
+ if !ok {
|
|
|
+ return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
|
|
|
}
|
|
|
|
|
|
- var agentProvider provider.Provider
|
|
|
- var titleGenerator provider.Provider
|
|
|
- var err error
|
|
|
-
|
|
|
- switch model.Provider {
|
|
|
- case models.ProviderOpenAI:
|
|
|
- agentProvider, err = provider.NewOpenAIProvider(
|
|
|
- provider.WithOpenAISystemMessage(
|
|
|
- prompt.CoderOpenAISystemPrompt(),
|
|
|
- ),
|
|
|
- provider.WithOpenAIMaxTokens(maxTokens),
|
|
|
- provider.WithOpenAIModel(model),
|
|
|
- provider.WithOpenAIKey(providerConfig.APIKey),
|
|
|
- )
|
|
|
- if err != nil {
|
|
|
- return nil, nil, fmt.Errorf("failed to create OpenAI agent provider: %w", err)
|
|
|
- }
|
|
|
-
|
|
|
- titleGenerator, err = provider.NewOpenAIProvider(
|
|
|
- provider.WithOpenAISystemMessage(
|
|
|
- prompt.TitlePrompt(),
|
|
|
- ),
|
|
|
- provider.WithOpenAIMaxTokens(80),
|
|
|
- provider.WithOpenAIModel(model),
|
|
|
- provider.WithOpenAIKey(providerConfig.APIKey),
|
|
|
- )
|
|
|
- if err != nil {
|
|
|
- return nil, nil, fmt.Errorf("failed to create OpenAI title generator: %w", err)
|
|
|
- }
|
|
|
-
|
|
|
- case models.ProviderAnthropic:
|
|
|
- agentProvider, err = provider.NewAnthropicProvider(
|
|
|
- provider.WithAnthropicSystemMessage(
|
|
|
- prompt.CoderAnthropicSystemPrompt(),
|
|
|
- ),
|
|
|
- provider.WithAnthropicMaxTokens(maxTokens),
|
|
|
- provider.WithAnthropicKey(providerConfig.APIKey),
|
|
|
- provider.WithAnthropicModel(model),
|
|
|
- )
|
|
|
- if err != nil {
|
|
|
- return nil, nil, fmt.Errorf("failed to create Anthropic agent provider: %w", err)
|
|
|
- }
|
|
|
-
|
|
|
- titleGenerator, err = provider.NewAnthropicProvider(
|
|
|
- provider.WithAnthropicSystemMessage(
|
|
|
- prompt.TitlePrompt(),
|
|
|
- ),
|
|
|
- provider.WithAnthropicMaxTokens(80),
|
|
|
- provider.WithAnthropicKey(providerConfig.APIKey),
|
|
|
- provider.WithAnthropicModel(model),
|
|
|
- )
|
|
|
- if err != nil {
|
|
|
- return nil, nil, fmt.Errorf("failed to create Anthropic title generator: %w", err)
|
|
|
- }
|
|
|
-
|
|
|
- case models.ProviderGemini:
|
|
|
- agentProvider, err = provider.NewGeminiProvider(
|
|
|
- ctx,
|
|
|
- provider.WithGeminiSystemMessage(
|
|
|
- prompt.CoderOpenAISystemPrompt(),
|
|
|
- ),
|
|
|
- provider.WithGeminiMaxTokens(int32(maxTokens)),
|
|
|
- provider.WithGeminiKey(providerConfig.APIKey),
|
|
|
- provider.WithGeminiModel(model),
|
|
|
- )
|
|
|
- if err != nil {
|
|
|
- return nil, nil, fmt.Errorf("failed to create Gemini agent provider: %w", err)
|
|
|
- }
|
|
|
-
|
|
|
- titleGenerator, err = provider.NewGeminiProvider(
|
|
|
- ctx,
|
|
|
- provider.WithGeminiSystemMessage(
|
|
|
- prompt.TitlePrompt(),
|
|
|
- ),
|
|
|
- provider.WithGeminiMaxTokens(80),
|
|
|
- provider.WithGeminiKey(providerConfig.APIKey),
|
|
|
- provider.WithGeminiModel(model),
|
|
|
- )
|
|
|
- if err != nil {
|
|
|
- return nil, nil, fmt.Errorf("failed to create Gemini title generator: %w", err)
|
|
|
- }
|
|
|
-
|
|
|
- case models.ProviderGROQ:
|
|
|
- agentProvider, err = provider.NewOpenAIProvider(
|
|
|
- provider.WithOpenAISystemMessage(
|
|
|
- prompt.CoderAnthropicSystemPrompt(),
|
|
|
- ),
|
|
|
- provider.WithOpenAIMaxTokens(maxTokens),
|
|
|
- provider.WithOpenAIModel(model),
|
|
|
- provider.WithOpenAIKey(providerConfig.APIKey),
|
|
|
- provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
|
|
|
- )
|
|
|
- if err != nil {
|
|
|
- return nil, nil, fmt.Errorf("failed to create GROQ agent provider: %w", err)
|
|
|
- }
|
|
|
-
|
|
|
- titleGenerator, err = provider.NewOpenAIProvider(
|
|
|
- provider.WithOpenAISystemMessage(
|
|
|
- prompt.TitlePrompt(),
|
|
|
- ),
|
|
|
- provider.WithOpenAIMaxTokens(80),
|
|
|
- provider.WithOpenAIModel(model),
|
|
|
- provider.WithOpenAIKey(providerConfig.APIKey),
|
|
|
- provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
|
|
|
- )
|
|
|
- if err != nil {
|
|
|
- return nil, nil, fmt.Errorf("failed to create GROQ title generator: %w", err)
|
|
|
- }
|
|
|
-
|
|
|
- case models.ProviderBedrock:
|
|
|
- agentProvider, err = provider.NewBedrockProvider(
|
|
|
- provider.WithBedrockSystemMessage(
|
|
|
- prompt.CoderAnthropicSystemPrompt(),
|
|
|
- ),
|
|
|
- provider.WithBedrockMaxTokens(maxTokens),
|
|
|
- provider.WithBedrockModel(model),
|
|
|
- )
|
|
|
- if err != nil {
|
|
|
- return nil, nil, fmt.Errorf("failed to create Bedrock agent provider: %w", err)
|
|
|
- }
|
|
|
-
|
|
|
- titleGenerator, err = provider.NewBedrockProvider(
|
|
|
- provider.WithBedrockSystemMessage(
|
|
|
- prompt.TitlePrompt(),
|
|
|
- ),
|
|
|
- provider.WithBedrockMaxTokens(80),
|
|
|
- provider.WithBedrockModel(model),
|
|
|
- )
|
|
|
- if err != nil {
|
|
|
- return nil, nil, fmt.Errorf("failed to create Bedrock title generator: %w", err)
|
|
|
- }
|
|
|
- default:
|
|
|
- return nil, nil, fmt.Errorf("unsupported provider: %s", model.Provider)
|
|
|
+ providerCfg, ok := cfg.Providers[model.Provider]
|
|
|
+ if !ok {
|
|
|
+ return nil, fmt.Errorf("provider %s not supported", model.Provider)
|
|
|
+ }
|
|
|
+ if providerCfg.Disabled {
|
|
|
+ return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
|
|
|
+ }
|
|
|
+ agentProvider, err := provider.NewProvider(
|
|
|
+ model.Provider,
|
|
|
+ provider.WithAPIKey(providerCfg.APIKey),
|
|
|
+ provider.WithModel(model),
|
|
|
+ provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
|
|
|
+ provider.WithMaxTokens(agentConfig.MaxTokens),
|
|
|
+ )
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("could not create provider: %v", err)
|
|
|
}
|
|
|
|
|
|
- return agentProvider, titleGenerator, nil
|
|
|
+ return agentProvider, nil
|
|
|
}
|