agent.go 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158
  1. // Package agent is the core orchestration layer for Crush AI agents.
  2. //
  3. // It provides session-based AI agent functionality for managing
  4. // conversations, tool execution, and message handling. It coordinates
  5. // interactions between language models, messages, sessions, and tools while
  6. // handling features like automatic summarization, queuing, and token
  7. // management.
  8. package agent
  9. import (
  10. "cmp"
  11. "context"
  12. _ "embed"
  13. "encoding/base64"
  14. "errors"
  15. "fmt"
  16. "log/slog"
  17. "os"
  18. "regexp"
  19. "strconv"
  20. "strings"
  21. "sync"
  22. "time"
  23. "charm.land/catwalk/pkg/catwalk"
  24. "charm.land/fantasy"
  25. "charm.land/fantasy/providers/anthropic"
  26. "charm.land/fantasy/providers/bedrock"
  27. "charm.land/fantasy/providers/google"
  28. "charm.land/fantasy/providers/openai"
  29. "charm.land/fantasy/providers/openrouter"
  30. "charm.land/fantasy/providers/vercel"
  31. "charm.land/lipgloss/v2"
  32. "github.com/charmbracelet/crush/internal/agent/hyper"
  33. "github.com/charmbracelet/crush/internal/agent/tools"
  34. "github.com/charmbracelet/crush/internal/agent/tools/mcp"
  35. "github.com/charmbracelet/crush/internal/config"
  36. "github.com/charmbracelet/crush/internal/csync"
  37. "github.com/charmbracelet/crush/internal/message"
  38. "github.com/charmbracelet/crush/internal/permission"
  39. "github.com/charmbracelet/crush/internal/session"
  40. "github.com/charmbracelet/crush/internal/stringext"
  41. "github.com/charmbracelet/x/exp/charmtone"
  42. )
  43. const (
  44. defaultSessionName = "Untitled Session"
  45. // Constants for auto-summarization thresholds
  46. largeContextWindowThreshold = 200_000
  47. largeContextWindowBuffer = 20_000
  48. smallContextWindowRatio = 0.2
  49. )
  50. //go:embed templates/title.md
  51. var titlePrompt []byte
  52. //go:embed templates/summary.md
  53. var summaryPrompt []byte
  54. // Used to remove <think> tags from generated titles.
  55. var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
  56. type SessionAgentCall struct {
  57. SessionID string
  58. Prompt string
  59. ProviderOptions fantasy.ProviderOptions
  60. Attachments []message.Attachment
  61. MaxOutputTokens int64
  62. Temperature *float64
  63. TopP *float64
  64. TopK *int64
  65. FrequencyPenalty *float64
  66. PresencePenalty *float64
  67. }
  68. type SessionAgent interface {
  69. Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  70. SetModels(large Model, small Model)
  71. SetTools(tools []fantasy.AgentTool)
  72. SetSystemPrompt(systemPrompt string)
  73. Cancel(sessionID string)
  74. CancelAll()
  75. IsSessionBusy(sessionID string) bool
  76. IsBusy() bool
  77. QueuedPrompts(sessionID string) int
  78. QueuedPromptsList(sessionID string) []string
  79. ClearQueue(sessionID string)
  80. Summarize(context.Context, string, fantasy.ProviderOptions) error
  81. Model() Model
  82. }
  83. type Model struct {
  84. Model fantasy.LanguageModel
  85. CatwalkCfg catwalk.Model
  86. ModelCfg config.SelectedModel
  87. }
  88. type sessionAgent struct {
  89. largeModel *csync.Value[Model]
  90. smallModel *csync.Value[Model]
  91. systemPromptPrefix *csync.Value[string]
  92. systemPrompt *csync.Value[string]
  93. tools *csync.Slice[fantasy.AgentTool]
  94. isSubAgent bool
  95. sessions session.Service
  96. messages message.Service
  97. disableAutoSummarize bool
  98. isYolo bool
  99. messageQueue *csync.Map[string, []SessionAgentCall]
  100. activeRequests *csync.Map[string, context.CancelFunc]
  101. }
  102. type SessionAgentOptions struct {
  103. LargeModel Model
  104. SmallModel Model
  105. SystemPromptPrefix string
  106. SystemPrompt string
  107. IsSubAgent bool
  108. DisableAutoSummarize bool
  109. IsYolo bool
  110. Sessions session.Service
  111. Messages message.Service
  112. Tools []fantasy.AgentTool
  113. }
  114. func NewSessionAgent(
  115. opts SessionAgentOptions,
  116. ) SessionAgent {
  117. return &sessionAgent{
  118. largeModel: csync.NewValue(opts.LargeModel),
  119. smallModel: csync.NewValue(opts.SmallModel),
  120. systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
  121. systemPrompt: csync.NewValue(opts.SystemPrompt),
  122. isSubAgent: opts.IsSubAgent,
  123. sessions: opts.Sessions,
  124. messages: opts.Messages,
  125. disableAutoSummarize: opts.DisableAutoSummarize,
  126. tools: csync.NewSliceFrom(opts.Tools),
  127. isYolo: opts.IsYolo,
  128. messageQueue: csync.NewMap[string, []SessionAgentCall](),
  129. activeRequests: csync.NewMap[string, context.CancelFunc](),
  130. }
  131. }
  132. func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
  133. if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
  134. return nil, ErrEmptyPrompt
  135. }
  136. if call.SessionID == "" {
  137. return nil, ErrSessionMissing
  138. }
  139. // Queue the message if busy
  140. if a.IsSessionBusy(call.SessionID) {
  141. existing, ok := a.messageQueue.Get(call.SessionID)
  142. if !ok {
  143. existing = []SessionAgentCall{}
  144. }
  145. existing = append(existing, call)
  146. a.messageQueue.Set(call.SessionID, existing)
  147. return nil, nil
  148. }
  149. // Copy mutable fields under lock to avoid races with SetTools/SetModels.
  150. agentTools := a.tools.Copy()
  151. largeModel := a.largeModel.Get()
  152. systemPrompt := a.systemPrompt.Get()
  153. promptPrefix := a.systemPromptPrefix.Get()
  154. var instructions strings.Builder
  155. for _, server := range mcp.GetStates() {
  156. if server.State != mcp.StateConnected {
  157. continue
  158. }
  159. if s := server.Client.InitializeResult().Instructions; s != "" {
  160. instructions.WriteString(s)
  161. instructions.WriteString("\n\n")
  162. }
  163. }
  164. if s := instructions.String(); s != "" {
  165. systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
  166. }
  167. if len(agentTools) > 0 {
  168. // Add Anthropic caching to the last tool.
  169. agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
  170. }
  171. agent := fantasy.NewAgent(
  172. largeModel.Model,
  173. fantasy.WithSystemPrompt(systemPrompt),
  174. fantasy.WithTools(agentTools...),
  175. )
  176. sessionLock := sync.Mutex{}
  177. currentSession, err := a.sessions.Get(ctx, call.SessionID)
  178. if err != nil {
  179. return nil, fmt.Errorf("failed to get session: %w", err)
  180. }
  181. msgs, err := a.getSessionMessages(ctx, currentSession)
  182. if err != nil {
  183. return nil, fmt.Errorf("failed to get session messages: %w", err)
  184. }
  185. var wg sync.WaitGroup
  186. // Generate title if first message.
  187. if len(msgs) == 0 {
  188. titleCtx := ctx // Copy to avoid race with ctx reassignment below.
  189. wg.Go(func() {
  190. a.generateTitle(titleCtx, call.SessionID, call.Prompt)
  191. })
  192. }
  193. defer wg.Wait()
  194. // Add the user message to the session.
  195. _, err = a.createUserMessage(ctx, call)
  196. if err != nil {
  197. return nil, err
  198. }
  199. // Add the session to the context.
  200. ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
  201. genCtx, cancel := context.WithCancel(ctx)
  202. a.activeRequests.Set(call.SessionID, cancel)
  203. defer cancel()
  204. defer a.activeRequests.Del(call.SessionID)
  205. history, files := a.preparePrompt(msgs, call.Attachments...)
  206. startTime := time.Now()
  207. a.eventPromptSent(call.SessionID)
  208. var currentAssistant *message.Message
  209. var shouldSummarize bool
  210. result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
  211. Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
  212. Files: files,
  213. Messages: history,
  214. ProviderOptions: call.ProviderOptions,
  215. MaxOutputTokens: &call.MaxOutputTokens,
  216. TopP: call.TopP,
  217. Temperature: call.Temperature,
  218. PresencePenalty: call.PresencePenalty,
  219. TopK: call.TopK,
  220. FrequencyPenalty: call.FrequencyPenalty,
  221. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  222. prepared.Messages = options.Messages
  223. for i := range prepared.Messages {
  224. prepared.Messages[i].ProviderOptions = nil
  225. }
  226. queuedCalls, _ := a.messageQueue.Get(call.SessionID)
  227. a.messageQueue.Del(call.SessionID)
  228. for _, queued := range queuedCalls {
  229. userMessage, createErr := a.createUserMessage(callContext, queued)
  230. if createErr != nil {
  231. return callContext, prepared, createErr
  232. }
  233. prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
  234. }
  235. prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
  236. lastSystemRoleInx := 0
  237. systemMessageUpdated := false
  238. for i, msg := range prepared.Messages {
  239. // Only add cache control to the last message.
  240. if msg.Role == fantasy.MessageRoleSystem {
  241. lastSystemRoleInx = i
  242. } else if !systemMessageUpdated {
  243. prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
  244. systemMessageUpdated = true
  245. }
  246. // Than add cache control to the last 2 messages.
  247. if i > len(prepared.Messages)-3 {
  248. prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
  249. }
  250. }
  251. if promptPrefix != "" {
  252. prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
  253. }
  254. var assistantMsg message.Message
  255. assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
  256. Role: message.Assistant,
  257. Parts: []message.ContentPart{},
  258. Model: largeModel.ModelCfg.Model,
  259. Provider: largeModel.ModelCfg.Provider,
  260. })
  261. if err != nil {
  262. return callContext, prepared, err
  263. }
  264. callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
  265. callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
  266. callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
  267. currentAssistant = &assistantMsg
  268. return callContext, prepared, err
  269. },
  270. OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
  271. currentAssistant.AppendReasoningContent(reasoning.Text)
  272. return a.messages.Update(genCtx, *currentAssistant)
  273. },
  274. OnReasoningDelta: func(id string, text string) error {
  275. currentAssistant.AppendReasoningContent(text)
  276. return a.messages.Update(genCtx, *currentAssistant)
  277. },
  278. OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
  279. // handle anthropic signature
  280. if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
  281. if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
  282. currentAssistant.AppendReasoningSignature(reasoning.Signature)
  283. }
  284. }
  285. if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
  286. if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
  287. currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
  288. }
  289. }
  290. if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
  291. if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
  292. currentAssistant.SetReasoningResponsesData(reasoning)
  293. }
  294. }
  295. currentAssistant.FinishThinking()
  296. return a.messages.Update(genCtx, *currentAssistant)
  297. },
  298. OnTextDelta: func(id string, text string) error {
  299. // Strip leading newline from initial text content. This is is
  300. // particularly important in non-interactive mode where leading
  301. // newlines are very visible.
  302. if len(currentAssistant.Parts) == 0 {
  303. text = strings.TrimPrefix(text, "\n")
  304. }
  305. currentAssistant.AppendContent(text)
  306. return a.messages.Update(genCtx, *currentAssistant)
  307. },
  308. OnToolInputStart: func(id string, toolName string) error {
  309. toolCall := message.ToolCall{
  310. ID: id,
  311. Name: toolName,
  312. ProviderExecuted: false,
  313. Finished: false,
  314. }
  315. currentAssistant.AddToolCall(toolCall)
  316. return a.messages.Update(genCtx, *currentAssistant)
  317. },
  318. OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
  319. // TODO: implement
  320. },
  321. OnToolCall: func(tc fantasy.ToolCallContent) error {
  322. toolCall := message.ToolCall{
  323. ID: tc.ToolCallID,
  324. Name: tc.ToolName,
  325. Input: tc.Input,
  326. ProviderExecuted: false,
  327. Finished: true,
  328. }
  329. currentAssistant.AddToolCall(toolCall)
  330. return a.messages.Update(genCtx, *currentAssistant)
  331. },
  332. OnToolResult: func(result fantasy.ToolResultContent) error {
  333. toolResult := a.convertToToolResult(result)
  334. _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
  335. Role: message.Tool,
  336. Parts: []message.ContentPart{
  337. toolResult,
  338. },
  339. })
  340. return createMsgErr
  341. },
  342. OnStepFinish: func(stepResult fantasy.StepResult) error {
  343. finishReason := message.FinishReasonUnknown
  344. switch stepResult.FinishReason {
  345. case fantasy.FinishReasonLength:
  346. finishReason = message.FinishReasonMaxTokens
  347. case fantasy.FinishReasonStop:
  348. finishReason = message.FinishReasonEndTurn
  349. case fantasy.FinishReasonToolCalls:
  350. finishReason = message.FinishReasonToolUse
  351. }
  352. currentAssistant.AddFinish(finishReason, "", "")
  353. sessionLock.Lock()
  354. defer sessionLock.Unlock()
  355. updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
  356. if getSessionErr != nil {
  357. return getSessionErr
  358. }
  359. a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
  360. _, sessionErr := a.sessions.Save(ctx, updatedSession)
  361. if sessionErr != nil {
  362. return sessionErr
  363. }
  364. currentSession = updatedSession
  365. return a.messages.Update(genCtx, *currentAssistant)
  366. },
  367. StopWhen: []fantasy.StopCondition{
  368. func(_ []fantasy.StepResult) bool {
  369. cw := int64(largeModel.CatwalkCfg.ContextWindow)
  370. tokens := currentSession.CompletionTokens + currentSession.PromptTokens
  371. remaining := cw - tokens
  372. var threshold int64
  373. if cw > largeContextWindowThreshold {
  374. threshold = largeContextWindowBuffer
  375. } else {
  376. threshold = int64(float64(cw) * smallContextWindowRatio)
  377. }
  378. if (remaining <= threshold) && !a.disableAutoSummarize {
  379. shouldSummarize = true
  380. return true
  381. }
  382. return false
  383. },
  384. func(steps []fantasy.StepResult) bool {
  385. return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
  386. },
  387. },
  388. })
  389. a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
  390. if err != nil {
  391. isCancelErr := errors.Is(err, context.Canceled)
  392. isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
  393. if currentAssistant == nil {
  394. return result, err
  395. }
  396. // Ensure we finish thinking on error to close the reasoning state.
  397. currentAssistant.FinishThinking()
  398. toolCalls := currentAssistant.ToolCalls()
  399. // INFO: we use the parent context here because the genCtx has been cancelled.
  400. msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
  401. if createErr != nil {
  402. return nil, createErr
  403. }
  404. for _, tc := range toolCalls {
  405. if !tc.Finished {
  406. tc.Finished = true
  407. tc.Input = "{}"
  408. currentAssistant.AddToolCall(tc)
  409. updateErr := a.messages.Update(ctx, *currentAssistant)
  410. if updateErr != nil {
  411. return nil, updateErr
  412. }
  413. }
  414. found := false
  415. for _, msg := range msgs {
  416. if msg.Role == message.Tool {
  417. for _, tr := range msg.ToolResults() {
  418. if tr.ToolCallID == tc.ID {
  419. found = true
  420. break
  421. }
  422. }
  423. }
  424. if found {
  425. break
  426. }
  427. }
  428. if found {
  429. continue
  430. }
  431. content := "There was an error while executing the tool"
  432. if isCancelErr {
  433. content = "Tool execution canceled by user"
  434. } else if isPermissionErr {
  435. content = "User denied permission"
  436. }
  437. toolResult := message.ToolResult{
  438. ToolCallID: tc.ID,
  439. Name: tc.Name,
  440. Content: content,
  441. IsError: true,
  442. }
  443. _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
  444. Role: message.Tool,
  445. Parts: []message.ContentPart{
  446. toolResult,
  447. },
  448. })
  449. if createErr != nil {
  450. return nil, createErr
  451. }
  452. }
  453. var fantasyErr *fantasy.Error
  454. var providerErr *fantasy.ProviderError
  455. const defaultTitle = "Provider Error"
  456. linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
  457. if isCancelErr {
  458. currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
  459. } else if isPermissionErr {
  460. currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
  461. } else if errors.Is(err, hyper.ErrNoCredits) {
  462. url := hyper.BaseURL()
  463. link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
  464. currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
  465. } else if errors.As(err, &providerErr) {
  466. if providerErr.Message == "The requested model is not supported." {
  467. url := "https://github.com/settings/copilot/features"
  468. link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
  469. currentAssistant.AddFinish(
  470. message.FinishReasonError,
  471. "Copilot model not enabled",
  472. fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
  473. )
  474. } else {
  475. currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
  476. }
  477. } else if errors.As(err, &fantasyErr) {
  478. currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
  479. } else {
  480. currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
  481. }
  482. // Note: we use the parent context here because the genCtx has been
  483. // cancelled.
  484. updateErr := a.messages.Update(ctx, *currentAssistant)
  485. if updateErr != nil {
  486. return nil, updateErr
  487. }
  488. return nil, err
  489. }
  490. if shouldSummarize {
  491. a.activeRequests.Del(call.SessionID)
  492. if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
  493. return nil, summarizeErr
  494. }
  495. // If the agent wasn't done...
  496. if len(currentAssistant.ToolCalls()) > 0 {
  497. existing, ok := a.messageQueue.Get(call.SessionID)
  498. if !ok {
  499. existing = []SessionAgentCall{}
  500. }
  501. call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
  502. existing = append(existing, call)
  503. a.messageQueue.Set(call.SessionID, existing)
  504. }
  505. }
  506. // Release active request before processing queued messages.
  507. a.activeRequests.Del(call.SessionID)
  508. cancel()
  509. queuedMessages, ok := a.messageQueue.Get(call.SessionID)
  510. if !ok || len(queuedMessages) == 0 {
  511. return result, err
  512. }
  513. // There are queued messages restart the loop.
  514. firstQueuedMessage := queuedMessages[0]
  515. a.messageQueue.Set(call.SessionID, queuedMessages[1:])
  516. return a.Run(ctx, firstQueuedMessage)
  517. }
  518. func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
  519. if a.IsSessionBusy(sessionID) {
  520. return ErrSessionBusy
  521. }
  522. // Copy mutable fields under lock to avoid races with SetModels.
  523. largeModel := a.largeModel.Get()
  524. systemPromptPrefix := a.systemPromptPrefix.Get()
  525. currentSession, err := a.sessions.Get(ctx, sessionID)
  526. if err != nil {
  527. return fmt.Errorf("failed to get session: %w", err)
  528. }
  529. msgs, err := a.getSessionMessages(ctx, currentSession)
  530. if err != nil {
  531. return err
  532. }
  533. if len(msgs) == 0 {
  534. // Nothing to summarize.
  535. return nil
  536. }
  537. aiMsgs, _ := a.preparePrompt(msgs)
  538. genCtx, cancel := context.WithCancel(ctx)
  539. a.activeRequests.Set(sessionID, cancel)
  540. defer a.activeRequests.Del(sessionID)
  541. defer cancel()
  542. agent := fantasy.NewAgent(largeModel.Model,
  543. fantasy.WithSystemPrompt(string(summaryPrompt)),
  544. )
  545. summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  546. Role: message.Assistant,
  547. Model: largeModel.Model.Model(),
  548. Provider: largeModel.Model.Provider(),
  549. IsSummaryMessage: true,
  550. })
  551. if err != nil {
  552. return err
  553. }
  554. summaryPromptText := buildSummaryPrompt(currentSession.Todos)
  555. resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
  556. Prompt: summaryPromptText,
  557. Messages: aiMsgs,
  558. ProviderOptions: opts,
  559. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  560. prepared.Messages = options.Messages
  561. if systemPromptPrefix != "" {
  562. prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
  563. }
  564. return callContext, prepared, nil
  565. },
  566. OnReasoningDelta: func(id string, text string) error {
  567. summaryMessage.AppendReasoningContent(text)
  568. return a.messages.Update(genCtx, summaryMessage)
  569. },
  570. OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
  571. // Handle anthropic signature.
  572. if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
  573. if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
  574. summaryMessage.AppendReasoningSignature(signature.Signature)
  575. }
  576. }
  577. summaryMessage.FinishThinking()
  578. return a.messages.Update(genCtx, summaryMessage)
  579. },
  580. OnTextDelta: func(id, text string) error {
  581. summaryMessage.AppendContent(text)
  582. return a.messages.Update(genCtx, summaryMessage)
  583. },
  584. })
  585. if err != nil {
  586. isCancelErr := errors.Is(err, context.Canceled)
  587. if isCancelErr {
  588. // User cancelled summarize we need to remove the summary message.
  589. deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
  590. return deleteErr
  591. }
  592. return err
  593. }
  594. summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
  595. err = a.messages.Update(genCtx, summaryMessage)
  596. if err != nil {
  597. return err
  598. }
  599. var openrouterCost *float64
  600. for _, step := range resp.Steps {
  601. stepCost := a.openrouterCost(step.ProviderMetadata)
  602. if stepCost != nil {
  603. newCost := *stepCost
  604. if openrouterCost != nil {
  605. newCost += *openrouterCost
  606. }
  607. openrouterCost = &newCost
  608. }
  609. }
  610. a.updateSessionUsage(largeModel, &currentSession, resp.TotalUsage, openrouterCost)
  611. // Just in case, get just the last usage info.
  612. usage := resp.Response.Usage
  613. currentSession.SummaryMessageID = summaryMessage.ID
  614. currentSession.CompletionTokens = usage.OutputTokens
  615. currentSession.PromptTokens = 0
  616. _, err = a.sessions.Save(genCtx, currentSession)
  617. return err
  618. }
  619. func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
  620. if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
  621. return fantasy.ProviderOptions{}
  622. }
  623. return fantasy.ProviderOptions{
  624. anthropic.Name: &anthropic.ProviderCacheControlOptions{
  625. CacheControl: anthropic.CacheControl{Type: "ephemeral"},
  626. },
  627. bedrock.Name: &anthropic.ProviderCacheControlOptions{
  628. CacheControl: anthropic.CacheControl{Type: "ephemeral"},
  629. },
  630. vercel.Name: &anthropic.ProviderCacheControlOptions{
  631. CacheControl: anthropic.CacheControl{Type: "ephemeral"},
  632. },
  633. }
  634. }
  635. func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
  636. parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
  637. var attachmentParts []message.ContentPart
  638. for _, attachment := range call.Attachments {
  639. attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
  640. }
  641. parts = append(parts, attachmentParts...)
  642. msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
  643. Role: message.User,
  644. Parts: parts,
  645. })
  646. if err != nil {
  647. return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
  648. }
  649. return msg, nil
  650. }
  651. func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
  652. var history []fantasy.Message
  653. if !a.isSubAgent {
  654. history = append(history, fantasy.NewUserMessage(
  655. fmt.Sprintf("<system_reminder>%s</system_reminder>",
  656. `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
  657. If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
  658. If not, please feel free to ignore. Again do not mention this message to the user.`,
  659. ),
  660. ))
  661. }
  662. for _, m := range msgs {
  663. if len(m.Parts) == 0 {
  664. continue
  665. }
  666. // Assistant message without content or tool calls (cancelled before it
  667. // returned anything).
  668. if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
  669. continue
  670. }
  671. history = append(history, m.ToAIMessage()...)
  672. }
  673. var files []fantasy.FilePart
  674. for _, attachment := range attachments {
  675. if attachment.IsText() {
  676. continue
  677. }
  678. files = append(files, fantasy.FilePart{
  679. Filename: attachment.FileName,
  680. Data: attachment.Content,
  681. MediaType: attachment.MimeType,
  682. })
  683. }
  684. return history, files
  685. }
  686. func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
  687. msgs, err := a.messages.List(ctx, session.ID)
  688. if err != nil {
  689. return nil, fmt.Errorf("failed to list messages: %w", err)
  690. }
  691. if session.SummaryMessageID != "" {
  692. summaryMsgIndex := -1
  693. for i, msg := range msgs {
  694. if msg.ID == session.SummaryMessageID {
  695. summaryMsgIndex = i
  696. break
  697. }
  698. }
  699. if summaryMsgIndex != -1 {
  700. msgs = msgs[summaryMsgIndex:]
  701. msgs[0].Role = message.User
  702. }
  703. }
  704. return msgs, nil
  705. }
  706. // generateTitle generates a session titled based on the initial prompt.
  707. func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
  708. if userPrompt == "" {
  709. return
  710. }
  711. smallModel := a.smallModel.Get()
  712. largeModel := a.largeModel.Get()
  713. systemPromptPrefix := a.systemPromptPrefix.Get()
  714. var maxOutputTokens int64 = 40
  715. if smallModel.CatwalkCfg.CanReason {
  716. maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
  717. }
  718. newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
  719. return fantasy.NewAgent(m,
  720. fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
  721. fantasy.WithMaxOutputTokens(tok),
  722. )
  723. }
  724. streamCall := fantasy.AgentStreamCall{
  725. Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
  726. PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  727. prepared.Messages = opts.Messages
  728. if systemPromptPrefix != "" {
  729. prepared.Messages = append([]fantasy.Message{
  730. fantasy.NewSystemMessage(systemPromptPrefix),
  731. }, prepared.Messages...)
  732. }
  733. return callCtx, prepared, nil
  734. },
  735. }
  736. // Use the small model to generate the title.
  737. model := smallModel
  738. agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
  739. resp, err := agent.Stream(ctx, streamCall)
  740. if err == nil {
  741. // We successfully generated a title with the small model.
  742. slog.Debug("Generated title with small model")
  743. } else {
  744. // It didn't work. Let's try with the big model.
  745. slog.Error("Error generating title with small model; trying big model", "err", err)
  746. model = largeModel
  747. agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
  748. resp, err = agent.Stream(ctx, streamCall)
  749. if err == nil {
  750. slog.Debug("Generated title with large model")
  751. } else {
  752. // Welp, the large model didn't work either. Use the default
  753. // session name and return.
  754. slog.Error("Error generating title with large model", "err", err)
  755. saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
  756. if saveErr != nil {
  757. slog.Error("Failed to save session title and usage", "error", saveErr)
  758. }
  759. return
  760. }
  761. }
  762. if resp == nil {
  763. // Actually, we didn't get a response so we can't. Use the default
  764. // session name and return.
  765. slog.Error("Response is nil; can't generate title")
  766. saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
  767. if saveErr != nil {
  768. slog.Error("Failed to save session title and usage", "error", saveErr)
  769. }
  770. return
  771. }
  772. // Clean up title.
  773. var title string
  774. title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
  775. // Remove thinking tags if present.
  776. title = thinkTagRegex.ReplaceAllString(title, "")
  777. title = strings.TrimSpace(title)
  778. title = cmp.Or(title, defaultSessionName)
  779. // Calculate usage and cost.
  780. var openrouterCost *float64
  781. for _, step := range resp.Steps {
  782. stepCost := a.openrouterCost(step.ProviderMetadata)
  783. if stepCost != nil {
  784. newCost := *stepCost
  785. if openrouterCost != nil {
  786. newCost += *openrouterCost
  787. }
  788. openrouterCost = &newCost
  789. }
  790. }
  791. modelConfig := model.CatwalkCfg
  792. cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
  793. modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
  794. modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
  795. modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
  796. // Use override cost if available (e.g., from OpenRouter).
  797. if openrouterCost != nil {
  798. cost = *openrouterCost
  799. }
  800. promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
  801. completionTokens := resp.TotalUsage.OutputTokens
  802. // Atomically update only title and usage fields to avoid overriding other
  803. // concurrent session updates.
  804. saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
  805. if saveErr != nil {
  806. slog.Error("Failed to save session title and usage", "error", saveErr)
  807. return
  808. }
  809. }
  810. func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
  811. openrouterMetadata, ok := metadata[openrouter.Name]
  812. if !ok {
  813. return nil
  814. }
  815. opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
  816. if !ok {
  817. return nil
  818. }
  819. return &opts.Usage.Cost
  820. }
  821. func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
  822. modelConfig := model.CatwalkCfg
  823. cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
  824. modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
  825. modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
  826. modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
  827. a.eventTokensUsed(session.ID, model, usage, cost)
  828. if overrideCost != nil {
  829. session.Cost += *overrideCost
  830. } else {
  831. session.Cost += cost
  832. }
  833. session.CompletionTokens = usage.OutputTokens
  834. session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
  835. }
  836. func (a *sessionAgent) Cancel(sessionID string) {
  837. // Cancel regular requests. Don't use Take() here - we need the entry to
  838. // remain in activeRequests so IsBusy() returns true until the goroutine
  839. // fully completes (including error handling that may access the DB).
  840. // The defer in processRequest will clean up the entry.
  841. if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
  842. slog.Debug("Request cancellation initiated", "session_id", sessionID)
  843. cancel()
  844. }
  845. // Also check for summarize requests.
  846. if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
  847. slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
  848. cancel()
  849. }
  850. if a.QueuedPrompts(sessionID) > 0 {
  851. slog.Debug("Clearing queued prompts", "session_id", sessionID)
  852. a.messageQueue.Del(sessionID)
  853. }
  854. }
  855. func (a *sessionAgent) ClearQueue(sessionID string) {
  856. if a.QueuedPrompts(sessionID) > 0 {
  857. slog.Debug("Clearing queued prompts", "session_id", sessionID)
  858. a.messageQueue.Del(sessionID)
  859. }
  860. }
  861. func (a *sessionAgent) CancelAll() {
  862. if !a.IsBusy() {
  863. return
  864. }
  865. for key := range a.activeRequests.Seq2() {
  866. a.Cancel(key) // key is sessionID
  867. }
  868. timeout := time.After(5 * time.Second)
  869. for a.IsBusy() {
  870. select {
  871. case <-timeout:
  872. return
  873. default:
  874. time.Sleep(200 * time.Millisecond)
  875. }
  876. }
  877. }
  878. func (a *sessionAgent) IsBusy() bool {
  879. var busy bool
  880. for cancelFunc := range a.activeRequests.Seq() {
  881. if cancelFunc != nil {
  882. busy = true
  883. break
  884. }
  885. }
  886. return busy
  887. }
  888. func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
  889. _, busy := a.activeRequests.Get(sessionID)
  890. return busy
  891. }
  892. func (a *sessionAgent) QueuedPrompts(sessionID string) int {
  893. l, ok := a.messageQueue.Get(sessionID)
  894. if !ok {
  895. return 0
  896. }
  897. return len(l)
  898. }
  899. func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
  900. l, ok := a.messageQueue.Get(sessionID)
  901. if !ok {
  902. return nil
  903. }
  904. prompts := make([]string, len(l))
  905. for i, call := range l {
  906. prompts[i] = call.Prompt
  907. }
  908. return prompts
  909. }
  910. func (a *sessionAgent) SetModels(large Model, small Model) {
  911. a.largeModel.Set(large)
  912. a.smallModel.Set(small)
  913. }
  914. func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
  915. a.tools.SetSlice(tools)
  916. }
  917. func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
  918. a.systemPrompt.Set(systemPrompt)
  919. }
  920. func (a *sessionAgent) Model() Model {
  921. return a.largeModel.Get()
  922. }
  923. // convertToToolResult converts a fantasy tool result to a message tool result.
  924. func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
  925. baseResult := message.ToolResult{
  926. ToolCallID: result.ToolCallID,
  927. Name: result.ToolName,
  928. Metadata: result.ClientMetadata,
  929. }
  930. switch result.Result.GetType() {
  931. case fantasy.ToolResultContentTypeText:
  932. if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
  933. baseResult.Content = r.Text
  934. }
  935. case fantasy.ToolResultContentTypeError:
  936. if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
  937. baseResult.Content = r.Error.Error()
  938. baseResult.IsError = true
  939. }
  940. case fantasy.ToolResultContentTypeMedia:
  941. if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
  942. content := r.Text
  943. if content == "" {
  944. content = fmt.Sprintf("Loaded %s content", r.MediaType)
  945. }
  946. baseResult.Content = content
  947. baseResult.Data = r.Data
  948. baseResult.MIMEType = r.MediaType
  949. }
  950. }
  951. return baseResult
  952. }
  953. // workaroundProviderMediaLimitations converts media content in tool results to
  954. // user messages for providers that don't natively support images in tool results.
  955. //
  956. // Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
  957. // don't support sending images/media in tool result messages - they only accept
  958. // text in tool results. However, they DO support images in user messages.
  959. //
  960. // If we send media in tool results to these providers, the API returns an error.
  961. //
  962. // Solution: For these providers, we:
  963. // 1. Replace the media in the tool result with a text placeholder
  964. // 2. Inject a user message immediately after with the image as a file attachment
  965. // 3. This maintains the tool execution flow while working around API limitations
  966. //
  967. // Anthropic and Bedrock support images natively in tool results, so we skip
  968. // this workaround for them.
  969. //
  970. // Example transformation:
  971. //
  972. // BEFORE: [tool result: image data]
  973. // AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
  974. func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
  975. providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
  976. largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
  977. if providerSupportsMedia {
  978. return messages
  979. }
  980. convertedMessages := make([]fantasy.Message, 0, len(messages))
  981. for _, msg := range messages {
  982. if msg.Role != fantasy.MessageRoleTool {
  983. convertedMessages = append(convertedMessages, msg)
  984. continue
  985. }
  986. textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
  987. var mediaFiles []fantasy.FilePart
  988. for _, part := range msg.Content {
  989. toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
  990. if !ok {
  991. textParts = append(textParts, part)
  992. continue
  993. }
  994. if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
  995. decoded, err := base64.StdEncoding.DecodeString(media.Data)
  996. if err != nil {
  997. slog.Warn("Failed to decode media data", "error", err)
  998. textParts = append(textParts, part)
  999. continue
  1000. }
  1001. mediaFiles = append(mediaFiles, fantasy.FilePart{
  1002. Data: decoded,
  1003. MediaType: media.MediaType,
  1004. Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
  1005. })
  1006. textParts = append(textParts, fantasy.ToolResultPart{
  1007. ToolCallID: toolResult.ToolCallID,
  1008. Output: fantasy.ToolResultOutputContentText{
  1009. Text: "[Image/media content loaded - see attached file]",
  1010. },
  1011. ProviderOptions: toolResult.ProviderOptions,
  1012. })
  1013. } else {
  1014. textParts = append(textParts, part)
  1015. }
  1016. }
  1017. convertedMessages = append(convertedMessages, fantasy.Message{
  1018. Role: fantasy.MessageRoleTool,
  1019. Content: textParts,
  1020. })
  1021. if len(mediaFiles) > 0 {
  1022. convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
  1023. "Here is the media content from the tool result:",
  1024. mediaFiles...,
  1025. ))
  1026. }
  1027. }
  1028. return convertedMessages
  1029. }
  1030. // buildSummaryPrompt constructs the prompt text for session summarization.
  1031. func buildSummaryPrompt(todos []session.Todo) string {
  1032. var sb strings.Builder
  1033. sb.WriteString("Provide a detailed summary of our conversation above.")
  1034. if len(todos) > 0 {
  1035. sb.WriteString("\n\n## Current Todo List\n\n")
  1036. for _, t := range todos {
  1037. fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
  1038. }
  1039. sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
  1040. sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
  1041. }
  1042. return sb.String()
  1043. }