| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199 |
- // Package agent is the core orchestration layer for Crush AI agents.
- //
- // It provides session-based AI agent functionality for managing
- // conversations, tool execution, and message handling. It coordinates
- // interactions between language models, messages, sessions, and tools while
- // handling features like automatic summarization, queuing, and token
- // management.
- package agent
- import (
- "cmp"
- "context"
- _ "embed"
- "encoding/base64"
- "errors"
- "fmt"
- "log/slog"
- "os"
- "regexp"
- "strconv"
- "strings"
- "sync"
- "time"
- "charm.land/catwalk/pkg/catwalk"
- "charm.land/fantasy"
- "charm.land/fantasy/providers/anthropic"
- "charm.land/fantasy/providers/bedrock"
- "charm.land/fantasy/providers/google"
- "charm.land/fantasy/providers/openai"
- "charm.land/fantasy/providers/openrouter"
- "charm.land/fantasy/providers/vercel"
- "charm.land/lipgloss/v2"
- "github.com/charmbracelet/crush/internal/agent/hyper"
- "github.com/charmbracelet/crush/internal/agent/notify"
- "github.com/charmbracelet/crush/internal/agent/tools"
- "github.com/charmbracelet/crush/internal/agent/tools/mcp"
- "github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/csync"
- "github.com/charmbracelet/crush/internal/message"
- "github.com/charmbracelet/crush/internal/permission"
- "github.com/charmbracelet/crush/internal/pubsub"
- "github.com/charmbracelet/crush/internal/session"
- "github.com/charmbracelet/crush/internal/stringext"
- "github.com/charmbracelet/crush/internal/version"
- "github.com/charmbracelet/x/exp/charmtone"
- )
- const (
- DefaultSessionName = "Untitled Session"
- // Constants for auto-summarization thresholds
- largeContextWindowThreshold = 200_000
- largeContextWindowBuffer = 20_000
- smallContextWindowRatio = 0.2
- )
- var userAgent = fmt.Sprintf("Charm-Crush/%s (https://charm.land/crush)", version.Version)
- //go:embed templates/title.md
- var titlePrompt []byte
- //go:embed templates/summary.md
- var summaryPrompt []byte
- // Used to remove <think> tags from generated titles.
- var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
- type SessionAgentCall struct {
- SessionID string
- Prompt string
- ProviderOptions fantasy.ProviderOptions
- Attachments []message.Attachment
- MaxOutputTokens int64
- Temperature *float64
- TopP *float64
- TopK *int64
- FrequencyPenalty *float64
- PresencePenalty *float64
- NonInteractive bool
- }
- type SessionAgent interface {
- Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
- SetModels(large Model, small Model)
- SetTools(tools []fantasy.AgentTool)
- SetSystemPrompt(systemPrompt string)
- Cancel(sessionID string)
- CancelAll()
- IsSessionBusy(sessionID string) bool
- IsBusy() bool
- QueuedPrompts(sessionID string) int
- QueuedPromptsList(sessionID string) []string
- ClearQueue(sessionID string)
- Summarize(context.Context, string, fantasy.ProviderOptions) error
- Model() Model
- }
- type Model struct {
- Model fantasy.LanguageModel
- CatwalkCfg catwalk.Model
- ModelCfg config.SelectedModel
- }
- type sessionAgent struct {
- largeModel *csync.Value[Model]
- smallModel *csync.Value[Model]
- systemPromptPrefix *csync.Value[string]
- systemPrompt *csync.Value[string]
- tools *csync.Slice[fantasy.AgentTool]
- isSubAgent bool
- sessions session.Service
- messages message.Service
- disableAutoSummarize bool
- isYolo bool
- notify pubsub.Publisher[notify.Notification]
- messageQueue *csync.Map[string, []SessionAgentCall]
- activeRequests *csync.Map[string, context.CancelFunc]
- }
- type SessionAgentOptions struct {
- LargeModel Model
- SmallModel Model
- SystemPromptPrefix string
- SystemPrompt string
- IsSubAgent bool
- DisableAutoSummarize bool
- IsYolo bool
- Sessions session.Service
- Messages message.Service
- Tools []fantasy.AgentTool
- Notify pubsub.Publisher[notify.Notification]
- }
- func NewSessionAgent(
- opts SessionAgentOptions,
- ) SessionAgent {
- return &sessionAgent{
- largeModel: csync.NewValue(opts.LargeModel),
- smallModel: csync.NewValue(opts.SmallModel),
- systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
- systemPrompt: csync.NewValue(opts.SystemPrompt),
- isSubAgent: opts.IsSubAgent,
- sessions: opts.Sessions,
- messages: opts.Messages,
- disableAutoSummarize: opts.DisableAutoSummarize,
- tools: csync.NewSliceFrom(opts.Tools),
- isYolo: opts.IsYolo,
- notify: opts.Notify,
- messageQueue: csync.NewMap[string, []SessionAgentCall](),
- activeRequests: csync.NewMap[string, context.CancelFunc](),
- }
- }
- func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
- if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
- return nil, ErrEmptyPrompt
- }
- if call.SessionID == "" {
- return nil, ErrSessionMissing
- }
- // Queue the message if busy
- if a.IsSessionBusy(call.SessionID) {
- existing, ok := a.messageQueue.Get(call.SessionID)
- if !ok {
- existing = []SessionAgentCall{}
- }
- existing = append(existing, call)
- a.messageQueue.Set(call.SessionID, existing)
- return nil, nil
- }
- // Copy mutable fields under lock to avoid races with SetTools/SetModels.
- agentTools := a.tools.Copy()
- largeModel := a.largeModel.Get()
- systemPrompt := a.systemPrompt.Get()
- promptPrefix := a.systemPromptPrefix.Get()
- var instructions strings.Builder
- for _, server := range mcp.GetStates() {
- if server.State != mcp.StateConnected {
- continue
- }
- if s := server.Client.InitializeResult().Instructions; s != "" {
- instructions.WriteString(s)
- instructions.WriteString("\n\n")
- }
- }
- if s := instructions.String(); s != "" {
- systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
- }
- if len(agentTools) > 0 {
- // Add Anthropic caching to the last tool.
- agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
- }
- agent := fantasy.NewAgent(
- largeModel.Model,
- fantasy.WithSystemPrompt(systemPrompt),
- fantasy.WithTools(agentTools...),
- fantasy.WithUserAgent(userAgent),
- )
- sessionLock := sync.Mutex{}
- currentSession, err := a.sessions.Get(ctx, call.SessionID)
- if err != nil {
- return nil, fmt.Errorf("failed to get session: %w", err)
- }
- msgs, err := a.getSessionMessages(ctx, currentSession)
- if err != nil {
- return nil, fmt.Errorf("failed to get session messages: %w", err)
- }
- var wg sync.WaitGroup
- // Generate title if first message.
- if len(msgs) == 0 {
- titleCtx := ctx // Copy to avoid race with ctx reassignment below.
- wg.Go(func() {
- a.generateTitle(titleCtx, call.SessionID, call.Prompt)
- })
- }
- defer wg.Wait()
- // Add the user message to the session.
- _, err = a.createUserMessage(ctx, call)
- if err != nil {
- return nil, err
- }
- // Add the session to the context.
- ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
- genCtx, cancel := context.WithCancel(ctx)
- a.activeRequests.Set(call.SessionID, cancel)
- defer cancel()
- defer a.activeRequests.Del(call.SessionID)
- history, files := a.preparePrompt(msgs, call.Attachments...)
- startTime := time.Now()
- a.eventPromptSent(call.SessionID)
- var currentAssistant *message.Message
- var shouldSummarize bool
- // Don't send MaxOutputTokens if 0 — some providers (e.g. LM Studio) reject it
- var maxOutputTokens *int64
- if call.MaxOutputTokens > 0 {
- maxOutputTokens = &call.MaxOutputTokens
- }
- result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
- Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
- Files: files,
- Messages: history,
- ProviderOptions: call.ProviderOptions,
- MaxOutputTokens: maxOutputTokens,
- TopP: call.TopP,
- Temperature: call.Temperature,
- PresencePenalty: call.PresencePenalty,
- TopK: call.TopK,
- FrequencyPenalty: call.FrequencyPenalty,
- PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
- prepared.Messages = options.Messages
- for i := range prepared.Messages {
- prepared.Messages[i].ProviderOptions = nil
- }
- // Use latest tools (updated by SetTools when MCP tools change).
- prepared.Tools = a.tools.Copy()
- queuedCalls, _ := a.messageQueue.Get(call.SessionID)
- a.messageQueue.Del(call.SessionID)
- for _, queued := range queuedCalls {
- userMessage, createErr := a.createUserMessage(callContext, queued)
- if createErr != nil {
- return callContext, prepared, createErr
- }
- prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
- }
- prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
- lastSystemRoleInx := 0
- systemMessageUpdated := false
- for i, msg := range prepared.Messages {
- // Only add cache control to the last message.
- if msg.Role == fantasy.MessageRoleSystem {
- lastSystemRoleInx = i
- } else if !systemMessageUpdated {
- prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
- systemMessageUpdated = true
- }
- // Than add cache control to the last 2 messages.
- if i > len(prepared.Messages)-3 {
- prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
- }
- }
- if promptPrefix != "" {
- prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
- }
- var assistantMsg message.Message
- assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
- Role: message.Assistant,
- Parts: []message.ContentPart{},
- Model: largeModel.ModelCfg.Model,
- Provider: largeModel.ModelCfg.Provider,
- })
- if err != nil {
- return callContext, prepared, err
- }
- callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
- callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
- callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
- currentAssistant = &assistantMsg
- return callContext, prepared, err
- },
- OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
- currentAssistant.AppendReasoningContent(reasoning.Text)
- return a.messages.Update(genCtx, *currentAssistant)
- },
- OnReasoningDelta: func(id string, text string) error {
- currentAssistant.AppendReasoningContent(text)
- return a.messages.Update(genCtx, *currentAssistant)
- },
- OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
- // handle anthropic signature
- if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
- if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
- currentAssistant.AppendReasoningSignature(reasoning.Signature)
- }
- }
- if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
- if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
- currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
- }
- }
- if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
- if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
- currentAssistant.SetReasoningResponsesData(reasoning)
- }
- }
- currentAssistant.FinishThinking()
- return a.messages.Update(genCtx, *currentAssistant)
- },
- OnTextDelta: func(id string, text string) error {
- // Strip leading newline from initial text content. This is is
- // particularly important in non-interactive mode where leading
- // newlines are very visible.
- if len(currentAssistant.Parts) == 0 {
- text = strings.TrimPrefix(text, "\n")
- }
- currentAssistant.AppendContent(text)
- return a.messages.Update(genCtx, *currentAssistant)
- },
- OnToolInputStart: func(id string, toolName string) error {
- toolCall := message.ToolCall{
- ID: id,
- Name: toolName,
- ProviderExecuted: false,
- Finished: false,
- }
- currentAssistant.AddToolCall(toolCall)
- // Use parent ctx instead of genCtx to ensure the update succeeds
- // even if the request is canceled mid-stream
- return a.messages.Update(ctx, *currentAssistant)
- },
- OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
- // TODO: implement
- },
- OnToolCall: func(tc fantasy.ToolCallContent) error {
- toolCall := message.ToolCall{
- ID: tc.ToolCallID,
- Name: tc.ToolName,
- Input: tc.Input,
- ProviderExecuted: false,
- Finished: true,
- }
- currentAssistant.AddToolCall(toolCall)
- // Use parent ctx instead of genCtx to ensure the update succeeds
- // even if the request is canceled mid-stream
- return a.messages.Update(ctx, *currentAssistant)
- },
- OnToolResult: func(result fantasy.ToolResultContent) error {
- toolResult := a.convertToToolResult(result)
- // Use parent ctx instead of genCtx to ensure the message is created
- // even if the request is canceled mid-stream
- _, createMsgErr := a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
- Role: message.Tool,
- Parts: []message.ContentPart{
- toolResult,
- },
- })
- return createMsgErr
- },
- OnStepFinish: func(stepResult fantasy.StepResult) error {
- finishReason := message.FinishReasonUnknown
- switch stepResult.FinishReason {
- case fantasy.FinishReasonLength:
- finishReason = message.FinishReasonMaxTokens
- case fantasy.FinishReasonStop:
- finishReason = message.FinishReasonEndTurn
- case fantasy.FinishReasonToolCalls:
- finishReason = message.FinishReasonToolUse
- }
- currentAssistant.AddFinish(finishReason, "", "")
- sessionLock.Lock()
- defer sessionLock.Unlock()
- updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
- if getSessionErr != nil {
- return getSessionErr
- }
- a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
- _, sessionErr := a.sessions.Save(ctx, updatedSession)
- if sessionErr != nil {
- return sessionErr
- }
- currentSession = updatedSession
- return a.messages.Update(genCtx, *currentAssistant)
- },
- StopWhen: []fantasy.StopCondition{
- func(_ []fantasy.StepResult) bool {
- cw := int64(largeModel.CatwalkCfg.ContextWindow)
- // If context window is unknown (0), skip auto-summarize
- // to avoid immediately truncating custom/local models.
- if cw == 0 {
- return false
- }
- tokens := currentSession.CompletionTokens + currentSession.PromptTokens
- remaining := cw - tokens
- var threshold int64
- if cw > largeContextWindowThreshold {
- threshold = largeContextWindowBuffer
- } else {
- threshold = int64(float64(cw) * smallContextWindowRatio)
- }
- if (remaining <= threshold) && !a.disableAutoSummarize {
- shouldSummarize = true
- return true
- }
- return false
- },
- func(steps []fantasy.StepResult) bool {
- return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
- },
- },
- })
- a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
- if err != nil {
- isCancelErr := errors.Is(err, context.Canceled)
- isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
- if currentAssistant == nil {
- return result, err
- }
- // Ensure we finish thinking on error to close the reasoning state.
- currentAssistant.FinishThinking()
- toolCalls := currentAssistant.ToolCalls()
- // INFO: we use the parent context here because the genCtx has been cancelled.
- msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
- if createErr != nil {
- return nil, createErr
- }
- for _, tc := range toolCalls {
- if !tc.Finished {
- tc.Finished = true
- tc.Input = "{}"
- currentAssistant.AddToolCall(tc)
- updateErr := a.messages.Update(ctx, *currentAssistant)
- if updateErr != nil {
- return nil, updateErr
- }
- }
- found := false
- for _, msg := range msgs {
- if msg.Role == message.Tool {
- for _, tr := range msg.ToolResults() {
- if tr.ToolCallID == tc.ID {
- found = true
- break
- }
- }
- }
- if found {
- break
- }
- }
- if found {
- continue
- }
- content := "There was an error while executing the tool"
- if isCancelErr {
- content = "Error: user cancelled assistant tool calling"
- } else if isPermissionErr {
- content = "User denied permission"
- }
- toolResult := message.ToolResult{
- ToolCallID: tc.ID,
- Name: tc.Name,
- Content: content,
- IsError: true,
- }
- _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
- Role: message.Tool,
- Parts: []message.ContentPart{
- toolResult,
- },
- })
- if createErr != nil {
- return nil, createErr
- }
- }
- var fantasyErr *fantasy.Error
- var providerErr *fantasy.ProviderError
- const defaultTitle = "Provider Error"
- linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
- if isCancelErr {
- currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
- } else if isPermissionErr {
- currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
- } else if errors.Is(err, hyper.ErrNoCredits) {
- url := hyper.BaseURL()
- link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
- currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
- } else if errors.As(err, &providerErr) {
- if providerErr.Message == "The requested model is not supported." {
- url := "https://github.com/settings/copilot/features"
- link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
- currentAssistant.AddFinish(
- message.FinishReasonError,
- "Copilot model not enabled",
- fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
- )
- } else {
- currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
- }
- } else if errors.As(err, &fantasyErr) {
- currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
- } else {
- currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
- }
- // Note: we use the parent context here because the genCtx has been
- // cancelled.
- updateErr := a.messages.Update(ctx, *currentAssistant)
- if updateErr != nil {
- return nil, updateErr
- }
- return nil, err
- }
- // Send notification that agent has finished its turn (skip for
- // nested/non-interactive sessions).
- if !call.NonInteractive && a.notify != nil {
- a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
- SessionID: call.SessionID,
- SessionTitle: currentSession.Title,
- Type: notify.TypeAgentFinished,
- })
- }
- if shouldSummarize {
- a.activeRequests.Del(call.SessionID)
- if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
- return nil, summarizeErr
- }
- // If the agent wasn't done...
- if len(currentAssistant.ToolCalls()) > 0 {
- existing, ok := a.messageQueue.Get(call.SessionID)
- if !ok {
- existing = []SessionAgentCall{}
- }
- call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
- existing = append(existing, call)
- a.messageQueue.Set(call.SessionID, existing)
- }
- }
- // Release active request before processing queued messages.
- a.activeRequests.Del(call.SessionID)
- cancel()
- queuedMessages, ok := a.messageQueue.Get(call.SessionID)
- if !ok || len(queuedMessages) == 0 {
- return result, err
- }
- // There are queued messages restart the loop.
- firstQueuedMessage := queuedMessages[0]
- a.messageQueue.Set(call.SessionID, queuedMessages[1:])
- return a.Run(ctx, firstQueuedMessage)
- }
- func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
- if a.IsSessionBusy(sessionID) {
- return ErrSessionBusy
- }
- // Copy mutable fields under lock to avoid races with SetModels.
- largeModel := a.largeModel.Get()
- systemPromptPrefix := a.systemPromptPrefix.Get()
- currentSession, err := a.sessions.Get(ctx, sessionID)
- if err != nil {
- return fmt.Errorf("failed to get session: %w", err)
- }
- msgs, err := a.getSessionMessages(ctx, currentSession)
- if err != nil {
- return err
- }
- if len(msgs) == 0 {
- // Nothing to summarize.
- return nil
- }
- aiMsgs, _ := a.preparePrompt(msgs)
- genCtx, cancel := context.WithCancel(ctx)
- a.activeRequests.Set(sessionID, cancel)
- defer a.activeRequests.Del(sessionID)
- defer cancel()
- agent := fantasy.NewAgent(largeModel.Model,
- fantasy.WithSystemPrompt(string(summaryPrompt)),
- fantasy.WithUserAgent(userAgent),
- )
- summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
- Role: message.Assistant,
- Model: largeModel.Model.Model(),
- Provider: largeModel.Model.Provider(),
- IsSummaryMessage: true,
- })
- if err != nil {
- return err
- }
- summaryPromptText := buildSummaryPrompt(currentSession.Todos)
- resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
- Prompt: summaryPromptText,
- Messages: aiMsgs,
- ProviderOptions: opts,
- PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
- prepared.Messages = options.Messages
- if systemPromptPrefix != "" {
- prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
- }
- return callContext, prepared, nil
- },
- OnReasoningDelta: func(id string, text string) error {
- summaryMessage.AppendReasoningContent(text)
- return a.messages.Update(genCtx, summaryMessage)
- },
- OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
- // Handle anthropic signature.
- if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
- if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
- summaryMessage.AppendReasoningSignature(signature.Signature)
- }
- }
- summaryMessage.FinishThinking()
- return a.messages.Update(genCtx, summaryMessage)
- },
- OnTextDelta: func(id, text string) error {
- summaryMessage.AppendContent(text)
- return a.messages.Update(genCtx, summaryMessage)
- },
- })
- if err != nil {
- isCancelErr := errors.Is(err, context.Canceled)
- if isCancelErr {
- // User cancelled summarize we need to remove the summary message.
- deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
- return deleteErr
- }
- return err
- }
- summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
- err = a.messages.Update(genCtx, summaryMessage)
- if err != nil {
- return err
- }
- var openrouterCost *float64
- for _, step := range resp.Steps {
- stepCost := a.openrouterCost(step.ProviderMetadata)
- if stepCost != nil {
- newCost := *stepCost
- if openrouterCost != nil {
- newCost += *openrouterCost
- }
- openrouterCost = &newCost
- }
- }
- a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
- // Just in case, get just the last usage info.
- usage := resp.Response.Usage
- currentSession.SummaryMessageID = summaryMessage.ID
- currentSession.CompletionTokens = usage.OutputTokens
- currentSession.PromptTokens = 0
- _, err = a.sessions.Save(genCtx, currentSession)
- return err
- }
- func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
- if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
- return fantasy.ProviderOptions{}
- }
- return fantasy.ProviderOptions{
- anthropic.Name: &anthropic.ProviderCacheControlOptions{
- CacheControl: anthropic.CacheControl{Type: "ephemeral"},
- },
- bedrock.Name: &anthropic.ProviderCacheControlOptions{
- CacheControl: anthropic.CacheControl{Type: "ephemeral"},
- },
- vercel.Name: &anthropic.ProviderCacheControlOptions{
- CacheControl: anthropic.CacheControl{Type: "ephemeral"},
- },
- }
- }
- func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
- parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
- var attachmentParts []message.ContentPart
- for _, attachment := range call.Attachments {
- attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
- }
- parts = append(parts, attachmentParts...)
- msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
- Role: message.User,
- Parts: parts,
- })
- if err != nil {
- return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
- }
- return msg, nil
- }
- func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
- var history []fantasy.Message
- if !a.isSubAgent {
- history = append(history, fantasy.NewUserMessage(
- fmt.Sprintf("<system_reminder>%s</system_reminder>",
- `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
- If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
- If not, please feel free to ignore. Again do not mention this message to the user.`,
- ),
- ))
- }
- for _, m := range msgs {
- if len(m.Parts) == 0 {
- continue
- }
- // Assistant message without content or tool calls (cancelled before it
- // returned anything).
- if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
- continue
- }
- history = append(history, m.ToAIMessage()...)
- }
- var files []fantasy.FilePart
- for _, attachment := range attachments {
- if attachment.IsText() {
- continue
- }
- files = append(files, fantasy.FilePart{
- Filename: attachment.FileName,
- Data: attachment.Content,
- MediaType: attachment.MimeType,
- })
- }
- return history, files
- }
- func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
- msgs, err := a.messages.List(ctx, session.ID)
- if err != nil {
- return nil, fmt.Errorf("failed to list messages: %w", err)
- }
- if session.SummaryMessageID != "" {
- summaryMsgIndex := -1
- for i, msg := range msgs {
- if msg.ID == session.SummaryMessageID {
- summaryMsgIndex = i
- break
- }
- }
- if summaryMsgIndex != -1 {
- msgs = msgs[summaryMsgIndex:]
- msgs[0].Role = message.User
- }
- }
- return msgs, nil
- }
- // generateTitle generates a session titled based on the initial prompt.
- func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
- if userPrompt == "" {
- return
- }
- smallModel := a.smallModel.Get()
- largeModel := a.largeModel.Get()
- systemPromptPrefix := a.systemPromptPrefix.Get()
- var maxOutputTokens int64 = 40
- if smallModel.CatwalkCfg.CanReason {
- maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
- }
- newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
- return fantasy.NewAgent(m,
- fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
- fantasy.WithMaxOutputTokens(tok),
- fantasy.WithUserAgent(userAgent),
- )
- }
- streamCall := fantasy.AgentStreamCall{
- Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
- PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
- prepared.Messages = opts.Messages
- if systemPromptPrefix != "" {
- prepared.Messages = append([]fantasy.Message{
- fantasy.NewSystemMessage(systemPromptPrefix),
- }, prepared.Messages...)
- }
- return callCtx, prepared, nil
- },
- }
- // Use the small model to generate the title.
- model := smallModel
- agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
- resp, err := agent.Stream(ctx, streamCall)
- if err == nil {
- // We successfully generated a title with the small model.
- slog.Debug("Generated title with small model")
- } else {
- // It didn't work. Let's try with the big model.
- slog.Error("Error generating title with small model; trying big model", "err", err)
- model = largeModel
- agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
- resp, err = agent.Stream(ctx, streamCall)
- if err == nil {
- slog.Debug("Generated title with large model")
- } else {
- // Welp, the large model didn't work either. Use the default
- // session name and return.
- slog.Error("Error generating title with large model", "err", err)
- saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
- if saveErr != nil {
- slog.Error("Failed to save session title", "error", saveErr)
- }
- return
- }
- }
- if resp == nil {
- // Actually, we didn't get a response so we can't. Use the default
- // session name and return.
- slog.Error("Response is nil; can't generate title")
- saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
- if saveErr != nil {
- slog.Error("Failed to save session title", "error", saveErr)
- }
- return
- }
- // Clean up title.
- var title string
- title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
- // Remove thinking tags if present.
- title = thinkTagRegex.ReplaceAllString(title, "")
- title = strings.TrimSpace(title)
- title = cmp.Or(title, DefaultSessionName)
- // Calculate usage and cost.
- var openrouterCost *float64
- for _, step := range resp.Steps {
- stepCost := a.openrouterCost(step.ProviderMetadata)
- if stepCost != nil {
- newCost := *stepCost
- if openrouterCost != nil {
- newCost += *openrouterCost
- }
- openrouterCost = &newCost
- }
- }
- modelConfig := model.CatwalkCfg
- cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
- modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
- modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
- modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
- // Use override cost if available (e.g., from OpenRouter).
- if openrouterCost != nil {
- cost = *openrouterCost
- }
- promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
- completionTokens := resp.TotalUsage.OutputTokens
- // Atomically update only title and usage fields to avoid overriding other
- // concurrent session updates.
- saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
- if saveErr != nil {
- slog.Error("Failed to save session title and usage", "error", saveErr)
- return
- }
- }
- func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
- openrouterMetadata, ok := metadata[openrouter.Name]
- if !ok {
- return nil
- }
- opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
- if !ok {
- return nil
- }
- return &opts.Usage.Cost
- }
- func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
- modelConfig := model.CatwalkCfg
- cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
- modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
- modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
- modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
- a.eventTokensUsed(session.ID, model, usage, cost)
- if overrideCost != nil {
- session.Cost += *overrideCost
- } else {
- session.Cost += cost
- }
- session.CompletionTokens = usage.OutputTokens
- session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
- }
- func (a *sessionAgent) Cancel(sessionID string) {
- // Cancel regular requests. Don't use Take() here - we need the entry to
- // remain in activeRequests so IsBusy() returns true until the goroutine
- // fully completes (including error handling that may access the DB).
- // The defer in processRequest will clean up the entry.
- if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
- slog.Debug("Request cancellation initiated", "session_id", sessionID)
- cancel()
- }
- // Also check for summarize requests.
- if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
- slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
- cancel()
- }
- if a.QueuedPrompts(sessionID) > 0 {
- slog.Debug("Clearing queued prompts", "session_id", sessionID)
- a.messageQueue.Del(sessionID)
- }
- }
- func (a *sessionAgent) ClearQueue(sessionID string) {
- if a.QueuedPrompts(sessionID) > 0 {
- slog.Debug("Clearing queued prompts", "session_id", sessionID)
- a.messageQueue.Del(sessionID)
- }
- }
- func (a *sessionAgent) CancelAll() {
- if !a.IsBusy() {
- return
- }
- for key := range a.activeRequests.Seq2() {
- a.Cancel(key) // key is sessionID
- }
- timeout := time.After(5 * time.Second)
- for a.IsBusy() {
- select {
- case <-timeout:
- return
- default:
- time.Sleep(200 * time.Millisecond)
- }
- }
- }
- func (a *sessionAgent) IsBusy() bool {
- var busy bool
- for cancelFunc := range a.activeRequests.Seq() {
- if cancelFunc != nil {
- busy = true
- break
- }
- }
- return busy
- }
- func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
- _, busy := a.activeRequests.Get(sessionID)
- return busy
- }
- func (a *sessionAgent) QueuedPrompts(sessionID string) int {
- l, ok := a.messageQueue.Get(sessionID)
- if !ok {
- return 0
- }
- return len(l)
- }
- func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
- l, ok := a.messageQueue.Get(sessionID)
- if !ok {
- return nil
- }
- prompts := make([]string, len(l))
- for i, call := range l {
- prompts[i] = call.Prompt
- }
- return prompts
- }
- func (a *sessionAgent) SetModels(large Model, small Model) {
- a.largeModel.Set(large)
- a.smallModel.Set(small)
- }
- func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
- a.tools.SetSlice(tools)
- }
- func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
- a.systemPrompt.Set(systemPrompt)
- }
- func (a *sessionAgent) Model() Model {
- return a.largeModel.Get()
- }
- // convertToToolResult converts a fantasy tool result to a message tool result.
- func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
- baseResult := message.ToolResult{
- ToolCallID: result.ToolCallID,
- Name: result.ToolName,
- Metadata: result.ClientMetadata,
- }
- switch result.Result.GetType() {
- case fantasy.ToolResultContentTypeText:
- if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
- baseResult.Content = r.Text
- }
- case fantasy.ToolResultContentTypeError:
- if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
- baseResult.Content = r.Error.Error()
- baseResult.IsError = true
- }
- case fantasy.ToolResultContentTypeMedia:
- if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
- content := r.Text
- if content == "" {
- content = fmt.Sprintf("Loaded %s content", r.MediaType)
- }
- baseResult.Content = content
- baseResult.Data = r.Data
- baseResult.MIMEType = r.MediaType
- }
- }
- return baseResult
- }
- // workaroundProviderMediaLimitations converts media content in tool results to
- // user messages for providers that don't natively support images in tool results.
- //
- // Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
- // don't support sending images/media in tool result messages - they only accept
- // text in tool results. However, they DO support images in user messages.
- //
- // If we send media in tool results to these providers, the API returns an error.
- //
- // Solution: For these providers, we:
- // 1. Replace the media in the tool result with a text placeholder
- // 2. Inject a user message immediately after with the image as a file attachment
- // 3. This maintains the tool execution flow while working around API limitations
- //
- // Anthropic and Bedrock support images natively in tool results, so we skip
- // this workaround for them.
- //
- // Example transformation:
- //
- // BEFORE: [tool result: image data]
- // AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
- func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
- providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
- largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
- if providerSupportsMedia {
- return messages
- }
- convertedMessages := make([]fantasy.Message, 0, len(messages))
- for _, msg := range messages {
- if msg.Role != fantasy.MessageRoleTool {
- convertedMessages = append(convertedMessages, msg)
- continue
- }
- textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
- var mediaFiles []fantasy.FilePart
- for _, part := range msg.Content {
- toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
- if !ok {
- textParts = append(textParts, part)
- continue
- }
- if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
- decoded, err := base64.StdEncoding.DecodeString(media.Data)
- if err != nil {
- slog.Warn("Failed to decode media data", "error", err)
- textParts = append(textParts, part)
- continue
- }
- mediaFiles = append(mediaFiles, fantasy.FilePart{
- Data: decoded,
- MediaType: media.MediaType,
- Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
- })
- textParts = append(textParts, fantasy.ToolResultPart{
- ToolCallID: toolResult.ToolCallID,
- Output: fantasy.ToolResultOutputContentText{
- Text: "[Image/media content loaded - see attached file]",
- },
- ProviderOptions: toolResult.ProviderOptions,
- })
- } else {
- textParts = append(textParts, part)
- }
- }
- convertedMessages = append(convertedMessages, fantasy.Message{
- Role: fantasy.MessageRoleTool,
- Content: textParts,
- })
- if len(mediaFiles) > 0 {
- convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
- "Here is the media content from the tool result:",
- mediaFiles...,
- ))
- }
- }
- return convertedMessages
- }
- // buildSummaryPrompt constructs the prompt text for session summarization.
- func buildSummaryPrompt(todos []session.Todo) string {
- var sb strings.Builder
- sb.WriteString("Provide a detailed summary of our conversation above.")
- if len(todos) > 0 {
- sb.WriteString("\n\n## Current Todo List\n\n")
- for _, t := range todos {
- fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
- }
- sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
- sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
- }
- return sb.String()
- }
|