agent.go 30 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000
  1. // Package agent is the core orchestration layer for Crush AI agents.
  2. //
  3. // It provides session-based AI agent functionality for managing
  4. // conversations, tool execution, and message handling. It coordinates
  5. // interactions between language models, messages, sessions, and tools while
  6. // handling features like automatic summarization, queuing, and token
  7. // management.
  8. package agent
  9. import (
  10. "cmp"
  11. "context"
  12. _ "embed"
  13. "encoding/base64"
  14. "errors"
  15. "fmt"
  16. "log/slog"
  17. "os"
  18. "strconv"
  19. "strings"
  20. "sync"
  21. "time"
  22. "charm.land/fantasy"
  23. "charm.land/fantasy/providers/anthropic"
  24. "charm.land/fantasy/providers/bedrock"
  25. "charm.land/fantasy/providers/google"
  26. "charm.land/fantasy/providers/openai"
  27. "charm.land/fantasy/providers/openrouter"
  28. "github.com/charmbracelet/catwalk/pkg/catwalk"
  29. "github.com/charmbracelet/crush/internal/agent/tools"
  30. "github.com/charmbracelet/crush/internal/config"
  31. "github.com/charmbracelet/crush/internal/csync"
  32. "github.com/charmbracelet/crush/internal/message"
  33. "github.com/charmbracelet/crush/internal/permission"
  34. "github.com/charmbracelet/crush/internal/session"
  35. "github.com/charmbracelet/crush/internal/stringext"
  36. )
  37. //go:embed templates/title.md
  38. var titlePrompt []byte
  39. //go:embed templates/summary.md
  40. var summaryPrompt []byte
  41. type SessionAgentCall struct {
  42. SessionID string
  43. Prompt string
  44. ProviderOptions fantasy.ProviderOptions
  45. Attachments []message.Attachment
  46. MaxOutputTokens int64
  47. Temperature *float64
  48. TopP *float64
  49. TopK *int64
  50. FrequencyPenalty *float64
  51. PresencePenalty *float64
  52. }
  53. type SessionAgent interface {
  54. Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  55. SetModels(large Model, small Model)
  56. SetTools(tools []fantasy.AgentTool)
  57. Cancel(sessionID string)
  58. CancelAll()
  59. IsSessionBusy(sessionID string) bool
  60. IsBusy() bool
  61. QueuedPrompts(sessionID string) int
  62. ClearQueue(sessionID string)
  63. Summarize(context.Context, string, fantasy.ProviderOptions) error
  64. Model() Model
  65. }
  66. type Model struct {
  67. Model fantasy.LanguageModel
  68. CatwalkCfg catwalk.Model
  69. ModelCfg config.SelectedModel
  70. }
  71. type sessionAgent struct {
  72. largeModel Model
  73. smallModel Model
  74. systemPromptPrefix string
  75. systemPrompt string
  76. tools []fantasy.AgentTool
  77. sessions session.Service
  78. messages message.Service
  79. disableAutoSummarize bool
  80. isYolo bool
  81. messageQueue *csync.Map[string, []SessionAgentCall]
  82. activeRequests *csync.Map[string, context.CancelFunc]
  83. }
  84. type SessionAgentOptions struct {
  85. LargeModel Model
  86. SmallModel Model
  87. SystemPromptPrefix string
  88. SystemPrompt string
  89. DisableAutoSummarize bool
  90. IsYolo bool
  91. Sessions session.Service
  92. Messages message.Service
  93. Tools []fantasy.AgentTool
  94. }
  95. func NewSessionAgent(
  96. opts SessionAgentOptions,
  97. ) SessionAgent {
  98. return &sessionAgent{
  99. largeModel: opts.LargeModel,
  100. smallModel: opts.SmallModel,
  101. systemPromptPrefix: opts.SystemPromptPrefix,
  102. systemPrompt: opts.SystemPrompt,
  103. sessions: opts.Sessions,
  104. messages: opts.Messages,
  105. disableAutoSummarize: opts.DisableAutoSummarize,
  106. tools: opts.Tools,
  107. isYolo: opts.IsYolo,
  108. messageQueue: csync.NewMap[string, []SessionAgentCall](),
  109. activeRequests: csync.NewMap[string, context.CancelFunc](),
  110. }
  111. }
  112. func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
  113. if call.Prompt == "" {
  114. return nil, ErrEmptyPrompt
  115. }
  116. if call.SessionID == "" {
  117. return nil, ErrSessionMissing
  118. }
  119. // Queue the message if busy
  120. if a.IsSessionBusy(call.SessionID) {
  121. existing, ok := a.messageQueue.Get(call.SessionID)
  122. if !ok {
  123. existing = []SessionAgentCall{}
  124. }
  125. existing = append(existing, call)
  126. a.messageQueue.Set(call.SessionID, existing)
  127. return nil, nil
  128. }
  129. if len(a.tools) > 0 {
  130. // Add Anthropic caching to the last tool.
  131. a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
  132. }
  133. agent := fantasy.NewAgent(
  134. a.largeModel.Model,
  135. fantasy.WithSystemPrompt(a.systemPrompt),
  136. fantasy.WithTools(a.tools...),
  137. )
  138. sessionLock := sync.Mutex{}
  139. currentSession, err := a.sessions.Get(ctx, call.SessionID)
  140. if err != nil {
  141. return nil, fmt.Errorf("failed to get session: %w", err)
  142. }
  143. msgs, err := a.getSessionMessages(ctx, currentSession)
  144. if err != nil {
  145. return nil, fmt.Errorf("failed to get session messages: %w", err)
  146. }
  147. var wg sync.WaitGroup
  148. // Generate title if first message.
  149. if len(msgs) == 0 {
  150. wg.Go(func() {
  151. sessionLock.Lock()
  152. a.generateTitle(ctx, &currentSession, call.Prompt)
  153. sessionLock.Unlock()
  154. })
  155. }
  156. // Add the user message to the session.
  157. _, err = a.createUserMessage(ctx, call)
  158. if err != nil {
  159. return nil, err
  160. }
  161. // Add the session to the context.
  162. ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
  163. genCtx, cancel := context.WithCancel(ctx)
  164. a.activeRequests.Set(call.SessionID, cancel)
  165. defer cancel()
  166. defer a.activeRequests.Del(call.SessionID)
  167. history, files := a.preparePrompt(msgs, call.Attachments...)
  168. startTime := time.Now()
  169. a.eventPromptSent(call.SessionID)
  170. var currentAssistant *message.Message
  171. var shouldSummarize bool
  172. result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
  173. Prompt: call.Prompt,
  174. Files: files,
  175. Messages: history,
  176. ProviderOptions: call.ProviderOptions,
  177. MaxOutputTokens: &call.MaxOutputTokens,
  178. TopP: call.TopP,
  179. Temperature: call.Temperature,
  180. PresencePenalty: call.PresencePenalty,
  181. TopK: call.TopK,
  182. FrequencyPenalty: call.FrequencyPenalty,
  183. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  184. prepared.Messages = options.Messages
  185. for i := range prepared.Messages {
  186. prepared.Messages[i].ProviderOptions = nil
  187. }
  188. queuedCalls, _ := a.messageQueue.Get(call.SessionID)
  189. a.messageQueue.Del(call.SessionID)
  190. for _, queued := range queuedCalls {
  191. userMessage, createErr := a.createUserMessage(callContext, queued)
  192. if createErr != nil {
  193. return callContext, prepared, createErr
  194. }
  195. prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
  196. }
  197. prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
  198. lastSystemRoleInx := 0
  199. systemMessageUpdated := false
  200. for i, msg := range prepared.Messages {
  201. // Only add cache control to the last message.
  202. if msg.Role == fantasy.MessageRoleSystem {
  203. lastSystemRoleInx = i
  204. } else if !systemMessageUpdated {
  205. prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
  206. systemMessageUpdated = true
  207. }
  208. // Than add cache control to the last 2 messages.
  209. if i > len(prepared.Messages)-3 {
  210. prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
  211. }
  212. }
  213. if promptPrefix := a.promptPrefix(); promptPrefix != "" {
  214. prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
  215. }
  216. var assistantMsg message.Message
  217. assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
  218. Role: message.Assistant,
  219. Parts: []message.ContentPart{},
  220. Model: a.largeModel.ModelCfg.Model,
  221. Provider: a.largeModel.ModelCfg.Provider,
  222. })
  223. if err != nil {
  224. return callContext, prepared, err
  225. }
  226. callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
  227. callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
  228. callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
  229. currentAssistant = &assistantMsg
  230. return callContext, prepared, err
  231. },
  232. OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
  233. currentAssistant.AppendReasoningContent(reasoning.Text)
  234. return a.messages.Update(genCtx, *currentAssistant)
  235. },
  236. OnReasoningDelta: func(id string, text string) error {
  237. currentAssistant.AppendReasoningContent(text)
  238. return a.messages.Update(genCtx, *currentAssistant)
  239. },
  240. OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
  241. // handle anthropic signature
  242. if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
  243. if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
  244. currentAssistant.AppendReasoningSignature(reasoning.Signature)
  245. }
  246. }
  247. if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
  248. if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
  249. currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
  250. }
  251. }
  252. if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
  253. if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
  254. currentAssistant.SetReasoningResponsesData(reasoning)
  255. }
  256. }
  257. currentAssistant.FinishThinking()
  258. return a.messages.Update(genCtx, *currentAssistant)
  259. },
  260. OnTextDelta: func(id string, text string) error {
  261. // Strip leading newline from initial text content. This is is
  262. // particularly important in non-interactive mode where leading
  263. // newlines are very visible.
  264. if len(currentAssistant.Parts) == 0 {
  265. text = strings.TrimPrefix(text, "\n")
  266. }
  267. currentAssistant.AppendContent(text)
  268. return a.messages.Update(genCtx, *currentAssistant)
  269. },
  270. OnToolInputStart: func(id string, toolName string) error {
  271. toolCall := message.ToolCall{
  272. ID: id,
  273. Name: toolName,
  274. ProviderExecuted: false,
  275. Finished: false,
  276. }
  277. currentAssistant.AddToolCall(toolCall)
  278. return a.messages.Update(genCtx, *currentAssistant)
  279. },
  280. OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
  281. // TODO: implement
  282. },
  283. OnToolCall: func(tc fantasy.ToolCallContent) error {
  284. toolCall := message.ToolCall{
  285. ID: tc.ToolCallID,
  286. Name: tc.ToolName,
  287. Input: tc.Input,
  288. ProviderExecuted: false,
  289. Finished: true,
  290. }
  291. currentAssistant.AddToolCall(toolCall)
  292. return a.messages.Update(genCtx, *currentAssistant)
  293. },
  294. OnToolResult: func(result fantasy.ToolResultContent) error {
  295. toolResult := a.convertToToolResult(result)
  296. _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
  297. Role: message.Tool,
  298. Parts: []message.ContentPart{
  299. toolResult,
  300. },
  301. })
  302. return createMsgErr
  303. },
  304. OnStepFinish: func(stepResult fantasy.StepResult) error {
  305. finishReason := message.FinishReasonUnknown
  306. switch stepResult.FinishReason {
  307. case fantasy.FinishReasonLength:
  308. finishReason = message.FinishReasonMaxTokens
  309. case fantasy.FinishReasonStop:
  310. finishReason = message.FinishReasonEndTurn
  311. case fantasy.FinishReasonToolCalls:
  312. finishReason = message.FinishReasonToolUse
  313. }
  314. currentAssistant.AddFinish(finishReason, "", "")
  315. a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
  316. sessionLock.Lock()
  317. _, sessionErr := a.sessions.Save(genCtx, currentSession)
  318. sessionLock.Unlock()
  319. if sessionErr != nil {
  320. return sessionErr
  321. }
  322. return a.messages.Update(genCtx, *currentAssistant)
  323. },
  324. StopWhen: []fantasy.StopCondition{
  325. func(_ []fantasy.StepResult) bool {
  326. cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
  327. tokens := currentSession.CompletionTokens + currentSession.PromptTokens
  328. remaining := cw - tokens
  329. var threshold int64
  330. if cw > 200_000 {
  331. threshold = 20_000
  332. } else {
  333. threshold = int64(float64(cw) * 0.2)
  334. }
  335. if (remaining <= threshold) && !a.disableAutoSummarize {
  336. shouldSummarize = true
  337. return true
  338. }
  339. return false
  340. },
  341. },
  342. })
  343. a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
  344. if err != nil {
  345. isCancelErr := errors.Is(err, context.Canceled)
  346. isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
  347. if currentAssistant == nil {
  348. return result, err
  349. }
  350. // Ensure we finish thinking on error to close the reasoning state.
  351. currentAssistant.FinishThinking()
  352. toolCalls := currentAssistant.ToolCalls()
  353. // INFO: we use the parent context here because the genCtx has been cancelled.
  354. msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
  355. if createErr != nil {
  356. return nil, createErr
  357. }
  358. for _, tc := range toolCalls {
  359. if !tc.Finished {
  360. tc.Finished = true
  361. tc.Input = "{}"
  362. currentAssistant.AddToolCall(tc)
  363. updateErr := a.messages.Update(ctx, *currentAssistant)
  364. if updateErr != nil {
  365. return nil, updateErr
  366. }
  367. }
  368. found := false
  369. for _, msg := range msgs {
  370. if msg.Role == message.Tool {
  371. for _, tr := range msg.ToolResults() {
  372. if tr.ToolCallID == tc.ID {
  373. found = true
  374. break
  375. }
  376. }
  377. }
  378. if found {
  379. break
  380. }
  381. }
  382. if found {
  383. continue
  384. }
  385. content := "There was an error while executing the tool"
  386. if isCancelErr {
  387. content = "Tool execution canceled by user"
  388. } else if isPermissionErr {
  389. content = "User denied permission"
  390. }
  391. toolResult := message.ToolResult{
  392. ToolCallID: tc.ID,
  393. Name: tc.Name,
  394. Content: content,
  395. IsError: true,
  396. }
  397. _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
  398. Role: message.Tool,
  399. Parts: []message.ContentPart{
  400. toolResult,
  401. },
  402. })
  403. if createErr != nil {
  404. return nil, createErr
  405. }
  406. }
  407. var fantasyErr *fantasy.Error
  408. var providerErr *fantasy.ProviderError
  409. const defaultTitle = "Provider Error"
  410. if isCancelErr {
  411. currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
  412. } else if isPermissionErr {
  413. currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
  414. } else if errors.As(err, &providerErr) {
  415. currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
  416. } else if errors.As(err, &fantasyErr) {
  417. currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
  418. } else {
  419. currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
  420. }
  421. // Note: we use the parent context here because the genCtx has been
  422. // cancelled.
  423. updateErr := a.messages.Update(ctx, *currentAssistant)
  424. if updateErr != nil {
  425. return nil, updateErr
  426. }
  427. return nil, err
  428. }
  429. wg.Wait()
  430. if shouldSummarize {
  431. a.activeRequests.Del(call.SessionID)
  432. if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
  433. return nil, summarizeErr
  434. }
  435. // If the agent wasn't done...
  436. if len(currentAssistant.ToolCalls()) > 0 {
  437. existing, ok := a.messageQueue.Get(call.SessionID)
  438. if !ok {
  439. existing = []SessionAgentCall{}
  440. }
  441. call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
  442. existing = append(existing, call)
  443. a.messageQueue.Set(call.SessionID, existing)
  444. }
  445. }
  446. // Release active request before processing queued messages.
  447. a.activeRequests.Del(call.SessionID)
  448. cancel()
  449. queuedMessages, ok := a.messageQueue.Get(call.SessionID)
  450. if !ok || len(queuedMessages) == 0 {
  451. return result, err
  452. }
  453. // There are queued messages restart the loop.
  454. firstQueuedMessage := queuedMessages[0]
  455. a.messageQueue.Set(call.SessionID, queuedMessages[1:])
  456. return a.Run(ctx, firstQueuedMessage)
  457. }
  458. func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
  459. if a.IsSessionBusy(sessionID) {
  460. return ErrSessionBusy
  461. }
  462. currentSession, err := a.sessions.Get(ctx, sessionID)
  463. if err != nil {
  464. return fmt.Errorf("failed to get session: %w", err)
  465. }
  466. msgs, err := a.getSessionMessages(ctx, currentSession)
  467. if err != nil {
  468. return err
  469. }
  470. if len(msgs) == 0 {
  471. // Nothing to summarize.
  472. return nil
  473. }
  474. aiMsgs, _ := a.preparePrompt(msgs)
  475. genCtx, cancel := context.WithCancel(ctx)
  476. a.activeRequests.Set(sessionID, cancel)
  477. defer a.activeRequests.Del(sessionID)
  478. defer cancel()
  479. agent := fantasy.NewAgent(a.largeModel.Model,
  480. fantasy.WithSystemPrompt(string(summaryPrompt)),
  481. )
  482. summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  483. Role: message.Assistant,
  484. Model: a.largeModel.Model.Model(),
  485. Provider: a.largeModel.Model.Provider(),
  486. IsSummaryMessage: true,
  487. })
  488. if err != nil {
  489. return err
  490. }
  491. resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
  492. Prompt: "Provide a detailed summary of our conversation above.",
  493. Messages: aiMsgs,
  494. ProviderOptions: opts,
  495. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  496. prepared.Messages = options.Messages
  497. if a.systemPromptPrefix != "" {
  498. prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
  499. }
  500. return callContext, prepared, nil
  501. },
  502. OnReasoningDelta: func(id string, text string) error {
  503. summaryMessage.AppendReasoningContent(text)
  504. return a.messages.Update(genCtx, summaryMessage)
  505. },
  506. OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
  507. // Handle anthropic signature.
  508. if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
  509. if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
  510. summaryMessage.AppendReasoningSignature(signature.Signature)
  511. }
  512. }
  513. summaryMessage.FinishThinking()
  514. return a.messages.Update(genCtx, summaryMessage)
  515. },
  516. OnTextDelta: func(id, text string) error {
  517. summaryMessage.AppendContent(text)
  518. return a.messages.Update(genCtx, summaryMessage)
  519. },
  520. })
  521. if err != nil {
  522. isCancelErr := errors.Is(err, context.Canceled)
  523. if isCancelErr {
  524. // User cancelled summarize we need to remove the summary message.
  525. deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
  526. return deleteErr
  527. }
  528. return err
  529. }
  530. summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
  531. err = a.messages.Update(genCtx, summaryMessage)
  532. if err != nil {
  533. return err
  534. }
  535. var openrouterCost *float64
  536. for _, step := range resp.Steps {
  537. stepCost := a.openrouterCost(step.ProviderMetadata)
  538. if stepCost != nil {
  539. newCost := *stepCost
  540. if openrouterCost != nil {
  541. newCost += *openrouterCost
  542. }
  543. openrouterCost = &newCost
  544. }
  545. }
  546. a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
  547. // Just in case, get just the last usage info.
  548. usage := resp.Response.Usage
  549. currentSession.SummaryMessageID = summaryMessage.ID
  550. currentSession.CompletionTokens = usage.OutputTokens
  551. currentSession.PromptTokens = 0
  552. _, err = a.sessions.Save(genCtx, currentSession)
  553. return err
  554. }
  555. func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
  556. if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
  557. return fantasy.ProviderOptions{}
  558. }
  559. return fantasy.ProviderOptions{
  560. anthropic.Name: &anthropic.ProviderCacheControlOptions{
  561. CacheControl: anthropic.CacheControl{Type: "ephemeral"},
  562. },
  563. bedrock.Name: &anthropic.ProviderCacheControlOptions{
  564. CacheControl: anthropic.CacheControl{Type: "ephemeral"},
  565. },
  566. }
  567. }
  568. func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
  569. var attachmentParts []message.ContentPart
  570. for _, attachment := range call.Attachments {
  571. attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
  572. }
  573. parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
  574. parts = append(parts, attachmentParts...)
  575. msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
  576. Role: message.User,
  577. Parts: parts,
  578. })
  579. if err != nil {
  580. return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
  581. }
  582. return msg, nil
  583. }
  584. func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
  585. var history []fantasy.Message
  586. for _, m := range msgs {
  587. if len(m.Parts) == 0 {
  588. continue
  589. }
  590. // Assistant message without content or tool calls (cancelled before it
  591. // returned anything).
  592. if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
  593. continue
  594. }
  595. history = append(history, m.ToAIMessage()...)
  596. }
  597. var files []fantasy.FilePart
  598. for _, attachment := range attachments {
  599. files = append(files, fantasy.FilePart{
  600. Filename: attachment.FileName,
  601. Data: attachment.Content,
  602. MediaType: attachment.MimeType,
  603. })
  604. }
  605. return history, files
  606. }
  607. func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
  608. msgs, err := a.messages.List(ctx, session.ID)
  609. if err != nil {
  610. return nil, fmt.Errorf("failed to list messages: %w", err)
  611. }
  612. if session.SummaryMessageID != "" {
  613. summaryMsgInex := -1
  614. for i, msg := range msgs {
  615. if msg.ID == session.SummaryMessageID {
  616. summaryMsgInex = i
  617. break
  618. }
  619. }
  620. if summaryMsgInex != -1 {
  621. msgs = msgs[summaryMsgInex:]
  622. msgs[0].Role = message.User
  623. }
  624. }
  625. return msgs, nil
  626. }
  627. func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
  628. if prompt == "" {
  629. return
  630. }
  631. var maxOutput int64 = 40
  632. if a.smallModel.CatwalkCfg.CanReason {
  633. maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
  634. }
  635. agent := fantasy.NewAgent(a.smallModel.Model,
  636. fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
  637. fantasy.WithMaxOutputTokens(maxOutput),
  638. )
  639. resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
  640. Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
  641. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  642. prepared.Messages = options.Messages
  643. if a.systemPromptPrefix != "" {
  644. prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
  645. }
  646. return callContext, prepared, nil
  647. },
  648. })
  649. if err != nil {
  650. slog.Error("error generating title", "err", err)
  651. return
  652. }
  653. title := resp.Response.Content.Text()
  654. title = strings.ReplaceAll(title, "\n", " ")
  655. // Remove thinking tags if present.
  656. if idx := strings.Index(title, "</think>"); idx > 0 {
  657. title = title[idx+len("</think>"):]
  658. }
  659. title = strings.TrimSpace(title)
  660. if title == "" {
  661. slog.Warn("failed to generate title", "warn", "empty title")
  662. return
  663. }
  664. session.Title = title
  665. var openrouterCost *float64
  666. for _, step := range resp.Steps {
  667. stepCost := a.openrouterCost(step.ProviderMetadata)
  668. if stepCost != nil {
  669. newCost := *stepCost
  670. if openrouterCost != nil {
  671. newCost += *openrouterCost
  672. }
  673. openrouterCost = &newCost
  674. }
  675. }
  676. a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
  677. _, saveErr := a.sessions.Save(ctx, *session)
  678. if saveErr != nil {
  679. slog.Error("failed to save session title & usage", "error", saveErr)
  680. return
  681. }
  682. }
  683. func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
  684. openrouterMetadata, ok := metadata[openrouter.Name]
  685. if !ok {
  686. return nil
  687. }
  688. opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
  689. if !ok {
  690. return nil
  691. }
  692. return &opts.Usage.Cost
  693. }
  694. func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
  695. modelConfig := model.CatwalkCfg
  696. cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
  697. modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
  698. modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
  699. modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
  700. if a.isClaudeCode() {
  701. cost = 0
  702. }
  703. a.eventTokensUsed(session.ID, model, usage, cost)
  704. if overrideCost != nil {
  705. session.Cost += *overrideCost
  706. } else {
  707. session.Cost += cost
  708. }
  709. session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
  710. session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
  711. }
  712. func (a *sessionAgent) Cancel(sessionID string) {
  713. // Cancel regular requests.
  714. if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
  715. slog.Info("Request cancellation initiated", "session_id", sessionID)
  716. cancel()
  717. }
  718. // Also check for summarize requests.
  719. if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
  720. slog.Info("Summarize cancellation initiated", "session_id", sessionID)
  721. cancel()
  722. }
  723. if a.QueuedPrompts(sessionID) > 0 {
  724. slog.Info("Clearing queued prompts", "session_id", sessionID)
  725. a.messageQueue.Del(sessionID)
  726. }
  727. }
  728. func (a *sessionAgent) ClearQueue(sessionID string) {
  729. if a.QueuedPrompts(sessionID) > 0 {
  730. slog.Info("Clearing queued prompts", "session_id", sessionID)
  731. a.messageQueue.Del(sessionID)
  732. }
  733. }
  734. func (a *sessionAgent) CancelAll() {
  735. if !a.IsBusy() {
  736. return
  737. }
  738. for key := range a.activeRequests.Seq2() {
  739. a.Cancel(key) // key is sessionID
  740. }
  741. timeout := time.After(5 * time.Second)
  742. for a.IsBusy() {
  743. select {
  744. case <-timeout:
  745. return
  746. default:
  747. time.Sleep(200 * time.Millisecond)
  748. }
  749. }
  750. }
  751. func (a *sessionAgent) IsBusy() bool {
  752. var busy bool
  753. for cancelFunc := range a.activeRequests.Seq() {
  754. if cancelFunc != nil {
  755. busy = true
  756. break
  757. }
  758. }
  759. return busy
  760. }
  761. func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
  762. _, busy := a.activeRequests.Get(sessionID)
  763. return busy
  764. }
  765. func (a *sessionAgent) QueuedPrompts(sessionID string) int {
  766. l, ok := a.messageQueue.Get(sessionID)
  767. if !ok {
  768. return 0
  769. }
  770. return len(l)
  771. }
  772. func (a *sessionAgent) SetModels(large Model, small Model) {
  773. a.largeModel = large
  774. a.smallModel = small
  775. }
  776. func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
  777. a.tools = tools
  778. }
  779. func (a *sessionAgent) Model() Model {
  780. return a.largeModel
  781. }
  782. func (a *sessionAgent) promptPrefix() string {
  783. if a.isClaudeCode() {
  784. return "You are Claude Code, Anthropic's official CLI for Claude."
  785. }
  786. return a.systemPromptPrefix
  787. }
  788. func (a *sessionAgent) isClaudeCode() bool {
  789. cfg := config.Get()
  790. pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
  791. return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
  792. }
  793. // convertToToolResult converts a fantasy tool result to a message tool result.
  794. func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
  795. baseResult := message.ToolResult{
  796. ToolCallID: result.ToolCallID,
  797. Name: result.ToolName,
  798. Metadata: result.ClientMetadata,
  799. }
  800. switch result.Result.GetType() {
  801. case fantasy.ToolResultContentTypeText:
  802. if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
  803. baseResult.Content = r.Text
  804. }
  805. case fantasy.ToolResultContentTypeError:
  806. if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
  807. baseResult.Content = r.Error.Error()
  808. baseResult.IsError = true
  809. }
  810. case fantasy.ToolResultContentTypeMedia:
  811. if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
  812. content := r.Text
  813. if content == "" {
  814. content = fmt.Sprintf("Loaded %s content", r.MediaType)
  815. }
  816. baseResult.Content = content
  817. baseResult.Data = r.Data
  818. baseResult.MIMEType = r.MediaType
  819. }
  820. }
  821. return baseResult
  822. }
  823. // workaroundProviderMediaLimitations converts media content in tool results to
  824. // user messages for providers that don't natively support images in tool results.
  825. //
  826. // Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
  827. // don't support sending images/media in tool result messages - they only accept
  828. // text in tool results. However, they DO support images in user messages.
  829. //
  830. // If we send media in tool results to these providers, the API returns an error.
  831. //
  832. // Solution: For these providers, we:
  833. // 1. Replace the media in the tool result with a text placeholder
  834. // 2. Inject a user message immediately after with the image as a file attachment
  835. // 3. This maintains the tool execution flow while working around API limitations
  836. //
  837. // Anthropic and Bedrock support images natively in tool results, so we skip
  838. // this workaround for them.
  839. //
  840. // Example transformation:
  841. //
  842. // BEFORE: [tool result: image data]
  843. // AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
  844. func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
  845. providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
  846. a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
  847. if providerSupportsMedia {
  848. return messages
  849. }
  850. convertedMessages := make([]fantasy.Message, 0, len(messages))
  851. for _, msg := range messages {
  852. if msg.Role != fantasy.MessageRoleTool {
  853. convertedMessages = append(convertedMessages, msg)
  854. continue
  855. }
  856. textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
  857. var mediaFiles []fantasy.FilePart
  858. for _, part := range msg.Content {
  859. toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
  860. if !ok {
  861. textParts = append(textParts, part)
  862. continue
  863. }
  864. if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
  865. decoded, err := base64.StdEncoding.DecodeString(media.Data)
  866. if err != nil {
  867. slog.Warn("failed to decode media data", "error", err)
  868. textParts = append(textParts, part)
  869. continue
  870. }
  871. mediaFiles = append(mediaFiles, fantasy.FilePart{
  872. Data: decoded,
  873. MediaType: media.MediaType,
  874. Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
  875. })
  876. textParts = append(textParts, fantasy.ToolResultPart{
  877. ToolCallID: toolResult.ToolCallID,
  878. Output: fantasy.ToolResultOutputContentText{
  879. Text: "[Image/media content loaded - see attached file]",
  880. },
  881. ProviderOptions: toolResult.ProviderOptions,
  882. })
  883. } else {
  884. textParts = append(textParts, part)
  885. }
  886. }
  887. convertedMessages = append(convertedMessages, fantasy.Message{
  888. Role: fantasy.MessageRoleTool,
  889. Content: textParts,
  890. })
  891. if len(mediaFiles) > 0 {
  892. convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
  893. "Here is the media content from the tool result:",
  894. mediaFiles...,
  895. ))
  896. }
  897. }
  898. return convertedMessages
  899. }