agent.go 36 KB


  1. // Package agent is the core orchestration layer for Crush AI agents.
  2. //
  3. // It provides session-based AI agent functionality for managing
  4. // conversations, tool execution, and message handling. It coordinates
  5. // interactions between language models, messages, sessions, and tools while
  6. // handling features like automatic summarization, queuing, and token
  7. // management.
  8. package agent
  9. import (
  10. "cmp"
  11. "context"
  12. _ "embed"
  13. "encoding/base64"
  14. "errors"
  15. "fmt"
  16. "log/slog"
  17. "os"
  18. "regexp"
  19. "strconv"
  20. "strings"
  21. "sync"
  22. "time"
  23. "charm.land/fantasy"
  24. "charm.land/fantasy/providers/anthropic"
  25. "charm.land/fantasy/providers/bedrock"
  26. "charm.land/fantasy/providers/google"
  27. "charm.land/fantasy/providers/openai"
  28. "charm.land/fantasy/providers/openrouter"
  29. "charm.land/lipgloss/v2"
  30. "github.com/charmbracelet/catwalk/pkg/catwalk"
  31. "github.com/charmbracelet/crush/internal/agent/hyper"
  32. "github.com/charmbracelet/crush/internal/agent/tools"
  33. "github.com/charmbracelet/crush/internal/config"
  34. "github.com/charmbracelet/crush/internal/csync"
  35. "github.com/charmbracelet/crush/internal/message"
  36. "github.com/charmbracelet/crush/internal/permission"
  37. "github.com/charmbracelet/crush/internal/session"
  38. "github.com/charmbracelet/crush/internal/stringext"
  39. "github.com/charmbracelet/x/exp/charmtone"
  40. )
  41. const (
  42. defaultSessionName = "Untitled Session"
  43. // Constants for auto-summarization thresholds
  44. largeContextWindowThreshold = 200_000
  45. largeContextWindowBuffer = 20_000
  46. smallContextWindowRatio = 0.2
  47. )
  48. //go:embed templates/title.md
  49. var titlePrompt []byte
  50. //go:embed templates/summary.md
  51. var summaryPrompt []byte
  52. // Used to remove <think> tags from generated titles.
  53. var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
  54. type SessionAgentCall struct {
  55. SessionID string
  56. Prompt string
  57. ProviderOptions fantasy.ProviderOptions
  58. Attachments []message.Attachment
  59. MaxOutputTokens int64
  60. Temperature *float64
  61. TopP *float64
  62. TopK *int64
  63. FrequencyPenalty *float64
  64. PresencePenalty *float64
  65. }
  66. type SessionAgent interface {
  67. Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  68. SetModels(large Model, small Model)
  69. SetTools(tools []fantasy.AgentTool)
  70. SetSystemPrompt(systemPrompt string)
  71. Cancel(sessionID string)
  72. CancelAll()
  73. IsSessionBusy(sessionID string) bool
  74. IsBusy() bool
  75. QueuedPrompts(sessionID string) int
  76. QueuedPromptsList(sessionID string) []string
  77. ClearQueue(sessionID string)
  78. Summarize(context.Context, string, fantasy.ProviderOptions) error
  79. Model() Model
  80. }
  81. type Model struct {
  82. Model fantasy.LanguageModel
  83. CatwalkCfg catwalk.Model
  84. ModelCfg config.SelectedModel
  85. }
  86. type sessionAgent struct {
  87. largeModel *csync.Value[Model]
  88. smallModel *csync.Value[Model]
  89. systemPromptPrefix *csync.Value[string]
  90. systemPrompt *csync.Value[string]
  91. tools *csync.Slice[fantasy.AgentTool]
  92. isSubAgent bool
  93. sessions session.Service
  94. messages message.Service
  95. disableAutoSummarize bool
  96. isYolo bool
  97. messageQueue *csync.Map[string, []SessionAgentCall]
  98. activeRequests *csync.Map[string, context.CancelFunc]
  99. }
  100. type SessionAgentOptions struct {
  101. LargeModel Model
  102. SmallModel Model
  103. SystemPromptPrefix string
  104. SystemPrompt string
  105. IsSubAgent bool
  106. DisableAutoSummarize bool
  107. IsYolo bool
  108. Sessions session.Service
  109. Messages message.Service
  110. Tools []fantasy.AgentTool
  111. }
  112. func NewSessionAgent(
  113. opts SessionAgentOptions,
  114. ) SessionAgent {
  115. return &sessionAgent{
  116. largeModel: csync.NewValue(opts.LargeModel),
  117. smallModel: csync.NewValue(opts.SmallModel),
  118. systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
  119. systemPrompt: csync.NewValue(opts.SystemPrompt),
  120. isSubAgent: opts.IsSubAgent,
  121. sessions: opts.Sessions,
  122. messages: opts.Messages,
  123. disableAutoSummarize: opts.DisableAutoSummarize,
  124. tools: csync.NewSliceFrom(opts.Tools),
  125. isYolo: opts.IsYolo,
  126. messageQueue: csync.NewMap[string, []SessionAgentCall](),
  127. activeRequests: csync.NewMap[string, context.CancelFunc](),
  128. }
  129. }
  130. func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
  131. if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
  132. return nil, ErrEmptyPrompt
  133. }
  134. if call.SessionID == "" {
  135. return nil, ErrSessionMissing
  136. }
  137. // Queue the message if busy
  138. if a.IsSessionBusy(call.SessionID) {
  139. existing, ok := a.messageQueue.Get(call.SessionID)
  140. if !ok {
  141. existing = []SessionAgentCall{}
  142. }
  143. existing = append(existing, call)
  144. a.messageQueue.Set(call.SessionID, existing)
  145. return nil, nil
  146. }
  147. // Copy mutable fields under lock to avoid races with SetTools/SetModels.
  148. agentTools := a.tools.Copy()
  149. largeModel := a.largeModel.Get()
  150. systemPrompt := a.systemPrompt.Get()
  151. promptPrefix := a.systemPromptPrefix.Get()
  152. if len(agentTools) > 0 {
  153. // Add Anthropic caching to the last tool.
  154. agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
  155. }
  156. agent := fantasy.NewAgent(
  157. largeModel.Model,
  158. fantasy.WithSystemPrompt(systemPrompt),
  159. fantasy.WithTools(agentTools...),
  160. )
  161. sessionLock := sync.Mutex{}
  162. currentSession, err := a.sessions.Get(ctx, call.SessionID)
  163. if err != nil {
  164. return nil, fmt.Errorf("failed to get session: %w", err)
  165. }
  166. msgs, err := a.getSessionMessages(ctx, currentSession)
  167. if err != nil {
  168. return nil, fmt.Errorf("failed to get session messages: %w", err)
  169. }
  170. var wg sync.WaitGroup
  171. // Generate title if first message.
  172. if len(msgs) == 0 {
  173. titleCtx := ctx // Copy to avoid race with ctx reassignment below.
  174. wg.Go(func() {
  175. a.generateTitle(titleCtx, call.SessionID, call.Prompt)
  176. })
  177. }
  178. defer wg.Wait()
  179. // Add the user message to the session.
  180. _, err = a.createUserMessage(ctx, call)
  181. if err != nil {
  182. return nil, err
  183. }
  184. // Add the session to the context.
  185. ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
  186. genCtx, cancel := context.WithCancel(ctx)
  187. a.activeRequests.Set(call.SessionID, cancel)
  188. defer cancel()
  189. defer a.activeRequests.Del(call.SessionID)
  190. history, files := a.preparePrompt(msgs, call.Attachments...)
  191. startTime := time.Now()
  192. a.eventPromptSent(call.SessionID)
  193. var currentAssistant *message.Message
  194. var shouldSummarize bool
  195. result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
  196. Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
  197. Files: files,
  198. Messages: history,
  199. ProviderOptions: call.ProviderOptions,
  200. MaxOutputTokens: &call.MaxOutputTokens,
  201. TopP: call.TopP,
  202. Temperature: call.Temperature,
  203. PresencePenalty: call.PresencePenalty,
  204. TopK: call.TopK,
  205. FrequencyPenalty: call.FrequencyPenalty,
  206. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  207. prepared.Messages = options.Messages
  208. for i := range prepared.Messages {
  209. prepared.Messages[i].ProviderOptions = nil
  210. }
  211. queuedCalls, _ := a.messageQueue.Get(call.SessionID)
  212. a.messageQueue.Del(call.SessionID)
  213. for _, queued := range queuedCalls {
  214. userMessage, createErr := a.createUserMessage(callContext, queued)
  215. if createErr != nil {
  216. return callContext, prepared, createErr
  217. }
  218. prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
  219. }
  220. prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
  221. lastSystemRoleInx := 0
  222. systemMessageUpdated := false
  223. for i, msg := range prepared.Messages {
  224. // Only add cache control to the last message.
  225. if msg.Role == fantasy.MessageRoleSystem {
  226. lastSystemRoleInx = i
  227. } else if !systemMessageUpdated {
  228. prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
  229. systemMessageUpdated = true
  230. }
  231. // Than add cache control to the last 2 messages.
  232. if i > len(prepared.Messages)-3 {
  233. prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
  234. }
  235. }
  236. if promptPrefix != "" {
  237. prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
  238. }
  239. var assistantMsg message.Message
  240. assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
  241. Role: message.Assistant,
  242. Parts: []message.ContentPart{},
  243. Model: largeModel.ModelCfg.Model,
  244. Provider: largeModel.ModelCfg.Provider,
  245. })
  246. if err != nil {
  247. return callContext, prepared, err
  248. }
  249. callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
  250. callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
  251. callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
  252. currentAssistant = &assistantMsg
  253. return callContext, prepared, err
  254. },
  255. OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
  256. currentAssistant.AppendReasoningContent(reasoning.Text)
  257. return a.messages.Update(genCtx, *currentAssistant)
  258. },
  259. OnReasoningDelta: func(id string, text string) error {
  260. currentAssistant.AppendReasoningContent(text)
  261. return a.messages.Update(genCtx, *currentAssistant)
  262. },
  263. OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
  264. // handle anthropic signature
  265. if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
  266. if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
  267. currentAssistant.AppendReasoningSignature(reasoning.Signature)
  268. }
  269. }
  270. if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
  271. if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
  272. currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
  273. }
  274. }
  275. if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
  276. if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
  277. currentAssistant.SetReasoningResponsesData(reasoning)
  278. }
  279. }
  280. currentAssistant.FinishThinking()
  281. return a.messages.Update(genCtx, *currentAssistant)
  282. },
  283. OnTextDelta: func(id string, text string) error {
  284. // Strip leading newline from initial text content. This is is
  285. // particularly important in non-interactive mode where leading
  286. // newlines are very visible.
  287. if len(currentAssistant.Parts) == 0 {
  288. text = strings.TrimPrefix(text, "\n")
  289. }
  290. currentAssistant.AppendContent(text)
  291. return a.messages.Update(genCtx, *currentAssistant)
  292. },
  293. OnToolInputStart: func(id string, toolName string) error {
  294. toolCall := message.ToolCall{
  295. ID: id,
  296. Name: toolName,
  297. ProviderExecuted: false,
  298. Finished: false,
  299. }
  300. currentAssistant.AddToolCall(toolCall)
  301. return a.messages.Update(genCtx, *currentAssistant)
  302. },
  303. OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
  304. // TODO: implement
  305. },
  306. OnToolCall: func(tc fantasy.ToolCallContent) error {
  307. toolCall := message.ToolCall{
  308. ID: tc.ToolCallID,
  309. Name: tc.ToolName,
  310. Input: tc.Input,
  311. ProviderExecuted: false,
  312. Finished: true,
  313. }
  314. currentAssistant.AddToolCall(toolCall)
  315. return a.messages.Update(genCtx, *currentAssistant)
  316. },
  317. OnToolResult: func(result fantasy.ToolResultContent) error {
  318. toolResult := a.convertToToolResult(result)
  319. _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
  320. Role: message.Tool,
  321. Parts: []message.ContentPart{
  322. toolResult,
  323. },
  324. })
  325. return createMsgErr
  326. },
  327. OnStepFinish: func(stepResult fantasy.StepResult) error {
  328. finishReason := message.FinishReasonUnknown
  329. switch stepResult.FinishReason {
  330. case fantasy.FinishReasonLength:
  331. finishReason = message.FinishReasonMaxTokens
  332. case fantasy.FinishReasonStop:
  333. finishReason = message.FinishReasonEndTurn
  334. case fantasy.FinishReasonToolCalls:
  335. finishReason = message.FinishReasonToolUse
  336. }
  337. currentAssistant.AddFinish(finishReason, "", "")
  338. sessionLock.Lock()
  339. updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
  340. if getSessionErr != nil {
  341. sessionLock.Unlock()
  342. return getSessionErr
  343. }
  344. a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
  345. _, sessionErr := a.sessions.Save(genCtx, updatedSession)
  346. sessionLock.Unlock()
  347. if sessionErr != nil {
  348. return sessionErr
  349. }
  350. return a.messages.Update(genCtx, *currentAssistant)
  351. },
  352. StopWhen: []fantasy.StopCondition{
  353. func(_ []fantasy.StepResult) bool {
  354. cw := int64(largeModel.CatwalkCfg.ContextWindow)
  355. tokens := currentSession.CompletionTokens + currentSession.PromptTokens
  356. remaining := cw - tokens
  357. var threshold int64
  358. if cw > largeContextWindowThreshold {
  359. threshold = largeContextWindowBuffer
  360. } else {
  361. threshold = int64(float64(cw) * smallContextWindowRatio)
  362. }
  363. if (remaining <= threshold) && !a.disableAutoSummarize {
  364. shouldSummarize = true
  365. return true
  366. }
  367. return false
  368. },
  369. },
  370. })
  371. a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
  372. if err != nil {
  373. isCancelErr := errors.Is(err, context.Canceled)
  374. isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
  375. if currentAssistant == nil {
  376. return result, err
  377. }
  378. // Ensure we finish thinking on error to close the reasoning state.
  379. currentAssistant.FinishThinking()
  380. toolCalls := currentAssistant.ToolCalls()
  381. // INFO: we use the parent context here because the genCtx has been cancelled.
  382. msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
  383. if createErr != nil {
  384. return nil, createErr
  385. }
  386. for _, tc := range toolCalls {
  387. if !tc.Finished {
  388. tc.Finished = true
  389. tc.Input = "{}"
  390. currentAssistant.AddToolCall(tc)
  391. updateErr := a.messages.Update(ctx, *currentAssistant)
  392. if updateErr != nil {
  393. return nil, updateErr
  394. }
  395. }
  396. found := false
  397. for _, msg := range msgs {
  398. if msg.Role == message.Tool {
  399. for _, tr := range msg.ToolResults() {
  400. if tr.ToolCallID == tc.ID {
  401. found = true
  402. break
  403. }
  404. }
  405. }
  406. if found {
  407. break
  408. }
  409. }
  410. if found {
  411. continue
  412. }
  413. content := "There was an error while executing the tool"
  414. if isCancelErr {
  415. content = "Tool execution canceled by user"
  416. } else if isPermissionErr {
  417. content = "User denied permission"
  418. }
  419. toolResult := message.ToolResult{
  420. ToolCallID: tc.ID,
  421. Name: tc.Name,
  422. Content: content,
  423. IsError: true,
  424. }
  425. _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
  426. Role: message.Tool,
  427. Parts: []message.ContentPart{
  428. toolResult,
  429. },
  430. })
  431. if createErr != nil {
  432. return nil, createErr
  433. }
  434. }
  435. var fantasyErr *fantasy.Error
  436. var providerErr *fantasy.ProviderError
  437. const defaultTitle = "Provider Error"
  438. linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
  439. if isCancelErr {
  440. currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
  441. } else if isPermissionErr {
  442. currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
  443. } else if errors.Is(err, hyper.ErrNoCredits) {
  444. url := hyper.BaseURL()
  445. link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
  446. currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
  447. } else if errors.As(err, &providerErr) {
  448. if providerErr.Message == "The requested model is not supported." {
  449. url := "https://github.com/settings/copilot/features"
  450. link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
  451. currentAssistant.AddFinish(
  452. message.FinishReasonError,
  453. "Copilot model not enabled",
  454. fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
  455. )
  456. } else {
  457. currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
  458. }
  459. } else if errors.As(err, &fantasyErr) {
  460. currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
  461. } else {
  462. currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
  463. }
  464. // Note: we use the parent context here because the genCtx has been
  465. // cancelled.
  466. updateErr := a.messages.Update(ctx, *currentAssistant)
  467. if updateErr != nil {
  468. return nil, updateErr
  469. }
  470. return nil, err
  471. }
  472. if shouldSummarize {
  473. a.activeRequests.Del(call.SessionID)
  474. if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
  475. return nil, summarizeErr
  476. }
  477. // If the agent wasn't done...
  478. if len(currentAssistant.ToolCalls()) > 0 {
  479. existing, ok := a.messageQueue.Get(call.SessionID)
  480. if !ok {
  481. existing = []SessionAgentCall{}
  482. }
  483. call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
  484. existing = append(existing, call)
  485. a.messageQueue.Set(call.SessionID, existing)
  486. }
  487. }
  488. // Release active request before processing queued messages.
  489. a.activeRequests.Del(call.SessionID)
  490. cancel()
  491. queuedMessages, ok := a.messageQueue.Get(call.SessionID)
  492. if !ok || len(queuedMessages) == 0 {
  493. return result, err
  494. }
  495. // There are queued messages restart the loop.
  496. firstQueuedMessage := queuedMessages[0]
  497. a.messageQueue.Set(call.SessionID, queuedMessages[1:])
  498. return a.Run(ctx, firstQueuedMessage)
  499. }
  500. func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
  501. if a.IsSessionBusy(sessionID) {
  502. return ErrSessionBusy
  503. }
  504. // Copy mutable fields under lock to avoid races with SetModels.
  505. largeModel := a.largeModel.Get()
  506. systemPromptPrefix := a.systemPromptPrefix.Get()
  507. currentSession, err := a.sessions.Get(ctx, sessionID)
  508. if err != nil {
  509. return fmt.Errorf("failed to get session: %w", err)
  510. }
  511. msgs, err := a.getSessionMessages(ctx, currentSession)
  512. if err != nil {
  513. return err
  514. }
  515. if len(msgs) == 0 {
  516. // Nothing to summarize.
  517. return nil
  518. }
  519. aiMsgs, _ := a.preparePrompt(msgs)
  520. genCtx, cancel := context.WithCancel(ctx)
  521. a.activeRequests.Set(sessionID, cancel)
  522. defer a.activeRequests.Del(sessionID)
  523. defer cancel()
  524. agent := fantasy.NewAgent(largeModel.Model,
  525. fantasy.WithSystemPrompt(string(summaryPrompt)),
  526. )
  527. summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  528. Role: message.Assistant,
  529. Model: largeModel.Model.Model(),
  530. Provider: largeModel.Model.Provider(),
  531. IsSummaryMessage: true,
  532. })
  533. if err != nil {
  534. return err
  535. }
  536. summaryPromptText := buildSummaryPrompt(currentSession.Todos)
  537. resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
  538. Prompt: summaryPromptText,
  539. Messages: aiMsgs,
  540. ProviderOptions: opts,
  541. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  542. prepared.Messages = options.Messages
  543. if systemPromptPrefix != "" {
  544. prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
  545. }
  546. return callContext, prepared, nil
  547. },
  548. OnReasoningDelta: func(id string, text string) error {
  549. summaryMessage.AppendReasoningContent(text)
  550. return a.messages.Update(genCtx, summaryMessage)
  551. },
  552. OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
  553. // Handle anthropic signature.
  554. if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
  555. if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
  556. summaryMessage.AppendReasoningSignature(signature.Signature)
  557. }
  558. }
  559. summaryMessage.FinishThinking()
  560. return a.messages.Update(genCtx, summaryMessage)
  561. },
  562. OnTextDelta: func(id, text string) error {
  563. summaryMessage.AppendContent(text)
  564. return a.messages.Update(genCtx, summaryMessage)
  565. },
  566. })
  567. if err != nil {
  568. isCancelErr := errors.Is(err, context.Canceled)
  569. if isCancelErr {
  570. // User cancelled summarize we need to remove the summary message.
  571. deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
  572. return deleteErr
  573. }
  574. return err
  575. }
  576. summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
  577. err = a.messages.Update(genCtx, summaryMessage)
  578. if err != nil {
  579. return err
  580. }
  581. var openrouterCost *float64
  582. for _, step := range resp.Steps {
  583. stepCost := a.openrouterCost(step.ProviderMetadata)
  584. if stepCost != nil {
  585. newCost := *stepCost
  586. if openrouterCost != nil {
  587. newCost += *openrouterCost
  588. }
  589. openrouterCost = &newCost
  590. }
  591. }
  592. a.updateSessionUsage(largeModel, &currentSession, resp.TotalUsage, openrouterCost)
  593. // Just in case, get just the last usage info.
  594. usage := resp.Response.Usage
  595. currentSession.SummaryMessageID = summaryMessage.ID
  596. currentSession.CompletionTokens = usage.OutputTokens
  597. currentSession.PromptTokens = 0
  598. _, err = a.sessions.Save(genCtx, currentSession)
  599. return err
  600. }
  601. func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
  602. if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
  603. return fantasy.ProviderOptions{}
  604. }
  605. return fantasy.ProviderOptions{
  606. anthropic.Name: &anthropic.ProviderCacheControlOptions{
  607. CacheControl: anthropic.CacheControl{Type: "ephemeral"},
  608. },
  609. bedrock.Name: &anthropic.ProviderCacheControlOptions{
  610. CacheControl: anthropic.CacheControl{Type: "ephemeral"},
  611. },
  612. }
  613. }
  614. func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
  615. parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
  616. var attachmentParts []message.ContentPart
  617. for _, attachment := range call.Attachments {
  618. attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
  619. }
  620. parts = append(parts, attachmentParts...)
  621. msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
  622. Role: message.User,
  623. Parts: parts,
  624. })
  625. if err != nil {
  626. return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
  627. }
  628. return msg, nil
  629. }
  630. func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
  631. var history []fantasy.Message
  632. if !a.isSubAgent {
  633. history = append(history, fantasy.NewUserMessage(
  634. fmt.Sprintf("<system_reminder>%s</system_reminder>",
  635. `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
  636. If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
  637. If not, please feel free to ignore. Again do not mention this message to the user.`,
  638. ),
  639. ))
  640. }
  641. for _, m := range msgs {
  642. if len(m.Parts) == 0 {
  643. continue
  644. }
  645. // Assistant message without content or tool calls (cancelled before it
  646. // returned anything).
  647. if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
  648. continue
  649. }
  650. history = append(history, m.ToAIMessage()...)
  651. }
  652. var files []fantasy.FilePart
  653. for _, attachment := range attachments {
  654. if attachment.IsText() {
  655. continue
  656. }
  657. files = append(files, fantasy.FilePart{
  658. Filename: attachment.FileName,
  659. Data: attachment.Content,
  660. MediaType: attachment.MimeType,
  661. })
  662. }
  663. return history, files
  664. }
  665. func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
  666. msgs, err := a.messages.List(ctx, session.ID)
  667. if err != nil {
  668. return nil, fmt.Errorf("failed to list messages: %w", err)
  669. }
  670. if session.SummaryMessageID != "" {
  671. summaryMsgIndex := -1
  672. for i, msg := range msgs {
  673. if msg.ID == session.SummaryMessageID {
  674. summaryMsgIndex = i
  675. break
  676. }
  677. }
  678. if summaryMsgIndex != -1 {
  679. msgs = msgs[summaryMsgIndex:]
  680. msgs[0].Role = message.User
  681. }
  682. }
  683. return msgs, nil
  684. }
  685. // generateTitle generates a session titled based on the initial prompt.
  686. func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
  687. if userPrompt == "" {
  688. return
  689. }
  690. smallModel := a.smallModel.Get()
  691. largeModel := a.largeModel.Get()
  692. systemPromptPrefix := a.systemPromptPrefix.Get()
  693. var maxOutputTokens int64 = 40
  694. if smallModel.CatwalkCfg.CanReason {
  695. maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
  696. }
  697. newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
  698. return fantasy.NewAgent(m,
  699. fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
  700. fantasy.WithMaxOutputTokens(tok),
  701. )
  702. }
  703. streamCall := fantasy.AgentStreamCall{
  704. Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
  705. PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  706. prepared.Messages = opts.Messages
  707. if systemPromptPrefix != "" {
  708. prepared.Messages = append([]fantasy.Message{
  709. fantasy.NewSystemMessage(systemPromptPrefix),
  710. }, prepared.Messages...)
  711. }
  712. return callCtx, prepared, nil
  713. },
  714. }
  715. // Use the small model to generate the title.
  716. model := smallModel
  717. agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
  718. resp, err := agent.Stream(ctx, streamCall)
  719. if err == nil {
  720. // We successfully generated a title with the small model.
  721. slog.Info("generated title with small model")
  722. } else {
  723. // It didn't work. Let's try with the big model.
  724. slog.Error("error generating title with small model; trying big model", "err", err)
  725. model = largeModel
  726. agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
  727. resp, err = agent.Stream(ctx, streamCall)
  728. if err == nil {
  729. slog.Info("generated title with large model")
  730. } else {
  731. // Welp, the large model didn't work either. Use the default
  732. // session name and return.
  733. slog.Error("error generating title with large model", "err", err)
  734. saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
  735. if saveErr != nil {
  736. slog.Error("failed to save session title and usage", "error", saveErr)
  737. }
  738. return
  739. }
  740. }
  741. if resp == nil {
  742. // Actually, we didn't get a response so we can't. Use the default
  743. // session name and return.
  744. slog.Error("response is nil; can't generate title")
  745. saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
  746. if saveErr != nil {
  747. slog.Error("failed to save session title and usage", "error", saveErr)
  748. }
  749. return
  750. }
  751. // Clean up title.
  752. var title string
  753. title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
  754. // Remove thinking tags if present.
  755. title = thinkTagRegex.ReplaceAllString(title, "")
  756. title = strings.TrimSpace(title)
  757. if title == "" {
  758. slog.Warn("empty title; using fallback")
  759. title = defaultSessionName
  760. }
  761. // Calculate usage and cost.
  762. var openrouterCost *float64
  763. for _, step := range resp.Steps {
  764. stepCost := a.openrouterCost(step.ProviderMetadata)
  765. if stepCost != nil {
  766. newCost := *stepCost
  767. if openrouterCost != nil {
  768. newCost += *openrouterCost
  769. }
  770. openrouterCost = &newCost
  771. }
  772. }
  773. modelConfig := model.CatwalkCfg
  774. cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
  775. modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
  776. modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
  777. modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
  778. // Use override cost if available (e.g., from OpenRouter).
  779. if openrouterCost != nil {
  780. cost = *openrouterCost
  781. }
  782. promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
  783. completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
  784. // Atomically update only title and usage fields to avoid overriding other
  785. // concurrent session updates.
  786. saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
  787. if saveErr != nil {
  788. slog.Error("failed to save session title and usage", "error", saveErr)
  789. return
  790. }
  791. }
  792. func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
  793. openrouterMetadata, ok := metadata[openrouter.Name]
  794. if !ok {
  795. return nil
  796. }
  797. opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
  798. if !ok {
  799. return nil
  800. }
  801. return &opts.Usage.Cost
  802. }
  803. func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
  804. modelConfig := model.CatwalkCfg
  805. cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
  806. modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
  807. modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
  808. modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
  809. a.eventTokensUsed(session.ID, model, usage, cost)
  810. if overrideCost != nil {
  811. session.Cost += *overrideCost
  812. } else {
  813. session.Cost += cost
  814. }
  815. session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
  816. session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
  817. }
  818. func (a *sessionAgent) Cancel(sessionID string) {
  819. // Cancel regular requests. Don't use Take() here - we need the entry to
  820. // remain in activeRequests so IsBusy() returns true until the goroutine
  821. // fully completes (including error handling that may access the DB).
  822. // The defer in processRequest will clean up the entry.
  823. if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
  824. slog.Info("Request cancellation initiated", "session_id", sessionID)
  825. cancel()
  826. }
  827. // Also check for summarize requests.
  828. if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
  829. slog.Info("Summarize cancellation initiated", "session_id", sessionID)
  830. cancel()
  831. }
  832. if a.QueuedPrompts(sessionID) > 0 {
  833. slog.Info("Clearing queued prompts", "session_id", sessionID)
  834. a.messageQueue.Del(sessionID)
  835. }
  836. }
  837. func (a *sessionAgent) ClearQueue(sessionID string) {
  838. if a.QueuedPrompts(sessionID) > 0 {
  839. slog.Info("Clearing queued prompts", "session_id", sessionID)
  840. a.messageQueue.Del(sessionID)
  841. }
  842. }
  843. func (a *sessionAgent) CancelAll() {
  844. if !a.IsBusy() {
  845. return
  846. }
  847. for key := range a.activeRequests.Seq2() {
  848. a.Cancel(key) // key is sessionID
  849. }
  850. timeout := time.After(5 * time.Second)
  851. for a.IsBusy() {
  852. select {
  853. case <-timeout:
  854. return
  855. default:
  856. time.Sleep(200 * time.Millisecond)
  857. }
  858. }
  859. }
  860. func (a *sessionAgent) IsBusy() bool {
  861. var busy bool
  862. for cancelFunc := range a.activeRequests.Seq() {
  863. if cancelFunc != nil {
  864. busy = true
  865. break
  866. }
  867. }
  868. return busy
  869. }
  870. func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
  871. _, busy := a.activeRequests.Get(sessionID)
  872. return busy
  873. }
  874. func (a *sessionAgent) QueuedPrompts(sessionID string) int {
  875. l, ok := a.messageQueue.Get(sessionID)
  876. if !ok {
  877. return 0
  878. }
  879. return len(l)
  880. }
  881. func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
  882. l, ok := a.messageQueue.Get(sessionID)
  883. if !ok {
  884. return nil
  885. }
  886. prompts := make([]string, len(l))
  887. for i, call := range l {
  888. prompts[i] = call.Prompt
  889. }
  890. return prompts
  891. }
  892. func (a *sessionAgent) SetModels(large Model, small Model) {
  893. a.largeModel.Set(large)
  894. a.smallModel.Set(small)
  895. }
  896. func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
  897. a.tools.SetSlice(tools)
  898. }
  899. func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
  900. a.systemPrompt.Set(systemPrompt)
  901. }
  902. func (a *sessionAgent) Model() Model {
  903. return a.largeModel.Get()
  904. }
  905. // convertToToolResult converts a fantasy tool result to a message tool result.
  906. func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
  907. baseResult := message.ToolResult{
  908. ToolCallID: result.ToolCallID,
  909. Name: result.ToolName,
  910. Metadata: result.ClientMetadata,
  911. }
  912. switch result.Result.GetType() {
  913. case fantasy.ToolResultContentTypeText:
  914. if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
  915. baseResult.Content = r.Text
  916. }
  917. case fantasy.ToolResultContentTypeError:
  918. if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
  919. baseResult.Content = r.Error.Error()
  920. baseResult.IsError = true
  921. }
  922. case fantasy.ToolResultContentTypeMedia:
  923. if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
  924. content := r.Text
  925. if content == "" {
  926. content = fmt.Sprintf("Loaded %s content", r.MediaType)
  927. }
  928. baseResult.Content = content
  929. baseResult.Data = r.Data
  930. baseResult.MIMEType = r.MediaType
  931. }
  932. }
  933. return baseResult
  934. }
  935. // workaroundProviderMediaLimitations converts media content in tool results to
  936. // user messages for providers that don't natively support images in tool results.
  937. //
  938. // Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
  939. // don't support sending images/media in tool result messages - they only accept
  940. // text in tool results. However, they DO support images in user messages.
  941. //
  942. // If we send media in tool results to these providers, the API returns an error.
  943. //
  944. // Solution: For these providers, we:
  945. // 1. Replace the media in the tool result with a text placeholder
  946. // 2. Inject a user message immediately after with the image as a file attachment
  947. // 3. This maintains the tool execution flow while working around API limitations
  948. //
  949. // Anthropic and Bedrock support images natively in tool results, so we skip
  950. // this workaround for them.
  951. //
  952. // Example transformation:
  953. //
  954. // BEFORE: [tool result: image data]
  955. // AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
  956. func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
  957. providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
  958. largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
  959. if providerSupportsMedia {
  960. return messages
  961. }
  962. convertedMessages := make([]fantasy.Message, 0, len(messages))
  963. for _, msg := range messages {
  964. if msg.Role != fantasy.MessageRoleTool {
  965. convertedMessages = append(convertedMessages, msg)
  966. continue
  967. }
  968. textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
  969. var mediaFiles []fantasy.FilePart
  970. for _, part := range msg.Content {
  971. toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
  972. if !ok {
  973. textParts = append(textParts, part)
  974. continue
  975. }
  976. if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
  977. decoded, err := base64.StdEncoding.DecodeString(media.Data)
  978. if err != nil {
  979. slog.Warn("failed to decode media data", "error", err)
  980. textParts = append(textParts, part)
  981. continue
  982. }
  983. mediaFiles = append(mediaFiles, fantasy.FilePart{
  984. Data: decoded,
  985. MediaType: media.MediaType,
  986. Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
  987. })
  988. textParts = append(textParts, fantasy.ToolResultPart{
  989. ToolCallID: toolResult.ToolCallID,
  990. Output: fantasy.ToolResultOutputContentText{
  991. Text: "[Image/media content loaded - see attached file]",
  992. },
  993. ProviderOptions: toolResult.ProviderOptions,
  994. })
  995. } else {
  996. textParts = append(textParts, part)
  997. }
  998. }
  999. convertedMessages = append(convertedMessages, fantasy.Message{
  1000. Role: fantasy.MessageRoleTool,
  1001. Content: textParts,
  1002. })
  1003. if len(mediaFiles) > 0 {
  1004. convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
  1005. "Here is the media content from the tool result:",
  1006. mediaFiles...,
  1007. ))
  1008. }
  1009. }
  1010. return convertedMessages
  1011. }
  1012. // buildSummaryPrompt constructs the prompt text for session summarization.
  1013. func buildSummaryPrompt(todos []session.Todo) string {
  1014. var sb strings.Builder
  1015. sb.WriteString("Provide a detailed summary of our conversation above.")
  1016. if len(todos) > 0 {
  1017. sb.WriteString("\n\n## Current Todo List\n\n")
  1018. for _, t := range todos {
  1019. fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
  1020. }
  1021. sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
  1022. sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
  1023. }
  1024. return sb.String()
  1025. }