| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511 |
- package agent
- import (
- "context"
- "errors"
- "fmt"
- "strings"
- "sync"
- "github.com/opencode-ai/opencode/internal/config"
- "github.com/opencode-ai/opencode/internal/llm/models"
- "github.com/opencode-ai/opencode/internal/llm/prompt"
- "github.com/opencode-ai/opencode/internal/llm/provider"
- "github.com/opencode-ai/opencode/internal/llm/tools"
- "github.com/opencode-ai/opencode/internal/logging"
- "github.com/opencode-ai/opencode/internal/message"
- "github.com/opencode-ai/opencode/internal/permission"
- "github.com/opencode-ai/opencode/internal/session"
- )
- // Common errors
- var (
- ErrRequestCancelled = errors.New("request cancelled by user")
- ErrSessionBusy = errors.New("session is currently processing another request")
- )
- type AgentEvent struct {
- message message.Message
- err error
- }
- func (e *AgentEvent) Err() error {
- return e.err
- }
- func (e *AgentEvent) Response() message.Message {
- return e.message
- }
- type Service interface {
- Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error)
- Cancel(sessionID string)
- IsSessionBusy(sessionID string) bool
- IsBusy() bool
- Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
- }
- type agent struct {
- sessions session.Service
- messages message.Service
- tools []tools.BaseTool
- provider provider.Provider
- titleProvider provider.Provider
- activeRequests sync.Map
- }
- func NewAgent(
- agentName config.AgentName,
- sessions session.Service,
- messages message.Service,
- agentTools []tools.BaseTool,
- ) (Service, error) {
- agentProvider, err := createAgentProvider(agentName)
- if err != nil {
- return nil, err
- }
- var titleProvider provider.Provider
- // Only generate titles for the coder agent
- if agentName == config.AgentCoder {
- titleProvider, err = createAgentProvider(config.AgentTitle)
- if err != nil {
- return nil, err
- }
- }
- agent := &agent{
- provider: agentProvider,
- messages: messages,
- sessions: sessions,
- tools: agentTools,
- titleProvider: titleProvider,
- activeRequests: sync.Map{},
- }
- return agent, nil
- }
- func (a *agent) Cancel(sessionID string) {
- if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
- if cancel, ok := cancelFunc.(context.CancelFunc); ok {
- logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
- cancel()
- }
- }
- }
- func (a *agent) IsBusy() bool {
- busy := false
- a.activeRequests.Range(func(key, value interface{}) bool {
- if cancelFunc, ok := value.(context.CancelFunc); ok {
- if cancelFunc != nil {
- busy = true
- return false // Stop iterating
- }
- }
- return true // Continue iterating
- })
- return busy
- }
- func (a *agent) IsSessionBusy(sessionID string) bool {
- _, busy := a.activeRequests.Load(sessionID)
- return busy
- }
- func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
- if a.titleProvider == nil {
- return nil
- }
- session, err := a.sessions.Get(ctx, sessionID)
- if err != nil {
- return err
- }
- response, err := a.titleProvider.SendMessages(
- ctx,
- []message.Message{
- {
- Role: message.User,
- Parts: []message.ContentPart{
- message.TextContent{
- Text: content,
- },
- },
- },
- },
- make([]tools.BaseTool, 0),
- )
- if err != nil {
- return err
- }
- title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
- if title == "" {
- return nil
- }
- session.Title = title
- _, err = a.sessions.Save(ctx, session)
- return err
- }
- func (a *agent) err(err error) AgentEvent {
- return AgentEvent{
- err: err,
- }
- }
- func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) {
- events := make(chan AgentEvent)
- if a.IsSessionBusy(sessionID) {
- return nil, ErrSessionBusy
- }
- genCtx, cancel := context.WithCancel(ctx)
- a.activeRequests.Store(sessionID, cancel)
- go func() {
- logging.Debug("Request started", "sessionID", sessionID)
- defer logging.RecoverPanic("agent.Run", func() {
- events <- a.err(fmt.Errorf("panic while running the agent"))
- })
- result := a.processGeneration(genCtx, sessionID, content)
- if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
- logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, result))
- }
- logging.Debug("Request completed", "sessionID", sessionID)
- a.activeRequests.Delete(sessionID)
- cancel()
- events <- result
- close(events)
- }()
- return events, nil
- }
- func (a *agent) processGeneration(ctx context.Context, sessionID, content string) AgentEvent {
- // List existing messages; if none, start title generation asynchronously.
- msgs, err := a.messages.List(ctx, sessionID)
- if err != nil {
- return a.err(fmt.Errorf("failed to list messages: %w", err))
- }
- if len(msgs) == 0 {
- go func() {
- defer logging.RecoverPanic("agent.Run", func() {
- logging.ErrorPersist("panic while generating title")
- })
- titleErr := a.generateTitle(context.Background(), sessionID, content)
- if titleErr != nil {
- logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
- }
- }()
- }
- userMsg, err := a.createUserMessage(ctx, sessionID, content)
- if err != nil {
- return a.err(fmt.Errorf("failed to create user message: %w", err))
- }
- // Append the new user message to the conversation history.
- msgHistory := append(msgs, userMsg)
- for {
- // Check for cancellation before each iteration
- select {
- case <-ctx.Done():
- return a.err(ctx.Err())
- default:
- // Continue processing
- }
- agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
- if err != nil {
- if errors.Is(err, context.Canceled) {
- agentMessage.AddFinish(message.FinishReasonCanceled)
- a.messages.Update(context.Background(), agentMessage)
- return a.err(ErrRequestCancelled)
- }
- return a.err(fmt.Errorf("failed to process events: %w", err))
- }
- logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
- if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
- // We are not done, we need to respond with the tool response
- msgHistory = append(msgHistory, agentMessage, *toolResults)
- continue
- }
- return AgentEvent{
- message: agentMessage,
- }
- }
- }
- func (a *agent) createUserMessage(ctx context.Context, sessionID, content string) (message.Message, error) {
- return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
- Role: message.User,
- Parts: []message.ContentPart{
- message.TextContent{Text: content},
- },
- })
- }
- func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
- eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
- assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
- Role: message.Assistant,
- Parts: []message.ContentPart{},
- Model: a.provider.Model().ID,
- })
- if err != nil {
- return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
- }
- // Add the session and message ID into the context if needed by tools.
- ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
- ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
- // Process each event in the stream.
- for event := range eventChan {
- if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
- a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
- return assistantMsg, nil, processErr
- }
- if ctx.Err() != nil {
- a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
- return assistantMsg, nil, ctx.Err()
- }
- }
- toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
- toolCalls := assistantMsg.ToolCalls()
- for i, toolCall := range toolCalls {
- select {
- case <-ctx.Done():
- a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
- // Make all future tool calls cancelled
- for j := i; j < len(toolCalls); j++ {
- toolResults[j] = message.ToolResult{
- ToolCallID: toolCalls[j].ID,
- Content: "Tool execution canceled by user",
- IsError: true,
- }
- }
- goto out
- default:
- // Continue processing
- var tool tools.BaseTool
- for _, availableTools := range a.tools {
- if availableTools.Info().Name == toolCall.Name {
- tool = availableTools
- }
- }
- // Tool not found
- if tool == nil {
- toolResults[i] = message.ToolResult{
- ToolCallID: toolCall.ID,
- Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
- IsError: true,
- }
- continue
- }
- toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
- ID: toolCall.ID,
- Name: toolCall.Name,
- Input: toolCall.Input,
- })
- if toolErr != nil {
- if errors.Is(toolErr, permission.ErrorPermissionDenied) {
- toolResults[i] = message.ToolResult{
- ToolCallID: toolCall.ID,
- Content: "Permission denied",
- IsError: true,
- }
- for j := i + 1; j < len(toolCalls); j++ {
- toolResults[j] = message.ToolResult{
- ToolCallID: toolCalls[j].ID,
- Content: "Tool execution canceled by user",
- IsError: true,
- }
- }
- a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
- break
- }
- }
- toolResults[i] = message.ToolResult{
- ToolCallID: toolCall.ID,
- Content: toolResult.Content,
- Metadata: toolResult.Metadata,
- IsError: toolResult.IsError,
- }
- }
- }
- out:
- if len(toolResults) == 0 {
- return assistantMsg, nil, nil
- }
- parts := make([]message.ContentPart, 0)
- for _, tr := range toolResults {
- parts = append(parts, tr)
- }
- msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
- Role: message.Tool,
- Parts: parts,
- })
- if err != nil {
- return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
- }
- return assistantMsg, &msg, err
- }
- func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
- msg.AddFinish(finishReson)
- _ = a.messages.Update(ctx, *msg)
- }
- func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
- select {
- case <-ctx.Done():
- return ctx.Err()
- default:
- // Continue processing.
- }
- switch event.Type {
- case provider.EventThinkingDelta:
- assistantMsg.AppendReasoningContent(event.Content)
- return a.messages.Update(ctx, *assistantMsg)
- case provider.EventContentDelta:
- assistantMsg.AppendContent(event.Content)
- return a.messages.Update(ctx, *assistantMsg)
- case provider.EventToolUseStart:
- assistantMsg.AddToolCall(*event.ToolCall)
- return a.messages.Update(ctx, *assistantMsg)
- // TODO: see how to handle this
- // case provider.EventToolUseDelta:
- // tm := time.Unix(assistantMsg.UpdatedAt, 0)
- // assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
- // if time.Since(tm) > 1000*time.Millisecond {
- // err := a.messages.Update(ctx, *assistantMsg)
- // assistantMsg.UpdatedAt = time.Now().Unix()
- // return err
- // }
- case provider.EventToolUseStop:
- assistantMsg.FinishToolCall(event.ToolCall.ID)
- return a.messages.Update(ctx, *assistantMsg)
- case provider.EventError:
- if errors.Is(event.Error, context.Canceled) {
- logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
- return context.Canceled
- }
- logging.ErrorPersist(event.Error.Error())
- return event.Error
- case provider.EventComplete:
- assistantMsg.SetToolCalls(event.Response.ToolCalls)
- assistantMsg.AddFinish(event.Response.FinishReason)
- if err := a.messages.Update(ctx, *assistantMsg); err != nil {
- return fmt.Errorf("failed to update message: %w", err)
- }
- return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
- }
- return nil
- }
- func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
- sess, err := a.sessions.Get(ctx, sessionID)
- if err != nil {
- return fmt.Errorf("failed to get session: %w", err)
- }
- cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
- model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
- model.CostPer1MIn/1e6*float64(usage.InputTokens) +
- model.CostPer1MOut/1e6*float64(usage.OutputTokens)
- sess.Cost += cost
- sess.CompletionTokens += usage.OutputTokens
- sess.PromptTokens += usage.InputTokens
- _, err = a.sessions.Save(ctx, sess)
- if err != nil {
- return fmt.Errorf("failed to save session: %w", err)
- }
- return nil
- }
- func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) {
- if a.IsBusy() {
- return models.Model{}, fmt.Errorf("cannot change model while processing requests")
- }
- if err := config.UpdateAgentModel(agentName, modelID); err != nil {
- return models.Model{}, fmt.Errorf("failed to update config: %w", err)
- }
- provider, err := createAgentProvider(agentName)
- if err != nil {
- return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err)
- }
- a.provider = provider
- return a.provider.Model(), nil
- }
- func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
- cfg := config.Get()
- agentConfig, ok := cfg.Agents[agentName]
- if !ok {
- return nil, fmt.Errorf("agent %s not found", agentName)
- }
- model, ok := models.SupportedModels[agentConfig.Model]
- if !ok {
- return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
- }
- providerCfg, ok := cfg.Providers[model.Provider]
- if !ok {
- return nil, fmt.Errorf("provider %s not supported", model.Provider)
- }
- if providerCfg.Disabled {
- return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
- }
- maxTokens := model.DefaultMaxTokens
- if agentConfig.MaxTokens > 0 {
- maxTokens = agentConfig.MaxTokens
- }
- opts := []provider.ProviderClientOption{
- provider.WithAPIKey(providerCfg.APIKey),
- provider.WithModel(model),
- provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
- provider.WithMaxTokens(maxTokens),
- }
- if model.Provider == models.ProviderOpenAI && model.CanReason {
- opts = append(
- opts,
- provider.WithOpenAIOptions(
- provider.WithReasoningEffort(agentConfig.ReasoningEffort),
- ),
- )
- } else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder {
- opts = append(
- opts,
- provider.WithAnthropicOptions(
- provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
- ),
- )
- }
- agentProvider, err := provider.NewProvider(
- model.Provider,
- opts...,
- )
- if err != nil {
- return nil, fmt.Errorf("could not create provider: %v", err)
- }
- return agentProvider, nil
- }
|