agent.go 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056
  1. // Package agent is the core orchestration layer for Crush AI agents.
  2. //
  3. // It provides session-based AI agent functionality for managing
  4. // conversations, tool execution, and message handling. It coordinates
  5. // interactions between language models, messages, sessions, and tools while
  6. // handling features like automatic summarization, queuing, and token
  7. // management.
  8. package agent
  9. import (
  10. "cmp"
  11. "context"
  12. _ "embed"
  13. "encoding/base64"
  14. "errors"
  15. "fmt"
  16. "log/slog"
  17. "os"
  18. "strconv"
  19. "strings"
  20. "sync"
  21. "time"
  22. "charm.land/fantasy"
  23. "charm.land/fantasy/providers/anthropic"
  24. "charm.land/fantasy/providers/bedrock"
  25. "charm.land/fantasy/providers/google"
  26. "charm.land/fantasy/providers/openai"
  27. "charm.land/fantasy/providers/openrouter"
  28. "charm.land/lipgloss/v2"
  29. "github.com/charmbracelet/catwalk/pkg/catwalk"
  30. "github.com/charmbracelet/crush/internal/agent/hyper"
  31. "github.com/charmbracelet/crush/internal/agent/tools"
  32. "github.com/charmbracelet/crush/internal/config"
  33. "github.com/charmbracelet/crush/internal/csync"
  34. "github.com/charmbracelet/crush/internal/message"
  35. "github.com/charmbracelet/crush/internal/permission"
  36. "github.com/charmbracelet/crush/internal/session"
  37. "github.com/charmbracelet/crush/internal/stringext"
  38. )
  39. //go:embed templates/title.md
  40. var titlePrompt []byte
  41. //go:embed templates/summary.md
  42. var summaryPrompt []byte
  43. type SessionAgentCall struct {
  44. SessionID string
  45. Prompt string
  46. ProviderOptions fantasy.ProviderOptions
  47. Attachments []message.Attachment
  48. MaxOutputTokens int64
  49. Temperature *float64
  50. TopP *float64
  51. TopK *int64
  52. FrequencyPenalty *float64
  53. PresencePenalty *float64
  54. }
  55. type SessionAgent interface {
  56. Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  57. SetModels(large Model, small Model)
  58. SetTools(tools []fantasy.AgentTool)
  59. Cancel(sessionID string)
  60. CancelAll()
  61. IsSessionBusy(sessionID string) bool
  62. IsBusy() bool
  63. QueuedPrompts(sessionID string) int
  64. QueuedPromptsList(sessionID string) []string
  65. ClearQueue(sessionID string)
  66. Summarize(context.Context, string, fantasy.ProviderOptions) error
  67. Model() Model
  68. }
  69. type Model struct {
  70. Model fantasy.LanguageModel
  71. CatwalkCfg catwalk.Model
  72. ModelCfg config.SelectedModel
  73. }
  74. type sessionAgent struct {
  75. largeModel Model
  76. smallModel Model
  77. systemPromptPrefix string
  78. systemPrompt string
  79. isSubAgent bool
  80. tools []fantasy.AgentTool
  81. sessions session.Service
  82. messages message.Service
  83. disableAutoSummarize bool
  84. isYolo bool
  85. messageQueue *csync.Map[string, []SessionAgentCall]
  86. activeRequests *csync.Map[string, context.CancelFunc]
  87. }
  88. type SessionAgentOptions struct {
  89. LargeModel Model
  90. SmallModel Model
  91. SystemPromptPrefix string
  92. SystemPrompt string
  93. IsSubAgent bool
  94. DisableAutoSummarize bool
  95. IsYolo bool
  96. Sessions session.Service
  97. Messages message.Service
  98. Tools []fantasy.AgentTool
  99. }
  100. func NewSessionAgent(
  101. opts SessionAgentOptions,
  102. ) SessionAgent {
  103. return &sessionAgent{
  104. largeModel: opts.LargeModel,
  105. smallModel: opts.SmallModel,
  106. systemPromptPrefix: opts.SystemPromptPrefix,
  107. systemPrompt: opts.SystemPrompt,
  108. isSubAgent: opts.IsSubAgent,
  109. sessions: opts.Sessions,
  110. messages: opts.Messages,
  111. disableAutoSummarize: opts.DisableAutoSummarize,
  112. tools: opts.Tools,
  113. isYolo: opts.IsYolo,
  114. messageQueue: csync.NewMap[string, []SessionAgentCall](),
  115. activeRequests: csync.NewMap[string, context.CancelFunc](),
  116. }
  117. }
  118. func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
  119. if call.Prompt == "" {
  120. return nil, ErrEmptyPrompt
  121. }
  122. if call.SessionID == "" {
  123. return nil, ErrSessionMissing
  124. }
  125. // Queue the message if busy
  126. if a.IsSessionBusy(call.SessionID) {
  127. existing, ok := a.messageQueue.Get(call.SessionID)
  128. if !ok {
  129. existing = []SessionAgentCall{}
  130. }
  131. existing = append(existing, call)
  132. a.messageQueue.Set(call.SessionID, existing)
  133. return nil, nil
  134. }
  135. if len(a.tools) > 0 {
  136. // Add Anthropic caching to the last tool.
  137. a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
  138. }
  139. agent := fantasy.NewAgent(
  140. a.largeModel.Model,
  141. fantasy.WithSystemPrompt(a.systemPrompt),
  142. fantasy.WithTools(a.tools...),
  143. )
  144. sessionLock := sync.Mutex{}
  145. currentSession, err := a.sessions.Get(ctx, call.SessionID)
  146. if err != nil {
  147. return nil, fmt.Errorf("failed to get session: %w", err)
  148. }
  149. msgs, err := a.getSessionMessages(ctx, currentSession)
  150. if err != nil {
  151. return nil, fmt.Errorf("failed to get session messages: %w", err)
  152. }
  153. var wg sync.WaitGroup
  154. // Generate title if first message.
  155. if len(msgs) == 0 {
  156. wg.Go(func() {
  157. sessionLock.Lock()
  158. a.generateTitle(ctx, &currentSession, call.Prompt)
  159. sessionLock.Unlock()
  160. })
  161. }
  162. // Add the user message to the session.
  163. _, err = a.createUserMessage(ctx, call)
  164. if err != nil {
  165. return nil, err
  166. }
  167. // Add the session to the context.
  168. ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
  169. genCtx, cancel := context.WithCancel(ctx)
  170. a.activeRequests.Set(call.SessionID, cancel)
  171. defer cancel()
  172. defer a.activeRequests.Del(call.SessionID)
  173. history, files := a.preparePrompt(msgs, call.Attachments...)
  174. startTime := time.Now()
  175. a.eventPromptSent(call.SessionID)
  176. var currentAssistant *message.Message
  177. var shouldSummarize bool
  178. result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
  179. Prompt: call.Prompt,
  180. Files: files,
  181. Messages: history,
  182. ProviderOptions: call.ProviderOptions,
  183. MaxOutputTokens: &call.MaxOutputTokens,
  184. TopP: call.TopP,
  185. Temperature: call.Temperature,
  186. PresencePenalty: call.PresencePenalty,
  187. TopK: call.TopK,
  188. FrequencyPenalty: call.FrequencyPenalty,
  189. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  190. prepared.Messages = options.Messages
  191. for i := range prepared.Messages {
  192. prepared.Messages[i].ProviderOptions = nil
  193. }
  194. queuedCalls, _ := a.messageQueue.Get(call.SessionID)
  195. a.messageQueue.Del(call.SessionID)
  196. for _, queued := range queuedCalls {
  197. userMessage, createErr := a.createUserMessage(callContext, queued)
  198. if createErr != nil {
  199. return callContext, prepared, createErr
  200. }
  201. prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
  202. }
  203. prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
  204. lastSystemRoleInx := 0
  205. systemMessageUpdated := false
  206. for i, msg := range prepared.Messages {
  207. // Only add cache control to the last message.
  208. if msg.Role == fantasy.MessageRoleSystem {
  209. lastSystemRoleInx = i
  210. } else if !systemMessageUpdated {
  211. prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
  212. systemMessageUpdated = true
  213. }
  214. // Than add cache control to the last 2 messages.
  215. if i > len(prepared.Messages)-3 {
  216. prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
  217. }
  218. }
  219. if promptPrefix := a.promptPrefix(); promptPrefix != "" {
  220. prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
  221. }
  222. var assistantMsg message.Message
  223. assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
  224. Role: message.Assistant,
  225. Parts: []message.ContentPart{},
  226. Model: a.largeModel.ModelCfg.Model,
  227. Provider: a.largeModel.ModelCfg.Provider,
  228. })
  229. if err != nil {
  230. return callContext, prepared, err
  231. }
  232. callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
  233. callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
  234. callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
  235. currentAssistant = &assistantMsg
  236. return callContext, prepared, err
  237. },
  238. OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
  239. currentAssistant.AppendReasoningContent(reasoning.Text)
  240. return a.messages.Update(genCtx, *currentAssistant)
  241. },
  242. OnReasoningDelta: func(id string, text string) error {
  243. currentAssistant.AppendReasoningContent(text)
  244. return a.messages.Update(genCtx, *currentAssistant)
  245. },
  246. OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
  247. // handle anthropic signature
  248. if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
  249. if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
  250. currentAssistant.AppendReasoningSignature(reasoning.Signature)
  251. }
  252. }
  253. if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
  254. if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
  255. currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
  256. }
  257. }
  258. if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
  259. if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
  260. currentAssistant.SetReasoningResponsesData(reasoning)
  261. }
  262. }
  263. currentAssistant.FinishThinking()
  264. return a.messages.Update(genCtx, *currentAssistant)
  265. },
  266. OnTextDelta: func(id string, text string) error {
  267. // Strip leading newline from initial text content. This is is
  268. // particularly important in non-interactive mode where leading
  269. // newlines are very visible.
  270. if len(currentAssistant.Parts) == 0 {
  271. text = strings.TrimPrefix(text, "\n")
  272. }
  273. currentAssistant.AppendContent(text)
  274. return a.messages.Update(genCtx, *currentAssistant)
  275. },
  276. OnToolInputStart: func(id string, toolName string) error {
  277. toolCall := message.ToolCall{
  278. ID: id,
  279. Name: toolName,
  280. ProviderExecuted: false,
  281. Finished: false,
  282. }
  283. currentAssistant.AddToolCall(toolCall)
  284. return a.messages.Update(genCtx, *currentAssistant)
  285. },
  286. OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
  287. // TODO: implement
  288. },
  289. OnToolCall: func(tc fantasy.ToolCallContent) error {
  290. toolCall := message.ToolCall{
  291. ID: tc.ToolCallID,
  292. Name: tc.ToolName,
  293. Input: tc.Input,
  294. ProviderExecuted: false,
  295. Finished: true,
  296. }
  297. currentAssistant.AddToolCall(toolCall)
  298. return a.messages.Update(genCtx, *currentAssistant)
  299. },
  300. OnToolResult: func(result fantasy.ToolResultContent) error {
  301. toolResult := a.convertToToolResult(result)
  302. _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
  303. Role: message.Tool,
  304. Parts: []message.ContentPart{
  305. toolResult,
  306. },
  307. })
  308. return createMsgErr
  309. },
  310. OnStepFinish: func(stepResult fantasy.StepResult) error {
  311. finishReason := message.FinishReasonUnknown
  312. switch stepResult.FinishReason {
  313. case fantasy.FinishReasonLength:
  314. finishReason = message.FinishReasonMaxTokens
  315. case fantasy.FinishReasonStop:
  316. finishReason = message.FinishReasonEndTurn
  317. case fantasy.FinishReasonToolCalls:
  318. finishReason = message.FinishReasonToolUse
  319. }
  320. currentAssistant.AddFinish(finishReason, "", "")
  321. sessionLock.Lock()
  322. updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
  323. if getSessionErr != nil {
  324. sessionLock.Unlock()
  325. return getSessionErr
  326. }
  327. a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
  328. _, sessionErr := a.sessions.Save(genCtx, updatedSession)
  329. sessionLock.Unlock()
  330. if sessionErr != nil {
  331. return sessionErr
  332. }
  333. return a.messages.Update(genCtx, *currentAssistant)
  334. },
  335. StopWhen: []fantasy.StopCondition{
  336. func(_ []fantasy.StepResult) bool {
  337. cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
  338. tokens := currentSession.CompletionTokens + currentSession.PromptTokens
  339. remaining := cw - tokens
  340. var threshold int64
  341. if cw > 200_000 {
  342. threshold = 20_000
  343. } else {
  344. threshold = int64(float64(cw) * 0.2)
  345. }
  346. if (remaining <= threshold) && !a.disableAutoSummarize {
  347. shouldSummarize = true
  348. return true
  349. }
  350. return false
  351. },
  352. },
  353. })
  354. a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
  355. if err != nil {
  356. isCancelErr := errors.Is(err, context.Canceled)
  357. isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
  358. if currentAssistant == nil {
  359. return result, err
  360. }
  361. // Ensure we finish thinking on error to close the reasoning state.
  362. currentAssistant.FinishThinking()
  363. toolCalls := currentAssistant.ToolCalls()
  364. // INFO: we use the parent context here because the genCtx has been cancelled.
  365. msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
  366. if createErr != nil {
  367. return nil, createErr
  368. }
  369. for _, tc := range toolCalls {
  370. if !tc.Finished {
  371. tc.Finished = true
  372. tc.Input = "{}"
  373. currentAssistant.AddToolCall(tc)
  374. updateErr := a.messages.Update(ctx, *currentAssistant)
  375. if updateErr != nil {
  376. return nil, updateErr
  377. }
  378. }
  379. found := false
  380. for _, msg := range msgs {
  381. if msg.Role == message.Tool {
  382. for _, tr := range msg.ToolResults() {
  383. if tr.ToolCallID == tc.ID {
  384. found = true
  385. break
  386. }
  387. }
  388. }
  389. if found {
  390. break
  391. }
  392. }
  393. if found {
  394. continue
  395. }
  396. content := "There was an error while executing the tool"
  397. if isCancelErr {
  398. content = "Tool execution canceled by user"
  399. } else if isPermissionErr {
  400. content = "User denied permission"
  401. }
  402. toolResult := message.ToolResult{
  403. ToolCallID: tc.ID,
  404. Name: tc.Name,
  405. Content: content,
  406. IsError: true,
  407. }
  408. _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
  409. Role: message.Tool,
  410. Parts: []message.ContentPart{
  411. toolResult,
  412. },
  413. })
  414. if createErr != nil {
  415. return nil, createErr
  416. }
  417. }
  418. var fantasyErr *fantasy.Error
  419. var providerErr *fantasy.ProviderError
  420. const defaultTitle = "Provider Error"
  421. if isCancelErr {
  422. currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
  423. } else if isPermissionErr {
  424. currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
  425. } else if errors.Is(err, hyper.ErrNoCredits) {
  426. url := hyper.BaseURL()
  427. link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
  428. currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
  429. } else if errors.As(err, &providerErr) {
  430. if providerErr.Message == "The requested model is not supported." {
  431. url := "https://github.com/settings/copilot/features"
  432. link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
  433. currentAssistant.AddFinish(
  434. message.FinishReasonError,
  435. "Copilot model not enabled",
  436. fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait a minute before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
  437. )
  438. } else {
  439. currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
  440. }
  441. } else if errors.As(err, &fantasyErr) {
  442. currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
  443. } else {
  444. currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
  445. }
  446. // Note: we use the parent context here because the genCtx has been
  447. // cancelled.
  448. updateErr := a.messages.Update(ctx, *currentAssistant)
  449. if updateErr != nil {
  450. return nil, updateErr
  451. }
  452. return nil, err
  453. }
  454. wg.Wait()
  455. if shouldSummarize {
  456. a.activeRequests.Del(call.SessionID)
  457. if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
  458. return nil, summarizeErr
  459. }
  460. // If the agent wasn't done...
  461. if len(currentAssistant.ToolCalls()) > 0 {
  462. existing, ok := a.messageQueue.Get(call.SessionID)
  463. if !ok {
  464. existing = []SessionAgentCall{}
  465. }
  466. call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
  467. existing = append(existing, call)
  468. a.messageQueue.Set(call.SessionID, existing)
  469. }
  470. }
  471. // Release active request before processing queued messages.
  472. a.activeRequests.Del(call.SessionID)
  473. cancel()
  474. queuedMessages, ok := a.messageQueue.Get(call.SessionID)
  475. if !ok || len(queuedMessages) == 0 {
  476. return result, err
  477. }
  478. // There are queued messages restart the loop.
  479. firstQueuedMessage := queuedMessages[0]
  480. a.messageQueue.Set(call.SessionID, queuedMessages[1:])
  481. return a.Run(ctx, firstQueuedMessage)
  482. }
  483. func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
  484. if a.IsSessionBusy(sessionID) {
  485. return ErrSessionBusy
  486. }
  487. currentSession, err := a.sessions.Get(ctx, sessionID)
  488. if err != nil {
  489. return fmt.Errorf("failed to get session: %w", err)
  490. }
  491. msgs, err := a.getSessionMessages(ctx, currentSession)
  492. if err != nil {
  493. return err
  494. }
  495. if len(msgs) == 0 {
  496. // Nothing to summarize.
  497. return nil
  498. }
  499. aiMsgs, _ := a.preparePrompt(msgs)
  500. genCtx, cancel := context.WithCancel(ctx)
  501. a.activeRequests.Set(sessionID, cancel)
  502. defer a.activeRequests.Del(sessionID)
  503. defer cancel()
  504. agent := fantasy.NewAgent(a.largeModel.Model,
  505. fantasy.WithSystemPrompt(string(summaryPrompt)),
  506. )
  507. summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  508. Role: message.Assistant,
  509. Model: a.largeModel.Model.Model(),
  510. Provider: a.largeModel.Model.Provider(),
  511. IsSummaryMessage: true,
  512. })
  513. if err != nil {
  514. return err
  515. }
  516. summaryPromptText := "Provide a detailed summary of our conversation above."
  517. if len(currentSession.Todos) > 0 {
  518. summaryPromptText += "\n\n## Current Todo List\n\n"
  519. for _, t := range currentSession.Todos {
  520. summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
  521. }
  522. summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
  523. summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
  524. }
  525. resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
  526. Prompt: summaryPromptText,
  527. Messages: aiMsgs,
  528. ProviderOptions: opts,
  529. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  530. prepared.Messages = options.Messages
  531. if a.systemPromptPrefix != "" {
  532. prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
  533. }
  534. return callContext, prepared, nil
  535. },
  536. OnReasoningDelta: func(id string, text string) error {
  537. summaryMessage.AppendReasoningContent(text)
  538. return a.messages.Update(genCtx, summaryMessage)
  539. },
  540. OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
  541. // Handle anthropic signature.
  542. if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
  543. if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
  544. summaryMessage.AppendReasoningSignature(signature.Signature)
  545. }
  546. }
  547. summaryMessage.FinishThinking()
  548. return a.messages.Update(genCtx, summaryMessage)
  549. },
  550. OnTextDelta: func(id, text string) error {
  551. summaryMessage.AppendContent(text)
  552. return a.messages.Update(genCtx, summaryMessage)
  553. },
  554. })
  555. if err != nil {
  556. isCancelErr := errors.Is(err, context.Canceled)
  557. if isCancelErr {
  558. // User cancelled summarize we need to remove the summary message.
  559. deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
  560. return deleteErr
  561. }
  562. return err
  563. }
  564. summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
  565. err = a.messages.Update(genCtx, summaryMessage)
  566. if err != nil {
  567. return err
  568. }
  569. var openrouterCost *float64
  570. for _, step := range resp.Steps {
  571. stepCost := a.openrouterCost(step.ProviderMetadata)
  572. if stepCost != nil {
  573. newCost := *stepCost
  574. if openrouterCost != nil {
  575. newCost += *openrouterCost
  576. }
  577. openrouterCost = &newCost
  578. }
  579. }
  580. a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
  581. // Just in case, get just the last usage info.
  582. usage := resp.Response.Usage
  583. currentSession.SummaryMessageID = summaryMessage.ID
  584. currentSession.CompletionTokens = usage.OutputTokens
  585. currentSession.PromptTokens = 0
  586. _, err = a.sessions.Save(genCtx, currentSession)
  587. return err
  588. }
  589. func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
  590. if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
  591. return fantasy.ProviderOptions{}
  592. }
  593. return fantasy.ProviderOptions{
  594. anthropic.Name: &anthropic.ProviderCacheControlOptions{
  595. CacheControl: anthropic.CacheControl{Type: "ephemeral"},
  596. },
  597. bedrock.Name: &anthropic.ProviderCacheControlOptions{
  598. CacheControl: anthropic.CacheControl{Type: "ephemeral"},
  599. },
  600. }
  601. }
  602. func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
  603. var attachmentParts []message.ContentPart
  604. for _, attachment := range call.Attachments {
  605. attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
  606. }
  607. parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
  608. parts = append(parts, attachmentParts...)
  609. msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
  610. Role: message.User,
  611. Parts: parts,
  612. })
  613. if err != nil {
  614. return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
  615. }
  616. return msg, nil
  617. }
  618. func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
  619. var history []fantasy.Message
  620. if !a.isSubAgent {
  621. history = append(history, fantasy.NewUserMessage(
  622. fmt.Sprintf("<system_reminder>%s</system_reminder>",
  623. `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
  624. If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
  625. If not, please feel free to ignore. Again do not mention this message to the user.`,
  626. ),
  627. ))
  628. }
  629. for _, m := range msgs {
  630. if len(m.Parts) == 0 {
  631. continue
  632. }
  633. // Assistant message without content or tool calls (cancelled before it
  634. // returned anything).
  635. if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
  636. continue
  637. }
  638. history = append(history, m.ToAIMessage()...)
  639. }
  640. var files []fantasy.FilePart
  641. for _, attachment := range attachments {
  642. files = append(files, fantasy.FilePart{
  643. Filename: attachment.FileName,
  644. Data: attachment.Content,
  645. MediaType: attachment.MimeType,
  646. })
  647. }
  648. return history, files
  649. }
  650. func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
  651. msgs, err := a.messages.List(ctx, session.ID)
  652. if err != nil {
  653. return nil, fmt.Errorf("failed to list messages: %w", err)
  654. }
  655. if session.SummaryMessageID != "" {
  656. summaryMsgInex := -1
  657. for i, msg := range msgs {
  658. if msg.ID == session.SummaryMessageID {
  659. summaryMsgInex = i
  660. break
  661. }
  662. }
  663. if summaryMsgInex != -1 {
  664. msgs = msgs[summaryMsgInex:]
  665. msgs[0].Role = message.User
  666. }
  667. }
  668. return msgs, nil
  669. }
  670. func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
  671. if prompt == "" {
  672. return
  673. }
  674. var maxOutput int64 = 40
  675. if a.smallModel.CatwalkCfg.CanReason {
  676. maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
  677. }
  678. agent := fantasy.NewAgent(a.smallModel.Model,
  679. fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
  680. fantasy.WithMaxOutputTokens(maxOutput),
  681. )
  682. resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
  683. Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
  684. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  685. prepared.Messages = options.Messages
  686. if a.systemPromptPrefix != "" {
  687. prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
  688. }
  689. return callContext, prepared, nil
  690. },
  691. })
  692. if err != nil {
  693. slog.Error("error generating title", "err", err)
  694. return
  695. }
  696. title := resp.Response.Content.Text()
  697. title = strings.ReplaceAll(title, "\n", " ")
  698. // Remove thinking tags if present.
  699. if idx := strings.Index(title, "</think>"); idx > 0 {
  700. title = title[idx+len("</think>"):]
  701. }
  702. title = strings.TrimSpace(title)
  703. if title == "" {
  704. slog.Warn("failed to generate title", "warn", "empty title")
  705. return
  706. }
  707. session.Title = title
  708. var openrouterCost *float64
  709. for _, step := range resp.Steps {
  710. stepCost := a.openrouterCost(step.ProviderMetadata)
  711. if stepCost != nil {
  712. newCost := *stepCost
  713. if openrouterCost != nil {
  714. newCost += *openrouterCost
  715. }
  716. openrouterCost = &newCost
  717. }
  718. }
  719. a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
  720. _, saveErr := a.sessions.Save(ctx, *session)
  721. if saveErr != nil {
  722. slog.Error("failed to save session title & usage", "error", saveErr)
  723. return
  724. }
  725. }
  726. func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
  727. openrouterMetadata, ok := metadata[openrouter.Name]
  728. if !ok {
  729. return nil
  730. }
  731. opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
  732. if !ok {
  733. return nil
  734. }
  735. return &opts.Usage.Cost
  736. }
  737. func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
  738. modelConfig := model.CatwalkCfg
  739. cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
  740. modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
  741. modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
  742. modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
  743. if a.isClaudeCode() {
  744. cost = 0
  745. }
  746. a.eventTokensUsed(session.ID, model, usage, cost)
  747. if overrideCost != nil {
  748. session.Cost += *overrideCost
  749. } else {
  750. session.Cost += cost
  751. }
  752. session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
  753. session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
  754. }
  755. func (a *sessionAgent) Cancel(sessionID string) {
  756. // Cancel regular requests.
  757. if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
  758. slog.Info("Request cancellation initiated", "session_id", sessionID)
  759. cancel()
  760. }
  761. // Also check for summarize requests.
  762. if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
  763. slog.Info("Summarize cancellation initiated", "session_id", sessionID)
  764. cancel()
  765. }
  766. if a.QueuedPrompts(sessionID) > 0 {
  767. slog.Info("Clearing queued prompts", "session_id", sessionID)
  768. a.messageQueue.Del(sessionID)
  769. }
  770. }
  771. func (a *sessionAgent) ClearQueue(sessionID string) {
  772. if a.QueuedPrompts(sessionID) > 0 {
  773. slog.Info("Clearing queued prompts", "session_id", sessionID)
  774. a.messageQueue.Del(sessionID)
  775. }
  776. }
  777. func (a *sessionAgent) CancelAll() {
  778. if !a.IsBusy() {
  779. return
  780. }
  781. for key := range a.activeRequests.Seq2() {
  782. a.Cancel(key) // key is sessionID
  783. }
  784. timeout := time.After(5 * time.Second)
  785. for a.IsBusy() {
  786. select {
  787. case <-timeout:
  788. return
  789. default:
  790. time.Sleep(200 * time.Millisecond)
  791. }
  792. }
  793. }
  794. func (a *sessionAgent) IsBusy() bool {
  795. var busy bool
  796. for cancelFunc := range a.activeRequests.Seq() {
  797. if cancelFunc != nil {
  798. busy = true
  799. break
  800. }
  801. }
  802. return busy
  803. }
  804. func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
  805. _, busy := a.activeRequests.Get(sessionID)
  806. return busy
  807. }
  808. func (a *sessionAgent) QueuedPrompts(sessionID string) int {
  809. l, ok := a.messageQueue.Get(sessionID)
  810. if !ok {
  811. return 0
  812. }
  813. return len(l)
  814. }
  815. func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
  816. l, ok := a.messageQueue.Get(sessionID)
  817. if !ok {
  818. return nil
  819. }
  820. prompts := make([]string, len(l))
  821. for i, call := range l {
  822. prompts[i] = call.Prompt
  823. }
  824. return prompts
  825. }
  826. func (a *sessionAgent) SetModels(large Model, small Model) {
  827. a.largeModel = large
  828. a.smallModel = small
  829. }
  830. func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
  831. a.tools = tools
  832. }
  833. func (a *sessionAgent) Model() Model {
  834. return a.largeModel
  835. }
  836. func (a *sessionAgent) promptPrefix() string {
  837. if a.isClaudeCode() {
  838. return "You are Claude Code, Anthropic's official CLI for Claude."
  839. }
  840. return a.systemPromptPrefix
  841. }
  842. func (a *sessionAgent) isClaudeCode() bool {
  843. cfg := config.Get()
  844. pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
  845. return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
  846. }
  847. // convertToToolResult converts a fantasy tool result to a message tool result.
  848. func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
  849. baseResult := message.ToolResult{
  850. ToolCallID: result.ToolCallID,
  851. Name: result.ToolName,
  852. Metadata: result.ClientMetadata,
  853. }
  854. switch result.Result.GetType() {
  855. case fantasy.ToolResultContentTypeText:
  856. if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
  857. baseResult.Content = r.Text
  858. }
  859. case fantasy.ToolResultContentTypeError:
  860. if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
  861. baseResult.Content = r.Error.Error()
  862. baseResult.IsError = true
  863. }
  864. case fantasy.ToolResultContentTypeMedia:
  865. if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
  866. content := r.Text
  867. if content == "" {
  868. content = fmt.Sprintf("Loaded %s content", r.MediaType)
  869. }
  870. baseResult.Content = content
  871. baseResult.Data = r.Data
  872. baseResult.MIMEType = r.MediaType
  873. }
  874. }
  875. return baseResult
  876. }
  877. // workaroundProviderMediaLimitations converts media content in tool results to
  878. // user messages for providers that don't natively support images in tool results.
  879. //
  880. // Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
  881. // don't support sending images/media in tool result messages - they only accept
  882. // text in tool results. However, they DO support images in user messages.
  883. //
  884. // If we send media in tool results to these providers, the API returns an error.
  885. //
  886. // Solution: For these providers, we:
  887. // 1. Replace the media in the tool result with a text placeholder
  888. // 2. Inject a user message immediately after with the image as a file attachment
  889. // 3. This maintains the tool execution flow while working around API limitations
  890. //
  891. // Anthropic and Bedrock support images natively in tool results, so we skip
  892. // this workaround for them.
  893. //
  894. // Example transformation:
  895. //
  896. // BEFORE: [tool result: image data]
  897. // AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
  898. func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
  899. providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
  900. a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
  901. if providerSupportsMedia {
  902. return messages
  903. }
  904. convertedMessages := make([]fantasy.Message, 0, len(messages))
  905. for _, msg := range messages {
  906. if msg.Role != fantasy.MessageRoleTool {
  907. convertedMessages = append(convertedMessages, msg)
  908. continue
  909. }
  910. textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
  911. var mediaFiles []fantasy.FilePart
  912. for _, part := range msg.Content {
  913. toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
  914. if !ok {
  915. textParts = append(textParts, part)
  916. continue
  917. }
  918. if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
  919. decoded, err := base64.StdEncoding.DecodeString(media.Data)
  920. if err != nil {
  921. slog.Warn("failed to decode media data", "error", err)
  922. textParts = append(textParts, part)
  923. continue
  924. }
  925. mediaFiles = append(mediaFiles, fantasy.FilePart{
  926. Data: decoded,
  927. MediaType: media.MediaType,
  928. Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
  929. })
  930. textParts = append(textParts, fantasy.ToolResultPart{
  931. ToolCallID: toolResult.ToolCallID,
  932. Output: fantasy.ToolResultOutputContentText{
  933. Text: "[Image/media content loaded - see attached file]",
  934. },
  935. ProviderOptions: toolResult.ProviderOptions,
  936. })
  937. } else {
  938. textParts = append(textParts, part)
  939. }
  940. }
  941. convertedMessages = append(convertedMessages, fantasy.Message{
  942. Role: fantasy.MessageRoleTool,
  943. Content: textParts,
  944. })
  945. if len(mediaFiles) > 0 {
  946. convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
  947. "Here is the media content from the tool result:",
  948. mediaFiles...,
  949. ))
  950. }
  951. }
  952. return convertedMessages
  953. }