agent.go 32 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040
  1. // Package agent is the core orchestration layer for Crush AI agents.
  2. //
  3. // It provides session-based AI agent functionality for managing
  4. // conversations, tool execution, and message handling. It coordinates
  5. // interactions between language models, messages, sessions, and tools while
  6. // handling features like automatic summarization, queuing, and token
  7. // management.
  8. package agent
  9. import (
  10. "cmp"
  11. "context"
  12. _ "embed"
  13. "encoding/base64"
  14. "errors"
  15. "fmt"
  16. "log/slog"
  17. "os"
  18. "strconv"
  19. "strings"
  20. "sync"
  21. "time"
  22. "charm.land/fantasy"
  23. "charm.land/fantasy/providers/anthropic"
  24. "charm.land/fantasy/providers/bedrock"
  25. "charm.land/fantasy/providers/google"
  26. "charm.land/fantasy/providers/openai"
  27. "charm.land/fantasy/providers/openrouter"
  28. "github.com/charmbracelet/catwalk/pkg/catwalk"
  29. "github.com/charmbracelet/crush/internal/agent/tools"
  30. "github.com/charmbracelet/crush/internal/config"
  31. "github.com/charmbracelet/crush/internal/csync"
  32. "github.com/charmbracelet/crush/internal/message"
  33. "github.com/charmbracelet/crush/internal/permission"
  34. "github.com/charmbracelet/crush/internal/session"
  35. "github.com/charmbracelet/crush/internal/stringext"
  36. )
  37. //go:embed templates/title.md
  38. var titlePrompt []byte
  39. //go:embed templates/summary.md
  40. var summaryPrompt []byte
  41. type SessionAgentCall struct {
  42. SessionID string
  43. Prompt string
  44. ProviderOptions fantasy.ProviderOptions
  45. Attachments []message.Attachment
  46. MaxOutputTokens int64
  47. Temperature *float64
  48. TopP *float64
  49. TopK *int64
  50. FrequencyPenalty *float64
  51. PresencePenalty *float64
  52. }
  53. type SessionAgent interface {
  54. Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  55. SetModels(large Model, small Model)
  56. SetTools(tools []fantasy.AgentTool)
  57. Cancel(sessionID string)
  58. CancelAll()
  59. IsSessionBusy(sessionID string) bool
  60. IsBusy() bool
  61. QueuedPrompts(sessionID string) int
  62. QueuedPromptsList(sessionID string) []string
  63. ClearQueue(sessionID string)
  64. Summarize(context.Context, string, fantasy.ProviderOptions) error
  65. Model() Model
  66. }
  67. type Model struct {
  68. Model fantasy.LanguageModel
  69. CatwalkCfg catwalk.Model
  70. ModelCfg config.SelectedModel
  71. }
  72. type sessionAgent struct {
  73. largeModel Model
  74. smallModel Model
  75. systemPromptPrefix string
  76. systemPrompt string
  77. isSubAgent bool
  78. tools []fantasy.AgentTool
  79. sessions session.Service
  80. messages message.Service
  81. disableAutoSummarize bool
  82. isYolo bool
  83. messageQueue *csync.Map[string, []SessionAgentCall]
  84. activeRequests *csync.Map[string, context.CancelFunc]
  85. }
  86. type SessionAgentOptions struct {
  87. LargeModel Model
  88. SmallModel Model
  89. SystemPromptPrefix string
  90. SystemPrompt string
  91. IsSubAgent bool
  92. DisableAutoSummarize bool
  93. IsYolo bool
  94. Sessions session.Service
  95. Messages message.Service
  96. Tools []fantasy.AgentTool
  97. }
  98. func NewSessionAgent(
  99. opts SessionAgentOptions,
  100. ) SessionAgent {
  101. return &sessionAgent{
  102. largeModel: opts.LargeModel,
  103. smallModel: opts.SmallModel,
  104. systemPromptPrefix: opts.SystemPromptPrefix,
  105. systemPrompt: opts.SystemPrompt,
  106. isSubAgent: opts.IsSubAgent,
  107. sessions: opts.Sessions,
  108. messages: opts.Messages,
  109. disableAutoSummarize: opts.DisableAutoSummarize,
  110. tools: opts.Tools,
  111. isYolo: opts.IsYolo,
  112. messageQueue: csync.NewMap[string, []SessionAgentCall](),
  113. activeRequests: csync.NewMap[string, context.CancelFunc](),
  114. }
  115. }
  116. func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
  117. if call.Prompt == "" {
  118. return nil, ErrEmptyPrompt
  119. }
  120. if call.SessionID == "" {
  121. return nil, ErrSessionMissing
  122. }
  123. // Queue the message if busy
  124. if a.IsSessionBusy(call.SessionID) {
  125. existing, ok := a.messageQueue.Get(call.SessionID)
  126. if !ok {
  127. existing = []SessionAgentCall{}
  128. }
  129. existing = append(existing, call)
  130. a.messageQueue.Set(call.SessionID, existing)
  131. return nil, nil
  132. }
  133. if len(a.tools) > 0 {
  134. // Add Anthropic caching to the last tool.
  135. a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
  136. }
  137. agent := fantasy.NewAgent(
  138. a.largeModel.Model,
  139. fantasy.WithSystemPrompt(a.systemPrompt),
  140. fantasy.WithTools(a.tools...),
  141. )
  142. sessionLock := sync.Mutex{}
  143. currentSession, err := a.sessions.Get(ctx, call.SessionID)
  144. if err != nil {
  145. return nil, fmt.Errorf("failed to get session: %w", err)
  146. }
  147. msgs, err := a.getSessionMessages(ctx, currentSession)
  148. if err != nil {
  149. return nil, fmt.Errorf("failed to get session messages: %w", err)
  150. }
  151. var wg sync.WaitGroup
  152. // Generate title if first message.
  153. if len(msgs) == 0 {
  154. wg.Go(func() {
  155. sessionLock.Lock()
  156. a.generateTitle(ctx, &currentSession, call.Prompt)
  157. sessionLock.Unlock()
  158. })
  159. }
  160. // Add the user message to the session.
  161. _, err = a.createUserMessage(ctx, call)
  162. if err != nil {
  163. return nil, err
  164. }
  165. // Add the session to the context.
  166. ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
  167. genCtx, cancel := context.WithCancel(ctx)
  168. a.activeRequests.Set(call.SessionID, cancel)
  169. defer cancel()
  170. defer a.activeRequests.Del(call.SessionID)
  171. history, files := a.preparePrompt(msgs, call.Attachments...)
  172. startTime := time.Now()
  173. a.eventPromptSent(call.SessionID)
  174. var currentAssistant *message.Message
  175. var shouldSummarize bool
  176. result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
  177. Prompt: call.Prompt,
  178. Files: files,
  179. Messages: history,
  180. ProviderOptions: call.ProviderOptions,
  181. MaxOutputTokens: &call.MaxOutputTokens,
  182. TopP: call.TopP,
  183. Temperature: call.Temperature,
  184. PresencePenalty: call.PresencePenalty,
  185. TopK: call.TopK,
  186. FrequencyPenalty: call.FrequencyPenalty,
  187. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  188. prepared.Messages = options.Messages
  189. for i := range prepared.Messages {
  190. prepared.Messages[i].ProviderOptions = nil
  191. }
  192. queuedCalls, _ := a.messageQueue.Get(call.SessionID)
  193. a.messageQueue.Del(call.SessionID)
  194. for _, queued := range queuedCalls {
  195. userMessage, createErr := a.createUserMessage(callContext, queued)
  196. if createErr != nil {
  197. return callContext, prepared, createErr
  198. }
  199. prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
  200. }
  201. prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
  202. lastSystemRoleInx := 0
  203. systemMessageUpdated := false
  204. for i, msg := range prepared.Messages {
  205. // Only add cache control to the last message.
  206. if msg.Role == fantasy.MessageRoleSystem {
  207. lastSystemRoleInx = i
  208. } else if !systemMessageUpdated {
  209. prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
  210. systemMessageUpdated = true
  211. }
  212. // Than add cache control to the last 2 messages.
  213. if i > len(prepared.Messages)-3 {
  214. prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
  215. }
  216. }
  217. if promptPrefix := a.promptPrefix(); promptPrefix != "" {
  218. prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
  219. }
  220. var assistantMsg message.Message
  221. assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
  222. Role: message.Assistant,
  223. Parts: []message.ContentPart{},
  224. Model: a.largeModel.ModelCfg.Model,
  225. Provider: a.largeModel.ModelCfg.Provider,
  226. })
  227. if err != nil {
  228. return callContext, prepared, err
  229. }
  230. callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
  231. callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
  232. callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
  233. currentAssistant = &assistantMsg
  234. return callContext, prepared, err
  235. },
  236. OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
  237. currentAssistant.AppendReasoningContent(reasoning.Text)
  238. return a.messages.Update(genCtx, *currentAssistant)
  239. },
  240. OnReasoningDelta: func(id string, text string) error {
  241. currentAssistant.AppendReasoningContent(text)
  242. return a.messages.Update(genCtx, *currentAssistant)
  243. },
  244. OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
  245. // handle anthropic signature
  246. if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
  247. if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
  248. currentAssistant.AppendReasoningSignature(reasoning.Signature)
  249. }
  250. }
  251. if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
  252. if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
  253. currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
  254. }
  255. }
  256. if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
  257. if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
  258. currentAssistant.SetReasoningResponsesData(reasoning)
  259. }
  260. }
  261. currentAssistant.FinishThinking()
  262. return a.messages.Update(genCtx, *currentAssistant)
  263. },
  264. OnTextDelta: func(id string, text string) error {
  265. // Strip leading newline from initial text content. This is is
  266. // particularly important in non-interactive mode where leading
  267. // newlines are very visible.
  268. if len(currentAssistant.Parts) == 0 {
  269. text = strings.TrimPrefix(text, "\n")
  270. }
  271. currentAssistant.AppendContent(text)
  272. return a.messages.Update(genCtx, *currentAssistant)
  273. },
  274. OnToolInputStart: func(id string, toolName string) error {
  275. toolCall := message.ToolCall{
  276. ID: id,
  277. Name: toolName,
  278. ProviderExecuted: false,
  279. Finished: false,
  280. }
  281. currentAssistant.AddToolCall(toolCall)
  282. return a.messages.Update(genCtx, *currentAssistant)
  283. },
  284. OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
  285. // TODO: implement
  286. },
  287. OnToolCall: func(tc fantasy.ToolCallContent) error {
  288. toolCall := message.ToolCall{
  289. ID: tc.ToolCallID,
  290. Name: tc.ToolName,
  291. Input: tc.Input,
  292. ProviderExecuted: false,
  293. Finished: true,
  294. }
  295. currentAssistant.AddToolCall(toolCall)
  296. return a.messages.Update(genCtx, *currentAssistant)
  297. },
  298. OnToolResult: func(result fantasy.ToolResultContent) error {
  299. toolResult := a.convertToToolResult(result)
  300. _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
  301. Role: message.Tool,
  302. Parts: []message.ContentPart{
  303. toolResult,
  304. },
  305. })
  306. return createMsgErr
  307. },
  308. OnStepFinish: func(stepResult fantasy.StepResult) error {
  309. finishReason := message.FinishReasonUnknown
  310. switch stepResult.FinishReason {
  311. case fantasy.FinishReasonLength:
  312. finishReason = message.FinishReasonMaxTokens
  313. case fantasy.FinishReasonStop:
  314. finishReason = message.FinishReasonEndTurn
  315. case fantasy.FinishReasonToolCalls:
  316. finishReason = message.FinishReasonToolUse
  317. }
  318. currentAssistant.AddFinish(finishReason, "", "")
  319. sessionLock.Lock()
  320. updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
  321. if getSessionErr != nil {
  322. sessionLock.Unlock()
  323. return getSessionErr
  324. }
  325. a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
  326. _, sessionErr := a.sessions.Save(genCtx, updatedSession)
  327. sessionLock.Unlock()
  328. if sessionErr != nil {
  329. return sessionErr
  330. }
  331. return a.messages.Update(genCtx, *currentAssistant)
  332. },
  333. StopWhen: []fantasy.StopCondition{
  334. func(_ []fantasy.StepResult) bool {
  335. cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
  336. tokens := currentSession.CompletionTokens + currentSession.PromptTokens
  337. remaining := cw - tokens
  338. var threshold int64
  339. if cw > 200_000 {
  340. threshold = 20_000
  341. } else {
  342. threshold = int64(float64(cw) * 0.2)
  343. }
  344. if (remaining <= threshold) && !a.disableAutoSummarize {
  345. shouldSummarize = true
  346. return true
  347. }
  348. return false
  349. },
  350. },
  351. })
  352. a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
  353. if err != nil {
  354. isCancelErr := errors.Is(err, context.Canceled)
  355. isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
  356. if currentAssistant == nil {
  357. return result, err
  358. }
  359. // Ensure we finish thinking on error to close the reasoning state.
  360. currentAssistant.FinishThinking()
  361. toolCalls := currentAssistant.ToolCalls()
  362. // INFO: we use the parent context here because the genCtx has been cancelled.
  363. msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
  364. if createErr != nil {
  365. return nil, createErr
  366. }
  367. for _, tc := range toolCalls {
  368. if !tc.Finished {
  369. tc.Finished = true
  370. tc.Input = "{}"
  371. currentAssistant.AddToolCall(tc)
  372. updateErr := a.messages.Update(ctx, *currentAssistant)
  373. if updateErr != nil {
  374. return nil, updateErr
  375. }
  376. }
  377. found := false
  378. for _, msg := range msgs {
  379. if msg.Role == message.Tool {
  380. for _, tr := range msg.ToolResults() {
  381. if tr.ToolCallID == tc.ID {
  382. found = true
  383. break
  384. }
  385. }
  386. }
  387. if found {
  388. break
  389. }
  390. }
  391. if found {
  392. continue
  393. }
  394. content := "There was an error while executing the tool"
  395. if isCancelErr {
  396. content = "Tool execution canceled by user"
  397. } else if isPermissionErr {
  398. content = "User denied permission"
  399. }
  400. toolResult := message.ToolResult{
  401. ToolCallID: tc.ID,
  402. Name: tc.Name,
  403. Content: content,
  404. IsError: true,
  405. }
  406. _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
  407. Role: message.Tool,
  408. Parts: []message.ContentPart{
  409. toolResult,
  410. },
  411. })
  412. if createErr != nil {
  413. return nil, createErr
  414. }
  415. }
  416. var fantasyErr *fantasy.Error
  417. var providerErr *fantasy.ProviderError
  418. const defaultTitle = "Provider Error"
  419. if isCancelErr {
  420. currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
  421. } else if isPermissionErr {
  422. currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
  423. } else if errors.As(err, &providerErr) {
  424. currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
  425. } else if errors.As(err, &fantasyErr) {
  426. currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
  427. } else {
  428. currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
  429. }
  430. // Note: we use the parent context here because the genCtx has been
  431. // cancelled.
  432. updateErr := a.messages.Update(ctx, *currentAssistant)
  433. if updateErr != nil {
  434. return nil, updateErr
  435. }
  436. return nil, err
  437. }
  438. wg.Wait()
  439. if shouldSummarize {
  440. a.activeRequests.Del(call.SessionID)
  441. if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
  442. return nil, summarizeErr
  443. }
  444. // If the agent wasn't done...
  445. if len(currentAssistant.ToolCalls()) > 0 {
  446. existing, ok := a.messageQueue.Get(call.SessionID)
  447. if !ok {
  448. existing = []SessionAgentCall{}
  449. }
  450. call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
  451. existing = append(existing, call)
  452. a.messageQueue.Set(call.SessionID, existing)
  453. }
  454. }
  455. // Release active request before processing queued messages.
  456. a.activeRequests.Del(call.SessionID)
  457. cancel()
  458. queuedMessages, ok := a.messageQueue.Get(call.SessionID)
  459. if !ok || len(queuedMessages) == 0 {
  460. return result, err
  461. }
  462. // There are queued messages restart the loop.
  463. firstQueuedMessage := queuedMessages[0]
  464. a.messageQueue.Set(call.SessionID, queuedMessages[1:])
  465. return a.Run(ctx, firstQueuedMessage)
  466. }
  467. func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
  468. if a.IsSessionBusy(sessionID) {
  469. return ErrSessionBusy
  470. }
  471. currentSession, err := a.sessions.Get(ctx, sessionID)
  472. if err != nil {
  473. return fmt.Errorf("failed to get session: %w", err)
  474. }
  475. msgs, err := a.getSessionMessages(ctx, currentSession)
  476. if err != nil {
  477. return err
  478. }
  479. if len(msgs) == 0 {
  480. // Nothing to summarize.
  481. return nil
  482. }
  483. aiMsgs, _ := a.preparePrompt(msgs)
  484. genCtx, cancel := context.WithCancel(ctx)
  485. a.activeRequests.Set(sessionID, cancel)
  486. defer a.activeRequests.Del(sessionID)
  487. defer cancel()
  488. agent := fantasy.NewAgent(a.largeModel.Model,
  489. fantasy.WithSystemPrompt(string(summaryPrompt)),
  490. )
  491. summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  492. Role: message.Assistant,
  493. Model: a.largeModel.Model.Model(),
  494. Provider: a.largeModel.Model.Provider(),
  495. IsSummaryMessage: true,
  496. })
  497. if err != nil {
  498. return err
  499. }
  500. summaryPromptText := "Provide a detailed summary of our conversation above."
  501. if len(currentSession.Todos) > 0 {
  502. summaryPromptText += "\n\n## Current Todo List\n\n"
  503. for _, t := range currentSession.Todos {
  504. summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
  505. }
  506. summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
  507. summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
  508. }
  509. resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
  510. Prompt: summaryPromptText,
  511. Messages: aiMsgs,
  512. ProviderOptions: opts,
  513. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  514. prepared.Messages = options.Messages
  515. if a.systemPromptPrefix != "" {
  516. prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
  517. }
  518. return callContext, prepared, nil
  519. },
  520. OnReasoningDelta: func(id string, text string) error {
  521. summaryMessage.AppendReasoningContent(text)
  522. return a.messages.Update(genCtx, summaryMessage)
  523. },
  524. OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
  525. // Handle anthropic signature.
  526. if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
  527. if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
  528. summaryMessage.AppendReasoningSignature(signature.Signature)
  529. }
  530. }
  531. summaryMessage.FinishThinking()
  532. return a.messages.Update(genCtx, summaryMessage)
  533. },
  534. OnTextDelta: func(id, text string) error {
  535. summaryMessage.AppendContent(text)
  536. return a.messages.Update(genCtx, summaryMessage)
  537. },
  538. })
  539. if err != nil {
  540. isCancelErr := errors.Is(err, context.Canceled)
  541. if isCancelErr {
  542. // User cancelled summarize we need to remove the summary message.
  543. deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
  544. return deleteErr
  545. }
  546. return err
  547. }
  548. summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
  549. err = a.messages.Update(genCtx, summaryMessage)
  550. if err != nil {
  551. return err
  552. }
  553. var openrouterCost *float64
  554. for _, step := range resp.Steps {
  555. stepCost := a.openrouterCost(step.ProviderMetadata)
  556. if stepCost != nil {
  557. newCost := *stepCost
  558. if openrouterCost != nil {
  559. newCost += *openrouterCost
  560. }
  561. openrouterCost = &newCost
  562. }
  563. }
  564. a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
  565. // Just in case, get just the last usage info.
  566. usage := resp.Response.Usage
  567. currentSession.SummaryMessageID = summaryMessage.ID
  568. currentSession.CompletionTokens = usage.OutputTokens
  569. currentSession.PromptTokens = 0
  570. _, err = a.sessions.Save(genCtx, currentSession)
  571. return err
  572. }
  573. func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
  574. if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
  575. return fantasy.ProviderOptions{}
  576. }
  577. return fantasy.ProviderOptions{
  578. anthropic.Name: &anthropic.ProviderCacheControlOptions{
  579. CacheControl: anthropic.CacheControl{Type: "ephemeral"},
  580. },
  581. bedrock.Name: &anthropic.ProviderCacheControlOptions{
  582. CacheControl: anthropic.CacheControl{Type: "ephemeral"},
  583. },
  584. }
  585. }
  586. func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
  587. var attachmentParts []message.ContentPart
  588. for _, attachment := range call.Attachments {
  589. attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
  590. }
  591. parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
  592. parts = append(parts, attachmentParts...)
  593. msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
  594. Role: message.User,
  595. Parts: parts,
  596. })
  597. if err != nil {
  598. return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
  599. }
  600. return msg, nil
  601. }
  602. func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
  603. var history []fantasy.Message
  604. if !a.isSubAgent {
  605. history = append(history, fantasy.NewUserMessage(
  606. fmt.Sprintf("<system_reminder>%s</system_reminder>",
  607. `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
  608. If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
  609. If not, please feel free to ignore. Again do not mention this message to the user.`,
  610. ),
  611. ))
  612. }
  613. for _, m := range msgs {
  614. if len(m.Parts) == 0 {
  615. continue
  616. }
  617. // Assistant message without content or tool calls (cancelled before it
  618. // returned anything).
  619. if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
  620. continue
  621. }
  622. history = append(history, m.ToAIMessage()...)
  623. }
  624. var files []fantasy.FilePart
  625. for _, attachment := range attachments {
  626. files = append(files, fantasy.FilePart{
  627. Filename: attachment.FileName,
  628. Data: attachment.Content,
  629. MediaType: attachment.MimeType,
  630. })
  631. }
  632. return history, files
  633. }
  634. func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
  635. msgs, err := a.messages.List(ctx, session.ID)
  636. if err != nil {
  637. return nil, fmt.Errorf("failed to list messages: %w", err)
  638. }
  639. if session.SummaryMessageID != "" {
  640. summaryMsgInex := -1
  641. for i, msg := range msgs {
  642. if msg.ID == session.SummaryMessageID {
  643. summaryMsgInex = i
  644. break
  645. }
  646. }
  647. if summaryMsgInex != -1 {
  648. msgs = msgs[summaryMsgInex:]
  649. msgs[0].Role = message.User
  650. }
  651. }
  652. return msgs, nil
  653. }
  654. func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
  655. if prompt == "" {
  656. return
  657. }
  658. var maxOutput int64 = 40
  659. if a.smallModel.CatwalkCfg.CanReason {
  660. maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
  661. }
  662. agent := fantasy.NewAgent(a.smallModel.Model,
  663. fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
  664. fantasy.WithMaxOutputTokens(maxOutput),
  665. )
  666. resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
  667. Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
  668. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  669. prepared.Messages = options.Messages
  670. if a.systemPromptPrefix != "" {
  671. prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
  672. }
  673. return callContext, prepared, nil
  674. },
  675. })
  676. if err != nil {
  677. slog.Error("error generating title", "err", err)
  678. return
  679. }
  680. title := resp.Response.Content.Text()
  681. title = strings.ReplaceAll(title, "\n", " ")
  682. // Remove thinking tags if present.
  683. if idx := strings.Index(title, "</think>"); idx > 0 {
  684. title = title[idx+len("</think>"):]
  685. }
  686. title = strings.TrimSpace(title)
  687. if title == "" {
  688. slog.Warn("failed to generate title", "warn", "empty title")
  689. return
  690. }
  691. session.Title = title
  692. var openrouterCost *float64
  693. for _, step := range resp.Steps {
  694. stepCost := a.openrouterCost(step.ProviderMetadata)
  695. if stepCost != nil {
  696. newCost := *stepCost
  697. if openrouterCost != nil {
  698. newCost += *openrouterCost
  699. }
  700. openrouterCost = &newCost
  701. }
  702. }
  703. a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
  704. _, saveErr := a.sessions.Save(ctx, *session)
  705. if saveErr != nil {
  706. slog.Error("failed to save session title & usage", "error", saveErr)
  707. return
  708. }
  709. }
  710. func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
  711. openrouterMetadata, ok := metadata[openrouter.Name]
  712. if !ok {
  713. return nil
  714. }
  715. opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
  716. if !ok {
  717. return nil
  718. }
  719. return &opts.Usage.Cost
  720. }
  721. func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
  722. modelConfig := model.CatwalkCfg
  723. cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
  724. modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
  725. modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
  726. modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
  727. if a.isClaudeCode() {
  728. cost = 0
  729. }
  730. a.eventTokensUsed(session.ID, model, usage, cost)
  731. if overrideCost != nil {
  732. session.Cost += *overrideCost
  733. } else {
  734. session.Cost += cost
  735. }
  736. session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
  737. session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
  738. }
  739. func (a *sessionAgent) Cancel(sessionID string) {
  740. // Cancel regular requests.
  741. if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
  742. slog.Info("Request cancellation initiated", "session_id", sessionID)
  743. cancel()
  744. }
  745. // Also check for summarize requests.
  746. if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
  747. slog.Info("Summarize cancellation initiated", "session_id", sessionID)
  748. cancel()
  749. }
  750. if a.QueuedPrompts(sessionID) > 0 {
  751. slog.Info("Clearing queued prompts", "session_id", sessionID)
  752. a.messageQueue.Del(sessionID)
  753. }
  754. }
  755. func (a *sessionAgent) ClearQueue(sessionID string) {
  756. if a.QueuedPrompts(sessionID) > 0 {
  757. slog.Info("Clearing queued prompts", "session_id", sessionID)
  758. a.messageQueue.Del(sessionID)
  759. }
  760. }
  761. func (a *sessionAgent) CancelAll() {
  762. if !a.IsBusy() {
  763. return
  764. }
  765. for key := range a.activeRequests.Seq2() {
  766. a.Cancel(key) // key is sessionID
  767. }
  768. timeout := time.After(5 * time.Second)
  769. for a.IsBusy() {
  770. select {
  771. case <-timeout:
  772. return
  773. default:
  774. time.Sleep(200 * time.Millisecond)
  775. }
  776. }
  777. }
  778. func (a *sessionAgent) IsBusy() bool {
  779. var busy bool
  780. for cancelFunc := range a.activeRequests.Seq() {
  781. if cancelFunc != nil {
  782. busy = true
  783. break
  784. }
  785. }
  786. return busy
  787. }
  788. func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
  789. _, busy := a.activeRequests.Get(sessionID)
  790. return busy
  791. }
  792. func (a *sessionAgent) QueuedPrompts(sessionID string) int {
  793. l, ok := a.messageQueue.Get(sessionID)
  794. if !ok {
  795. return 0
  796. }
  797. return len(l)
  798. }
  799. func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
  800. l, ok := a.messageQueue.Get(sessionID)
  801. if !ok {
  802. return nil
  803. }
  804. prompts := make([]string, len(l))
  805. for i, call := range l {
  806. prompts[i] = call.Prompt
  807. }
  808. return prompts
  809. }
  810. func (a *sessionAgent) SetModels(large Model, small Model) {
  811. a.largeModel = large
  812. a.smallModel = small
  813. }
  814. func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
  815. a.tools = tools
  816. }
  817. func (a *sessionAgent) Model() Model {
  818. return a.largeModel
  819. }
  820. func (a *sessionAgent) promptPrefix() string {
  821. if a.isClaudeCode() {
  822. return "You are Claude Code, Anthropic's official CLI for Claude."
  823. }
  824. return a.systemPromptPrefix
  825. }
  826. func (a *sessionAgent) isClaudeCode() bool {
  827. cfg := config.Get()
  828. pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
  829. return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
  830. }
  831. // convertToToolResult converts a fantasy tool result to a message tool result.
  832. func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
  833. baseResult := message.ToolResult{
  834. ToolCallID: result.ToolCallID,
  835. Name: result.ToolName,
  836. Metadata: result.ClientMetadata,
  837. }
  838. switch result.Result.GetType() {
  839. case fantasy.ToolResultContentTypeText:
  840. if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
  841. baseResult.Content = r.Text
  842. }
  843. case fantasy.ToolResultContentTypeError:
  844. if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
  845. baseResult.Content = r.Error.Error()
  846. baseResult.IsError = true
  847. }
  848. case fantasy.ToolResultContentTypeMedia:
  849. if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
  850. content := r.Text
  851. if content == "" {
  852. content = fmt.Sprintf("Loaded %s content", r.MediaType)
  853. }
  854. baseResult.Content = content
  855. baseResult.Data = r.Data
  856. baseResult.MIMEType = r.MediaType
  857. }
  858. }
  859. return baseResult
  860. }
  861. // workaroundProviderMediaLimitations converts media content in tool results to
  862. // user messages for providers that don't natively support images in tool results.
  863. //
  864. // Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
  865. // don't support sending images/media in tool result messages - they only accept
  866. // text in tool results. However, they DO support images in user messages.
  867. //
  868. // If we send media in tool results to these providers, the API returns an error.
  869. //
  870. // Solution: For these providers, we:
  871. // 1. Replace the media in the tool result with a text placeholder
  872. // 2. Inject a user message immediately after with the image as a file attachment
  873. // 3. This maintains the tool execution flow while working around API limitations
  874. //
  875. // Anthropic and Bedrock support images natively in tool results, so we skip
  876. // this workaround for them.
  877. //
  878. // Example transformation:
  879. //
  880. // BEFORE: [tool result: image data]
  881. // AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
  882. func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
  883. providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
  884. a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
  885. if providerSupportsMedia {
  886. return messages
  887. }
  888. convertedMessages := make([]fantasy.Message, 0, len(messages))
  889. for _, msg := range messages {
  890. if msg.Role != fantasy.MessageRoleTool {
  891. convertedMessages = append(convertedMessages, msg)
  892. continue
  893. }
  894. textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
  895. var mediaFiles []fantasy.FilePart
  896. for _, part := range msg.Content {
  897. toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
  898. if !ok {
  899. textParts = append(textParts, part)
  900. continue
  901. }
  902. if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
  903. decoded, err := base64.StdEncoding.DecodeString(media.Data)
  904. if err != nil {
  905. slog.Warn("failed to decode media data", "error", err)
  906. textParts = append(textParts, part)
  907. continue
  908. }
  909. mediaFiles = append(mediaFiles, fantasy.FilePart{
  910. Data: decoded,
  911. MediaType: media.MediaType,
  912. Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
  913. })
  914. textParts = append(textParts, fantasy.ToolResultPart{
  915. ToolCallID: toolResult.ToolCallID,
  916. Output: fantasy.ToolResultOutputContentText{
  917. Text: "[Image/media content loaded - see attached file]",
  918. },
  919. ProviderOptions: toolResult.ProviderOptions,
  920. })
  921. } else {
  922. textParts = append(textParts, part)
  923. }
  924. }
  925. convertedMessages = append(convertedMessages, fantasy.Message{
  926. Role: fantasy.MessageRoleTool,
  927. Content: textParts,
  928. })
  929. if len(mediaFiles) > 0 {
  930. convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
  931. "Here is the media content from the tool result:",
  932. mediaFiles...,
  933. ))
  934. }
  935. }
  936. return convertedMessages
  937. }