|
|
@@ -7,7 +7,6 @@ import (
|
|
|
"strings"
|
|
|
"sync"
|
|
|
|
|
|
- "github.com/kujtimiihoxha/termai/internal/app"
|
|
|
"github.com/kujtimiihoxha/termai/internal/config"
|
|
|
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
|
|
"github.com/kujtimiihoxha/termai/internal/llm/prompt"
|
|
|
@@ -15,22 +14,118 @@ import (
|
|
|
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
|
|
"github.com/kujtimiihoxha/termai/internal/logging"
|
|
|
"github.com/kujtimiihoxha/termai/internal/message"
|
|
|
+ "github.com/kujtimiihoxha/termai/internal/session"
|
|
|
)
|
|
|
|
|
|
-type Agent interface {
|
|
|
+// Common errors
|
|
|
+var (
|
|
|
+ ErrProviderNotEnabled = errors.New("provider is not enabled")
|
|
|
+ ErrRequestCancelled = errors.New("request cancelled by user")
|
|
|
+ ErrSessionBusy = errors.New("session is currently processing another request")
|
|
|
+)
|
|
|
+
|
|
|
+// Service defines the interface for generating responses
|
|
|
+type Service interface {
|
|
|
Generate(ctx context.Context, sessionID string, content string) error
|
|
|
+ Cancel(sessionID string) error
|
|
|
}
|
|
|
|
|
|
type agent struct {
|
|
|
- *app.App
|
|
|
+ sessions session.Service
|
|
|
+ messages message.Service
|
|
|
model models.Model
|
|
|
tools []tools.BaseTool
|
|
|
agent provider.Provider
|
|
|
titleGenerator provider.Provider
|
|
|
+ activeRequests sync.Map // map[sessionID]context.CancelFunc
|
|
|
+}
|
|
|
+
|
|
|
+// NewAgent creates a new agent instance with the given model and tools
|
|
|
+func NewAgent(ctx context.Context, sessions session.Service, messages message.Service, model models.Model, tools []tools.BaseTool) (Service, error) {
|
|
|
+ agentProvider, titleGenerator, err := getAgentProviders(ctx, model)
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("failed to initialize providers: %w", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ return &agent{
|
|
|
+ model: model,
|
|
|
+ tools: tools,
|
|
|
+ sessions: sessions,
|
|
|
+ messages: messages,
|
|
|
+ agent: agentProvider,
|
|
|
+ titleGenerator: titleGenerator,
|
|
|
+ activeRequests: sync.Map{},
|
|
|
+ }, nil
|
|
|
+}
|
|
|
+
|
|
|
+// Cancel cancels an active request by session ID
|
|
|
+func (a *agent) Cancel(sessionID string) error {
|
|
|
+ if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
|
|
|
+ if cancel, ok := cancelFunc.(context.CancelFunc); ok {
|
|
|
+ logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
|
|
|
+ cancel()
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return errors.New("no active request found for this session")
|
|
|
}
|
|
|
|
|
|
-func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
|
|
|
- response, err := c.titleGenerator.SendMessages(
|
|
|
+// Generate starts the generation process
|
|
|
+func (a *agent) Generate(ctx context.Context, sessionID string, content string) error {
|
|
|
+ // Check if this session already has an active request
|
|
|
+ if _, busy := a.activeRequests.Load(sessionID); busy {
|
|
|
+ return ErrSessionBusy
|
|
|
+ }
|
|
|
+
|
|
|
+ // Create a cancellable context
|
|
|
+ genCtx, cancel := context.WithCancel(ctx)
|
|
|
+
|
|
|
+ // Store cancel function to allow user cancellation
|
|
|
+ a.activeRequests.Store(sessionID, cancel)
|
|
|
+
|
|
|
+ // Launch the generation in a goroutine
|
|
|
+ go func() {
|
|
|
+ defer func() {
|
|
|
+ if r := recover(); r != nil {
|
|
|
+ logging.ErrorPersist(fmt.Sprintf("Panic in Generate: %v", r))
|
|
|
+ }
|
|
|
+ }()
|
|
|
+ defer a.activeRequests.Delete(sessionID)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ if err := a.generate(genCtx, sessionID, content); err != nil {
|
|
|
+ if !errors.Is(err, ErrRequestCancelled) && !errors.Is(err, context.Canceled) {
|
|
|
+ // Log the error (avoid logging cancellations as they're expected)
|
|
|
+ logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, err))
|
|
|
+
|
|
|
+ // You may want to create an error message in the chat
|
|
|
+ bgCtx := context.Background()
|
|
|
+ errorMsg := fmt.Sprintf("Sorry, an error occurred: %v", err)
|
|
|
+ _, createErr := a.messages.Create(bgCtx, sessionID, message.CreateMessageParams{
|
|
|
+ Role: message.System,
|
|
|
+ Parts: []message.ContentPart{
|
|
|
+ message.TextContent{
|
|
|
+ Text: errorMsg,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ })
|
|
|
+ if createErr != nil {
|
|
|
+ logging.ErrorPersist(fmt.Sprintf("Failed to create error message: %v", createErr))
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+// IsSessionBusy checks if a session currently has an active request
|
|
|
+func (a *agent) IsSessionBusy(sessionID string) bool {
|
|
|
+ _, busy := a.activeRequests.Load(sessionID)
|
|
|
+ return busy
|
|
|
+} // handleTitleGeneration asynchronously generates a title for new sessions
|
|
|
+func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
|
|
|
+ response, err := a.titleGenerator.SendMessages(
|
|
|
ctx,
|
|
|
[]message.Message{
|
|
|
{
|
|
|
@@ -45,25 +140,30 @@ func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content st
|
|
|
nil,
|
|
|
)
|
|
|
if err != nil {
|
|
|
+ logging.ErrorPersist(fmt.Sprintf("Failed to generate title: %v", err))
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- session, err := c.Sessions.Get(ctx, sessionID)
|
|
|
+ session, err := a.sessions.Get(ctx, sessionID)
|
|
|
if err != nil {
|
|
|
+ logging.ErrorPersist(fmt.Sprintf("Failed to get session: %v", err))
|
|
|
return
|
|
|
}
|
|
|
+
|
|
|
if response.Content != "" {
|
|
|
- session.Title = response.Content
|
|
|
- session.Title = strings.TrimSpace(session.Title)
|
|
|
+ session.Title = strings.TrimSpace(response.Content)
|
|
|
session.Title = strings.ReplaceAll(session.Title, "\n", " ")
|
|
|
- c.Sessions.Save(ctx, session)
|
|
|
+ if _, err := a.sessions.Save(ctx, session); err != nil {
|
|
|
+ logging.ErrorPersist(fmt.Sprintf("Failed to save session title: %v", err))
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (c *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
|
|
|
- session, err := c.Sessions.Get(ctx, sessionID)
|
|
|
+// TrackUsage updates token usage statistics for the session
|
|
|
+func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
|
|
|
+ session, err := a.sessions.Get(ctx, sessionID)
|
|
|
if err != nil {
|
|
|
- return err
|
|
|
+ return fmt.Errorf("failed to get session: %w", err)
|
|
|
}
|
|
|
|
|
|
cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
|
|
|
@@ -75,189 +175,241 @@ func (c *agent) TrackUsage(ctx context.Context, sessionID string, model models.M
|
|
|
session.CompletionTokens += usage.OutputTokens
|
|
|
session.PromptTokens += usage.InputTokens
|
|
|
|
|
|
- _, err = c.Sessions.Save(ctx, session)
|
|
|
- return err
|
|
|
+ _, err = a.sessions.Save(ctx, session)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("failed to save session: %w", err)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
-func (c *agent) processEvent(
|
|
|
+// processEvent handles different types of events during generation
|
|
|
+func (a *agent) processEvent(
|
|
|
ctx context.Context,
|
|
|
sessionID string,
|
|
|
assistantMsg *message.Message,
|
|
|
event provider.ProviderEvent,
|
|
|
) error {
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ return ctx.Err()
|
|
|
+ default:
|
|
|
+ // Continue processing
|
|
|
+ }
|
|
|
+
|
|
|
switch event.Type {
|
|
|
case provider.EventThinkingDelta:
|
|
|
assistantMsg.AppendReasoningContent(event.Content)
|
|
|
- return c.Messages.Update(ctx, *assistantMsg)
|
|
|
+ return a.messages.Update(ctx, *assistantMsg)
|
|
|
case provider.EventContentDelta:
|
|
|
assistantMsg.AppendContent(event.Content)
|
|
|
- return c.Messages.Update(ctx, *assistantMsg)
|
|
|
+ return a.messages.Update(ctx, *assistantMsg)
|
|
|
case provider.EventError:
|
|
|
if errors.Is(event.Error, context.Canceled) {
|
|
|
- return nil
|
|
|
+ logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
|
|
|
+ return context.Canceled
|
|
|
}
|
|
|
logging.ErrorPersist(event.Error.Error())
|
|
|
return event.Error
|
|
|
case provider.EventWarning:
|
|
|
logging.WarnPersist(event.Info)
|
|
|
- return nil
|
|
|
case provider.EventInfo:
|
|
|
logging.InfoPersist(event.Info)
|
|
|
case provider.EventComplete:
|
|
|
assistantMsg.SetToolCalls(event.Response.ToolCalls)
|
|
|
assistantMsg.AddFinish(event.Response.FinishReason)
|
|
|
- err := c.Messages.Update(ctx, *assistantMsg)
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
+ if err := a.messages.Update(ctx, *assistantMsg); err != nil {
|
|
|
+ return fmt.Errorf("failed to update message: %w", err)
|
|
|
}
|
|
|
- return c.TrackUsage(ctx, sessionID, c.model, event.Response.Usage)
|
|
|
+ return a.TrackUsage(ctx, sessionID, a.model, event.Response.Usage)
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
|
|
|
- var wg sync.WaitGroup
|
|
|
+// ExecuteTools runs all tool calls sequentially and returns the results
|
|
|
+func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
|
|
|
toolResults := make([]message.ToolResult, len(toolCalls))
|
|
|
- mutex := &sync.Mutex{}
|
|
|
- errChan := make(chan error, 1)
|
|
|
|
|
|
// Create a child context that can be canceled
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
|
defer cancel()
|
|
|
|
|
|
- for i, tc := range toolCalls {
|
|
|
- wg.Add(1)
|
|
|
- go func(index int, toolCall message.ToolCall) {
|
|
|
- defer wg.Done()
|
|
|
+ // Check if already canceled before starting any execution
|
|
|
+ if ctx.Err() != nil {
|
|
|
+ // Mark all tools as canceled
|
|
|
+ for i, toolCall := range toolCalls {
|
|
|
+ toolResults[i] = message.ToolResult{
|
|
|
+ ToolCallID: toolCall.ID,
|
|
|
+ Content: "Tool execution canceled by user",
|
|
|
+ IsError: true,
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return toolResults, ctx.Err()
|
|
|
+ }
|
|
|
|
|
|
- // Check if context is already canceled
|
|
|
- select {
|
|
|
- case <-ctx.Done():
|
|
|
- mutex.Lock()
|
|
|
- toolResults[index] = message.ToolResult{
|
|
|
- ToolCallID: toolCall.ID,
|
|
|
- Content: "Tool execution canceled",
|
|
|
+ for i, toolCall := range toolCalls {
|
|
|
+ // Check for cancellation before executing each tool
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ // Mark this and all remaining tools as canceled
|
|
|
+ for j := i; j < len(toolCalls); j++ {
|
|
|
+ toolResults[j] = message.ToolResult{
|
|
|
+ ToolCallID: toolCalls[j].ID,
|
|
|
+ Content: "Tool execution canceled by user",
|
|
|
IsError: true,
|
|
|
}
|
|
|
- mutex.Unlock()
|
|
|
-
|
|
|
- // Send cancellation error to error channel if it's empty
|
|
|
- select {
|
|
|
- case errChan <- ctx.Err():
|
|
|
- default:
|
|
|
- }
|
|
|
- return
|
|
|
- default:
|
|
|
}
|
|
|
+ return toolResults, ctx.Err()
|
|
|
+ default:
|
|
|
+ // Continue processing
|
|
|
+ }
|
|
|
|
|
|
- response := ""
|
|
|
- isError := false
|
|
|
- found := false
|
|
|
-
|
|
|
- for _, tool := range tls {
|
|
|
- if tool.Info().Name == toolCall.Name {
|
|
|
- found = true
|
|
|
- toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
|
|
|
- ID: toolCall.ID,
|
|
|
- Name: toolCall.Name,
|
|
|
- Input: toolCall.Input,
|
|
|
- })
|
|
|
-
|
|
|
- if toolErr != nil {
|
|
|
- if errors.Is(toolErr, context.Canceled) {
|
|
|
- response = "Tool execution canceled"
|
|
|
-
|
|
|
- // Send cancellation error to error channel if it's empty
|
|
|
- select {
|
|
|
- case errChan <- ctx.Err():
|
|
|
- default:
|
|
|
- }
|
|
|
- } else {
|
|
|
- response = fmt.Sprintf("error running tool: %s", toolErr)
|
|
|
- }
|
|
|
- isError = true
|
|
|
+ response := ""
|
|
|
+ isError := false
|
|
|
+ found := false
|
|
|
+
|
|
|
+ // Find and execute the appropriate tool
|
|
|
+ for _, tool := range tls {
|
|
|
+ if tool.Info().Name == toolCall.Name {
|
|
|
+ found = true
|
|
|
+ toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
|
|
|
+ ID: toolCall.ID,
|
|
|
+ Name: toolCall.Name,
|
|
|
+ Input: toolCall.Input,
|
|
|
+ })
|
|
|
+
|
|
|
+ if toolErr != nil {
|
|
|
+ if errors.Is(toolErr, context.Canceled) {
|
|
|
+ response = "Tool execution canceled by user"
|
|
|
} else {
|
|
|
- response = toolResult.Content
|
|
|
- isError = toolResult.IsError
|
|
|
+ response = fmt.Sprintf("Error running tool: %s", toolErr)
|
|
|
}
|
|
|
- break
|
|
|
+ isError = true
|
|
|
+ } else {
|
|
|
+ response = toolResult.Content
|
|
|
+ isError = toolResult.IsError
|
|
|
}
|
|
|
+ break
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- if !found {
|
|
|
- response = fmt.Sprintf("tool not found: %s", toolCall.Name)
|
|
|
- isError = true
|
|
|
- }
|
|
|
-
|
|
|
- mutex.Lock()
|
|
|
- defer mutex.Unlock()
|
|
|
-
|
|
|
- toolResults[index] = message.ToolResult{
|
|
|
- ToolCallID: toolCall.ID,
|
|
|
- Content: response,
|
|
|
- IsError: isError,
|
|
|
- }
|
|
|
- }(i, tc)
|
|
|
- }
|
|
|
-
|
|
|
- // Wait for all goroutines to finish or context to be canceled
|
|
|
- done := make(chan struct{})
|
|
|
- go func() {
|
|
|
- wg.Wait()
|
|
|
- close(done)
|
|
|
- }()
|
|
|
+ if !found {
|
|
|
+ response = fmt.Sprintf("Tool not found: %s", toolCall.Name)
|
|
|
+ isError = true
|
|
|
+ }
|
|
|
|
|
|
- select {
|
|
|
- case <-done:
|
|
|
- // All tools completed successfully
|
|
|
- case err := <-errChan:
|
|
|
- // One of the tools encountered a cancellation
|
|
|
- return toolResults, err
|
|
|
- case <-ctx.Done():
|
|
|
- // Context was canceled externally
|
|
|
- return toolResults, ctx.Err()
|
|
|
+ toolResults[i] = message.ToolResult{
|
|
|
+ ToolCallID: toolCall.ID,
|
|
|
+ Content: response,
|
|
|
+ IsError: isError,
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
return toolResults, nil
|
|
|
}
|
|
|
|
|
|
-func (c *agent) handleToolExecution(
|
|
|
+// handleToolExecution processes tool calls and creates tool result messages
|
|
|
+func (a *agent) handleToolExecution(
|
|
|
ctx context.Context,
|
|
|
assistantMsg message.Message,
|
|
|
) (*message.Message, error) {
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ // If cancelled, create tool results that indicate cancellation
|
|
|
+ if len(assistantMsg.ToolCalls()) > 0 {
|
|
|
+ toolResults := make([]message.ToolResult, 0, len(assistantMsg.ToolCalls()))
|
|
|
+ for _, tc := range assistantMsg.ToolCalls() {
|
|
|
+ toolResults = append(toolResults, message.ToolResult{
|
|
|
+ ToolCallID: tc.ID,
|
|
|
+ Content: "Tool execution canceled by user",
|
|
|
+ IsError: true,
|
|
|
+ })
|
|
|
+ }
|
|
|
+
|
|
|
+ // Use background context to ensure the message is created even if original context is cancelled
|
|
|
+ bgCtx := context.Background()
|
|
|
+ parts := make([]message.ContentPart, 0)
|
|
|
+ for _, toolResult := range toolResults {
|
|
|
+ parts = append(parts, toolResult)
|
|
|
+ }
|
|
|
+ msg, err := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
|
|
|
+ Role: message.Tool,
|
|
|
+ Parts: parts,
|
|
|
+ })
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
|
|
|
+ }
|
|
|
+ return &msg, ctx.Err()
|
|
|
+ }
|
|
|
+ return nil, ctx.Err()
|
|
|
+ default:
|
|
|
+ // Continue processing
|
|
|
+ }
|
|
|
+
|
|
|
if len(assistantMsg.ToolCalls()) == 0 {
|
|
|
return nil, nil
|
|
|
}
|
|
|
|
|
|
- toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
|
|
|
+ toolResults, err := a.ExecuteTools(ctx, assistantMsg.ToolCalls(), a.tools)
|
|
|
if err != nil {
|
|
|
+ // If error is from cancellation, still return the partial results we have
|
|
|
+ if errors.Is(err, context.Canceled) {
|
|
|
+ // Use background context to ensure the message is created even if original context is cancelled
|
|
|
+ bgCtx := context.Background()
|
|
|
+ parts := make([]message.ContentPart, 0)
|
|
|
+ for _, toolResult := range toolResults {
|
|
|
+ parts = append(parts, toolResult)
|
|
|
+ }
|
|
|
+
|
|
|
+ msg, createErr := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
|
|
|
+ Role: message.Tool,
|
|
|
+ Parts: parts,
|
|
|
+ })
|
|
|
+ if createErr != nil {
|
|
|
+ logging.ErrorPersist(fmt.Sprintf("Failed to create tool message after cancellation: %v", createErr))
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ return &msg, err
|
|
|
+ }
|
|
|
return nil, err
|
|
|
}
|
|
|
- parts := make([]message.ContentPart, 0)
|
|
|
+
|
|
|
+ parts := make([]message.ContentPart, 0, len(toolResults))
|
|
|
for _, toolResult := range toolResults {
|
|
|
parts = append(parts, toolResult)
|
|
|
}
|
|
|
- msg, err := c.Messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{
|
|
|
+
|
|
|
+ msg, err := a.messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{
|
|
|
Role: message.Tool,
|
|
|
Parts: parts,
|
|
|
})
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("failed to create tool message: %w", err)
|
|
|
+ }
|
|
|
|
|
|
- return &msg, err
|
|
|
+ return &msg, nil
|
|
|
}
|
|
|
|
|
|
-func (c *agent) generate(ctx context.Context, sessionID string, content string) error {
|
|
|
+// generate handles the main generation workflow
|
|
|
+func (a *agent) generate(ctx context.Context, sessionID string, content string) error {
|
|
|
ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
|
|
|
- messages, err := c.Messages.List(ctx, sessionID)
|
|
|
+
|
|
|
+ // Handle context cancellation at any point
|
|
|
+ if err := ctx.Err(); err != nil {
|
|
|
+ return ErrRequestCancelled
|
|
|
+ }
|
|
|
+
|
|
|
+ messages, err := a.messages.List(ctx, sessionID)
|
|
|
if err != nil {
|
|
|
- return err
|
|
|
+ return fmt.Errorf("failed to list messages: %w", err)
|
|
|
}
|
|
|
|
|
|
if len(messages) == 0 {
|
|
|
- go c.handleTitleGeneration(ctx, sessionID, content)
|
|
|
+ titleCtx := context.Background()
|
|
|
+ go a.handleTitleGeneration(titleCtx, sessionID, content)
|
|
|
}
|
|
|
|
|
|
- userMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
|
|
|
+ userMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
|
|
|
Role: message.User,
|
|
|
Parts: []message.ContentPart{
|
|
|
message.TextContent{
|
|
|
@@ -266,133 +418,125 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string)
|
|
|
},
|
|
|
})
|
|
|
if err != nil {
|
|
|
- return err
|
|
|
+ return fmt.Errorf("failed to create user message: %w", err)
|
|
|
}
|
|
|
|
|
|
messages = append(messages, userMsg)
|
|
|
+
|
|
|
for {
|
|
|
+ // Check for cancellation before each iteration
|
|
|
select {
|
|
|
case <-ctx.Done():
|
|
|
- assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
|
|
|
- Role: message.Assistant,
|
|
|
- Parts: []message.ContentPart{},
|
|
|
- })
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- assistantMsg.AddFinish("canceled")
|
|
|
- c.Messages.Update(ctx, assistantMsg)
|
|
|
- return context.Canceled
|
|
|
+ return ErrRequestCancelled
|
|
|
default:
|
|
|
// Continue processing
|
|
|
}
|
|
|
|
|
|
- eventChan, err := c.agent.StreamResponse(ctx, messages, c.tools)
|
|
|
+ eventChan, err := a.agent.StreamResponse(ctx, messages, a.tools)
|
|
|
if err != nil {
|
|
|
if errors.Is(err, context.Canceled) {
|
|
|
- assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
|
|
|
- Role: message.Assistant,
|
|
|
- Parts: []message.ContentPart{},
|
|
|
- })
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- assistantMsg.AddFinish("canceled")
|
|
|
- c.Messages.Update(ctx, assistantMsg)
|
|
|
- return context.Canceled
|
|
|
+ return ErrRequestCancelled
|
|
|
}
|
|
|
- return err
|
|
|
+ return fmt.Errorf("failed to stream response: %w", err)
|
|
|
}
|
|
|
|
|
|
- assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
|
|
|
+ assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
|
|
|
Role: message.Assistant,
|
|
|
Parts: []message.ContentPart{},
|
|
|
- Model: c.model.ID,
|
|
|
+ Model: a.model.ID,
|
|
|
})
|
|
|
if err != nil {
|
|
|
- return err
|
|
|
+ return fmt.Errorf("failed to create assistant message: %w", err)
|
|
|
}
|
|
|
|
|
|
ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
|
|
|
+
|
|
|
+ // Process events from the LLM provider
|
|
|
for event := range eventChan {
|
|
|
- err = c.processEvent(ctx, sessionID, &assistantMsg, event)
|
|
|
- if err != nil {
|
|
|
+ if err := a.processEvent(ctx, sessionID, &assistantMsg, event); err != nil {
|
|
|
if errors.Is(err, context.Canceled) {
|
|
|
+ // Mark as canceled but don't create separate message
|
|
|
assistantMsg.AddFinish("canceled")
|
|
|
- c.Messages.Update(ctx, assistantMsg)
|
|
|
- return context.Canceled
|
|
|
+ _ = a.messages.Update(context.Background(), assistantMsg)
|
|
|
+ return ErrRequestCancelled
|
|
|
}
|
|
|
assistantMsg.AddFinish("error:" + err.Error())
|
|
|
- c.Messages.Update(ctx, assistantMsg)
|
|
|
- return err
|
|
|
+ _ = a.messages.Update(ctx, assistantMsg)
|
|
|
+ return fmt.Errorf("event processing error: %w", err)
|
|
|
}
|
|
|
|
|
|
+ // Check for cancellation during event processing
|
|
|
select {
|
|
|
case <-ctx.Done():
|
|
|
+ // Mark as canceled
|
|
|
assistantMsg.AddFinish("canceled")
|
|
|
- c.Messages.Update(ctx, assistantMsg)
|
|
|
- return context.Canceled
|
|
|
+ _ = a.messages.Update(context.Background(), assistantMsg)
|
|
|
+ return ErrRequestCancelled
|
|
|
default:
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // Check for context cancellation before tool execution
|
|
|
+ // Check for cancellation before tool execution
|
|
|
select {
|
|
|
case <-ctx.Done():
|
|
|
- assistantMsg.AddFinish("canceled")
|
|
|
- c.Messages.Update(ctx, assistantMsg)
|
|
|
- return context.Canceled
|
|
|
+ assistantMsg.AddFinish("canceled_by_user")
|
|
|
+ _ = a.messages.Update(context.Background(), assistantMsg)
|
|
|
+ return ErrRequestCancelled
|
|
|
default:
|
|
|
- // Continue processing
|
|
|
}
|
|
|
|
|
|
- msg, err := c.handleToolExecution(ctx, assistantMsg)
|
|
|
+ // Execute any tool calls
|
|
|
+ toolMsg, err := a.handleToolExecution(ctx, assistantMsg)
|
|
|
if err != nil {
|
|
|
if errors.Is(err, context.Canceled) {
|
|
|
- assistantMsg.AddFinish("canceled")
|
|
|
- c.Messages.Update(ctx, assistantMsg)
|
|
|
- return context.Canceled
|
|
|
+ assistantMsg.AddFinish("canceled_by_user")
|
|
|
+ _ = a.messages.Update(context.Background(), assistantMsg)
|
|
|
+ return ErrRequestCancelled
|
|
|
}
|
|
|
- return err
|
|
|
+ return fmt.Errorf("tool execution error: %w", err)
|
|
|
}
|
|
|
|
|
|
- c.Messages.Update(ctx, assistantMsg)
|
|
|
+ if err := a.messages.Update(ctx, assistantMsg); err != nil {
|
|
|
+ return fmt.Errorf("failed to update assistant message: %w", err)
|
|
|
+ }
|
|
|
|
|
|
+ // If no tool calls, we're done
|
|
|
if len(assistantMsg.ToolCalls()) == 0 {
|
|
|
break
|
|
|
}
|
|
|
|
|
|
+ // Add messages for next iteration
|
|
|
messages = append(messages, assistantMsg)
|
|
|
- if msg != nil {
|
|
|
- messages = append(messages, *msg)
|
|
|
+ if toolMsg != nil {
|
|
|
+ messages = append(messages, *toolMsg)
|
|
|
}
|
|
|
|
|
|
- // Check for context cancellation after tool execution
|
|
|
+ // Check for cancellation after tool execution
|
|
|
select {
|
|
|
case <-ctx.Done():
|
|
|
- assistantMsg.AddFinish("canceled")
|
|
|
- c.Messages.Update(ctx, assistantMsg)
|
|
|
- return context.Canceled
|
|
|
+ return ErrRequestCancelled
|
|
|
default:
|
|
|
- // Continue processing
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
+// getAgentProviders initializes the LLM providers based on the chosen model
|
|
|
func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
|
|
|
maxTokens := config.Get().Model.CoderMaxTokens
|
|
|
|
|
|
providerConfig, ok := config.Get().Providers[model.Provider]
|
|
|
if !ok || providerConfig.Disabled {
|
|
|
- return nil, nil, errors.New("provider is not enabled")
|
|
|
+ return nil, nil, ErrProviderNotEnabled
|
|
|
}
|
|
|
+
|
|
|
var agentProvider provider.Provider
|
|
|
var titleGenerator provider.Provider
|
|
|
+ var err error
|
|
|
|
|
|
switch model.Provider {
|
|
|
case models.ProviderOpenAI:
|
|
|
- var err error
|
|
|
agentProvider, err = provider.NewOpenAIProvider(
|
|
|
provider.WithOpenAISystemMessage(
|
|
|
prompt.CoderOpenAISystemPrompt(),
|
|
|
@@ -402,8 +546,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
|
|
|
provider.WithOpenAIKey(providerConfig.APIKey),
|
|
|
)
|
|
|
if err != nil {
|
|
|
- return nil, nil, err
|
|
|
+ return nil, nil, fmt.Errorf("failed to create OpenAI agent provider: %w", err)
|
|
|
}
|
|
|
+
|
|
|
titleGenerator, err = provider.NewOpenAIProvider(
|
|
|
provider.WithOpenAISystemMessage(
|
|
|
prompt.TitlePrompt(),
|
|
|
@@ -413,10 +558,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
|
|
|
provider.WithOpenAIKey(providerConfig.APIKey),
|
|
|
)
|
|
|
if err != nil {
|
|
|
- return nil, nil, err
|
|
|
+ return nil, nil, fmt.Errorf("failed to create OpenAI title generator: %w", err)
|
|
|
}
|
|
|
+
|
|
|
case models.ProviderAnthropic:
|
|
|
- var err error
|
|
|
agentProvider, err = provider.NewAnthropicProvider(
|
|
|
provider.WithAnthropicSystemMessage(
|
|
|
prompt.CoderAnthropicSystemPrompt(),
|
|
|
@@ -426,8 +571,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
|
|
|
provider.WithAnthropicModel(model),
|
|
|
)
|
|
|
if err != nil {
|
|
|
- return nil, nil, err
|
|
|
+ return nil, nil, fmt.Errorf("failed to create Anthropic agent provider: %w", err)
|
|
|
}
|
|
|
+
|
|
|
titleGenerator, err = provider.NewAnthropicProvider(
|
|
|
provider.WithAnthropicSystemMessage(
|
|
|
prompt.TitlePrompt(),
|
|
|
@@ -437,11 +583,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
|
|
|
provider.WithAnthropicModel(model),
|
|
|
)
|
|
|
if err != nil {
|
|
|
- return nil, nil, err
|
|
|
+ return nil, nil, fmt.Errorf("failed to create Anthropic title generator: %w", err)
|
|
|
}
|
|
|
|
|
|
case models.ProviderGemini:
|
|
|
- var err error
|
|
|
agentProvider, err = provider.NewGeminiProvider(
|
|
|
ctx,
|
|
|
provider.WithGeminiSystemMessage(
|
|
|
@@ -452,8 +597,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
|
|
|
provider.WithGeminiModel(model),
|
|
|
)
|
|
|
if err != nil {
|
|
|
- return nil, nil, err
|
|
|
+ return nil, nil, fmt.Errorf("failed to create Gemini agent provider: %w", err)
|
|
|
}
|
|
|
+
|
|
|
titleGenerator, err = provider.NewGeminiProvider(
|
|
|
ctx,
|
|
|
provider.WithGeminiSystemMessage(
|
|
|
@@ -464,10 +610,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
|
|
|
provider.WithGeminiModel(model),
|
|
|
)
|
|
|
if err != nil {
|
|
|
- return nil, nil, err
|
|
|
+ return nil, nil, fmt.Errorf("failed to create Gemini title generator: %w", err)
|
|
|
}
|
|
|
+
|
|
|
case models.ProviderGROQ:
|
|
|
- var err error
|
|
|
agentProvider, err = provider.NewOpenAIProvider(
|
|
|
provider.WithOpenAISystemMessage(
|
|
|
prompt.CoderAnthropicSystemPrompt(),
|
|
|
@@ -478,8 +624,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
|
|
|
provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
|
|
|
)
|
|
|
if err != nil {
|
|
|
- return nil, nil, err
|
|
|
+ return nil, nil, fmt.Errorf("failed to create GROQ agent provider: %w", err)
|
|
|
}
|
|
|
+
|
|
|
titleGenerator, err = provider.NewOpenAIProvider(
|
|
|
provider.WithOpenAISystemMessage(
|
|
|
prompt.TitlePrompt(),
|
|
|
@@ -490,11 +637,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
|
|
|
provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
|
|
|
)
|
|
|
if err != nil {
|
|
|
- return nil, nil, err
|
|
|
+ return nil, nil, fmt.Errorf("failed to create GROQ title generator: %w", err)
|
|
|
}
|
|
|
|
|
|
case models.ProviderBedrock:
|
|
|
- var err error
|
|
|
agentProvider, err = provider.NewBedrockProvider(
|
|
|
provider.WithBedrockSystemMessage(
|
|
|
prompt.CoderAnthropicSystemPrompt(),
|
|
|
@@ -503,19 +649,21 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
|
|
|
provider.WithBedrockModel(model),
|
|
|
)
|
|
|
if err != nil {
|
|
|
- return nil, nil, err
|
|
|
+ return nil, nil, fmt.Errorf("failed to create Bedrock agent provider: %w", err)
|
|
|
}
|
|
|
+
|
|
|
titleGenerator, err = provider.NewBedrockProvider(
|
|
|
provider.WithBedrockSystemMessage(
|
|
|
prompt.TitlePrompt(),
|
|
|
),
|
|
|
- provider.WithBedrockMaxTokens(maxTokens),
|
|
|
+ provider.WithBedrockMaxTokens(80),
|
|
|
provider.WithBedrockModel(model),
|
|
|
)
|
|
|
if err != nil {
|
|
|
- return nil, nil, err
|
|
|
+ return nil, nil, fmt.Errorf("failed to create Bedrock title generator: %w", err)
|
|
|
}
|
|
|
-
|
|
|
+ default:
|
|
|
+ return nil, nil, fmt.Errorf("unsupported provider: %s", model.Provider)
|
|
|
}
|
|
|
|
|
|
return agentProvider, titleGenerator, nil
|