agent.go 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875
  1. // Package agent is the core orchestration layer for Crush AI agents.
  2. //
  3. // It provides session-based AI agent functionality for managing
  4. // conversations, tool execution, and message handling. It coordinates
  5. // interactions between language models, messages, sessions, and tools while
  6. // handling features like automatic summarization, queuing, and token
  7. // management.
  8. package agent
  9. import (
  10. "context"
  11. _ "embed"
  12. "errors"
  13. "fmt"
  14. "log/slog"
  15. "os"
  16. "strconv"
  17. "strings"
  18. "sync"
  19. "time"
  20. "charm.land/fantasy"
  21. "charm.land/fantasy/providers/anthropic"
  22. "charm.land/fantasy/providers/bedrock"
  23. "charm.land/fantasy/providers/google"
  24. "charm.land/fantasy/providers/openai"
  25. "charm.land/fantasy/providers/openrouter"
  26. "github.com/charmbracelet/catwalk/pkg/catwalk"
  27. "github.com/charmbracelet/crush/internal/agent/tools"
  28. "github.com/charmbracelet/crush/internal/config"
  29. "github.com/charmbracelet/crush/internal/csync"
  30. "github.com/charmbracelet/crush/internal/message"
  31. "github.com/charmbracelet/crush/internal/permission"
  32. "github.com/charmbracelet/crush/internal/session"
  33. )
  34. //go:embed templates/title.md
  35. var titlePrompt []byte
  36. //go:embed templates/summary.md
  37. var summaryPrompt []byte
  38. type SessionAgentCall struct {
  39. SessionID string
  40. Prompt string
  41. ProviderOptions fantasy.ProviderOptions
  42. Attachments []message.Attachment
  43. MaxOutputTokens int64
  44. Temperature *float64
  45. TopP *float64
  46. TopK *int64
  47. FrequencyPenalty *float64
  48. PresencePenalty *float64
  49. }
  50. type SessionAgent interface {
  51. Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  52. SetModels(large Model, small Model)
  53. SetTools(tools []fantasy.AgentTool)
  54. Cancel(sessionID string)
  55. CancelAll()
  56. IsSessionBusy(sessionID string) bool
  57. IsBusy() bool
  58. QueuedPrompts(sessionID string) int
  59. ClearQueue(sessionID string)
  60. Summarize(context.Context, string, fantasy.ProviderOptions) error
  61. Model() Model
  62. }
  63. type Model struct {
  64. Model fantasy.LanguageModel
  65. CatwalkCfg catwalk.Model
  66. ModelCfg config.SelectedModel
  67. }
  68. type sessionAgent struct {
  69. largeModel Model
  70. smallModel Model
  71. systemPromptPrefix string
  72. systemPrompt string
  73. tools []fantasy.AgentTool
  74. sessions session.Service
  75. messages message.Service
  76. disableAutoSummarize bool
  77. isYolo bool
  78. messageQueue *csync.Map[string, []SessionAgentCall]
  79. activeRequests *csync.Map[string, context.CancelFunc]
  80. }
  81. type SessionAgentOptions struct {
  82. LargeModel Model
  83. SmallModel Model
  84. SystemPromptPrefix string
  85. SystemPrompt string
  86. DisableAutoSummarize bool
  87. IsYolo bool
  88. Sessions session.Service
  89. Messages message.Service
  90. Tools []fantasy.AgentTool
  91. }
  92. func NewSessionAgent(
  93. opts SessionAgentOptions,
  94. ) SessionAgent {
  95. return &sessionAgent{
  96. largeModel: opts.LargeModel,
  97. smallModel: opts.SmallModel,
  98. systemPromptPrefix: opts.SystemPromptPrefix,
  99. systemPrompt: opts.SystemPrompt,
  100. sessions: opts.Sessions,
  101. messages: opts.Messages,
  102. disableAutoSummarize: opts.DisableAutoSummarize,
  103. tools: opts.Tools,
  104. isYolo: opts.IsYolo,
  105. messageQueue: csync.NewMap[string, []SessionAgentCall](),
  106. activeRequests: csync.NewMap[string, context.CancelFunc](),
  107. }
  108. }
  109. func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
  110. if call.Prompt == "" {
  111. return nil, ErrEmptyPrompt
  112. }
  113. if call.SessionID == "" {
  114. return nil, ErrSessionMissing
  115. }
  116. // Queue the message if busy
  117. if a.IsSessionBusy(call.SessionID) {
  118. existing, ok := a.messageQueue.Get(call.SessionID)
  119. if !ok {
  120. existing = []SessionAgentCall{}
  121. }
  122. existing = append(existing, call)
  123. a.messageQueue.Set(call.SessionID, existing)
  124. return nil, nil
  125. }
  126. if len(a.tools) > 0 {
  127. // Add Anthropic caching to the last tool.
  128. a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
  129. }
  130. agent := fantasy.NewAgent(
  131. a.largeModel.Model,
  132. fantasy.WithSystemPrompt(a.systemPrompt),
  133. fantasy.WithTools(a.tools...),
  134. )
  135. sessionLock := sync.Mutex{}
  136. currentSession, err := a.sessions.Get(ctx, call.SessionID)
  137. if err != nil {
  138. return nil, fmt.Errorf("failed to get session: %w", err)
  139. }
  140. msgs, err := a.getSessionMessages(ctx, currentSession)
  141. if err != nil {
  142. return nil, fmt.Errorf("failed to get session messages: %w", err)
  143. }
  144. var wg sync.WaitGroup
  145. // Generate title if first message.
  146. if len(msgs) == 0 {
  147. wg.Go(func() {
  148. sessionLock.Lock()
  149. a.generateTitle(ctx, &currentSession, call.Prompt)
  150. sessionLock.Unlock()
  151. })
  152. }
  153. // Add the user message to the session.
  154. _, err = a.createUserMessage(ctx, call)
  155. if err != nil {
  156. return nil, err
  157. }
  158. // Add the session to the context.
  159. ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
  160. genCtx, cancel := context.WithCancel(ctx)
  161. a.activeRequests.Set(call.SessionID, cancel)
  162. defer cancel()
  163. defer a.activeRequests.Del(call.SessionID)
  164. history, files := a.preparePrompt(msgs, call.Attachments...)
  165. startTime := time.Now()
  166. a.eventPromptSent(call.SessionID)
  167. var currentAssistant *message.Message
  168. var shouldSummarize bool
  169. result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
  170. Prompt: call.Prompt,
  171. Files: files,
  172. Messages: history,
  173. ProviderOptions: call.ProviderOptions,
  174. MaxOutputTokens: &call.MaxOutputTokens,
  175. TopP: call.TopP,
  176. Temperature: call.Temperature,
  177. PresencePenalty: call.PresencePenalty,
  178. TopK: call.TopK,
  179. FrequencyPenalty: call.FrequencyPenalty,
  180. // Before each step create a new assistant message.
  181. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  182. prepared.Messages = options.Messages
  183. // Reset all cached items.
  184. for i := range prepared.Messages {
  185. prepared.Messages[i].ProviderOptions = nil
  186. }
  187. queuedCalls, _ := a.messageQueue.Get(call.SessionID)
  188. a.messageQueue.Del(call.SessionID)
  189. for _, queued := range queuedCalls {
  190. userMessage, createErr := a.createUserMessage(callContext, queued)
  191. if createErr != nil {
  192. return callContext, prepared, createErr
  193. }
  194. prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
  195. }
  196. lastSystemRoleInx := 0
  197. systemMessageUpdated := false
  198. for i, msg := range prepared.Messages {
  199. // Only add cache control to the last message.
  200. if msg.Role == fantasy.MessageRoleSystem {
  201. lastSystemRoleInx = i
  202. } else if !systemMessageUpdated {
  203. prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
  204. systemMessageUpdated = true
  205. }
  206. // Than add cache control to the last 2 messages.
  207. if i > len(prepared.Messages)-3 {
  208. prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
  209. }
  210. }
  211. if a.systemPromptPrefix != "" {
  212. prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
  213. }
  214. var assistantMsg message.Message
  215. assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
  216. Role: message.Assistant,
  217. Parts: []message.ContentPart{},
  218. Model: a.largeModel.ModelCfg.Model,
  219. Provider: a.largeModel.ModelCfg.Provider,
  220. })
  221. if err != nil {
  222. return callContext, prepared, err
  223. }
  224. callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
  225. currentAssistant = &assistantMsg
  226. return callContext, prepared, err
  227. },
  228. OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
  229. currentAssistant.AppendReasoningContent(reasoning.Text)
  230. return a.messages.Update(genCtx, *currentAssistant)
  231. },
  232. OnReasoningDelta: func(id string, text string) error {
  233. currentAssistant.AppendReasoningContent(text)
  234. return a.messages.Update(genCtx, *currentAssistant)
  235. },
  236. OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
  237. // handle anthropic signature
  238. if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
  239. if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
  240. currentAssistant.AppendReasoningSignature(reasoning.Signature)
  241. }
  242. }
  243. if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
  244. if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
  245. currentAssistant.AppendReasoningSignature(reasoning.Signature)
  246. }
  247. }
  248. if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
  249. if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
  250. currentAssistant.SetReasoningResponsesData(reasoning)
  251. }
  252. }
  253. currentAssistant.FinishThinking()
  254. return a.messages.Update(genCtx, *currentAssistant)
  255. },
  256. OnTextDelta: func(id string, text string) error {
  257. // Strip leading newline from initial text content. This is is
  258. // particularly important in non-interactive mode where leading
  259. // newlines are very visible.
  260. if len(currentAssistant.Parts) == 0 {
  261. text = strings.TrimPrefix(text, "\n")
  262. }
  263. currentAssistant.AppendContent(text)
  264. return a.messages.Update(genCtx, *currentAssistant)
  265. },
  266. OnToolInputStart: func(id string, toolName string) error {
  267. toolCall := message.ToolCall{
  268. ID: id,
  269. Name: toolName,
  270. ProviderExecuted: false,
  271. Finished: false,
  272. }
  273. currentAssistant.AddToolCall(toolCall)
  274. return a.messages.Update(genCtx, *currentAssistant)
  275. },
  276. OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
  277. // TODO: implement
  278. },
  279. OnToolCall: func(tc fantasy.ToolCallContent) error {
  280. toolCall := message.ToolCall{
  281. ID: tc.ToolCallID,
  282. Name: tc.ToolName,
  283. Input: tc.Input,
  284. ProviderExecuted: false,
  285. Finished: true,
  286. }
  287. currentAssistant.AddToolCall(toolCall)
  288. return a.messages.Update(genCtx, *currentAssistant)
  289. },
  290. OnToolResult: func(result fantasy.ToolResultContent) error {
  291. var resultContent string
  292. isError := false
  293. switch result.Result.GetType() {
  294. case fantasy.ToolResultContentTypeText:
  295. r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
  296. if ok {
  297. resultContent = r.Text
  298. }
  299. case fantasy.ToolResultContentTypeError:
  300. r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
  301. if ok {
  302. isError = true
  303. resultContent = r.Error.Error()
  304. }
  305. case fantasy.ToolResultContentTypeMedia:
  306. // TODO: handle this message type
  307. }
  308. toolResult := message.ToolResult{
  309. ToolCallID: result.ToolCallID,
  310. Name: result.ToolName,
  311. Content: resultContent,
  312. IsError: isError,
  313. Metadata: result.ClientMetadata,
  314. }
  315. _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
  316. Role: message.Tool,
  317. Parts: []message.ContentPart{
  318. toolResult,
  319. },
  320. })
  321. if createMsgErr != nil {
  322. return createMsgErr
  323. }
  324. return nil
  325. },
  326. OnStepFinish: func(stepResult fantasy.StepResult) error {
  327. finishReason := message.FinishReasonUnknown
  328. switch stepResult.FinishReason {
  329. case fantasy.FinishReasonLength:
  330. finishReason = message.FinishReasonMaxTokens
  331. case fantasy.FinishReasonStop:
  332. finishReason = message.FinishReasonEndTurn
  333. case fantasy.FinishReasonToolCalls:
  334. finishReason = message.FinishReasonToolUse
  335. }
  336. currentAssistant.AddFinish(finishReason, "", "")
  337. a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
  338. sessionLock.Lock()
  339. _, sessionErr := a.sessions.Save(genCtx, currentSession)
  340. sessionLock.Unlock()
  341. if sessionErr != nil {
  342. return sessionErr
  343. }
  344. return a.messages.Update(genCtx, *currentAssistant)
  345. },
  346. StopWhen: []fantasy.StopCondition{
  347. func(_ []fantasy.StepResult) bool {
  348. cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
  349. tokens := currentSession.CompletionTokens + currentSession.PromptTokens
  350. remaining := cw - tokens
  351. var threshold int64
  352. if cw > 200_000 {
  353. threshold = 20_000
  354. } else {
  355. threshold = int64(float64(cw) * 0.2)
  356. }
  357. if (remaining <= threshold) && !a.disableAutoSummarize {
  358. shouldSummarize = true
  359. return true
  360. }
  361. return false
  362. },
  363. },
  364. })
  365. a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
  366. if err != nil {
  367. isCancelErr := errors.Is(err, context.Canceled)
  368. isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
  369. if currentAssistant == nil {
  370. return result, err
  371. }
  372. // Ensure we finish thinking on error to close the reasoning state.
  373. currentAssistant.FinishThinking()
  374. toolCalls := currentAssistant.ToolCalls()
  375. // INFO: we use the parent context here because the genCtx has been cancelled.
  376. msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
  377. if createErr != nil {
  378. return nil, createErr
  379. }
  380. for _, tc := range toolCalls {
  381. if !tc.Finished {
  382. tc.Finished = true
  383. tc.Input = "{}"
  384. currentAssistant.AddToolCall(tc)
  385. updateErr := a.messages.Update(ctx, *currentAssistant)
  386. if updateErr != nil {
  387. return nil, updateErr
  388. }
  389. }
  390. found := false
  391. for _, msg := range msgs {
  392. if msg.Role == message.Tool {
  393. for _, tr := range msg.ToolResults() {
  394. if tr.ToolCallID == tc.ID {
  395. found = true
  396. break
  397. }
  398. }
  399. }
  400. if found {
  401. break
  402. }
  403. }
  404. if found {
  405. continue
  406. }
  407. content := "There was an error while executing the tool"
  408. if isCancelErr {
  409. content = "Tool execution canceled by user"
  410. } else if isPermissionErr {
  411. content = "User denied permission"
  412. }
  413. toolResult := message.ToolResult{
  414. ToolCallID: tc.ID,
  415. Name: tc.Name,
  416. Content: content,
  417. IsError: true,
  418. }
  419. _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
  420. Role: message.Tool,
  421. Parts: []message.ContentPart{
  422. toolResult,
  423. },
  424. })
  425. if createErr != nil {
  426. return nil, createErr
  427. }
  428. }
  429. if isCancelErr {
  430. currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
  431. } else if isPermissionErr {
  432. currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
  433. } else {
  434. currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
  435. }
  436. // Note: we use the parent context here because the genCtx has been
  437. // cancelled.
  438. updateErr := a.messages.Update(ctx, *currentAssistant)
  439. if updateErr != nil {
  440. return nil, updateErr
  441. }
  442. return nil, err
  443. }
  444. wg.Wait()
  445. if shouldSummarize {
  446. a.activeRequests.Del(call.SessionID)
  447. if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
  448. return nil, summarizeErr
  449. }
  450. // If the agent wasn't done...
  451. if len(currentAssistant.ToolCalls()) > 0 {
  452. existing, ok := a.messageQueue.Get(call.SessionID)
  453. if !ok {
  454. existing = []SessionAgentCall{}
  455. }
  456. call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
  457. existing = append(existing, call)
  458. a.messageQueue.Set(call.SessionID, existing)
  459. }
  460. }
  461. // Release active request before processing queued messages.
  462. a.activeRequests.Del(call.SessionID)
  463. cancel()
  464. queuedMessages, ok := a.messageQueue.Get(call.SessionID)
  465. if !ok || len(queuedMessages) == 0 {
  466. return result, err
  467. }
  468. // There are queued messages restart the loop.
  469. firstQueuedMessage := queuedMessages[0]
  470. a.messageQueue.Set(call.SessionID, queuedMessages[1:])
  471. return a.Run(ctx, firstQueuedMessage)
  472. }
  473. func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
  474. if a.IsSessionBusy(sessionID) {
  475. return ErrSessionBusy
  476. }
  477. currentSession, err := a.sessions.Get(ctx, sessionID)
  478. if err != nil {
  479. return fmt.Errorf("failed to get session: %w", err)
  480. }
  481. msgs, err := a.getSessionMessages(ctx, currentSession)
  482. if err != nil {
  483. return err
  484. }
  485. if len(msgs) == 0 {
  486. // Nothing to summarize.
  487. return nil
  488. }
  489. aiMsgs, _ := a.preparePrompt(msgs)
  490. genCtx, cancel := context.WithCancel(ctx)
  491. a.activeRequests.Set(sessionID, cancel)
  492. defer a.activeRequests.Del(sessionID)
  493. defer cancel()
  494. agent := fantasy.NewAgent(a.largeModel.Model,
  495. fantasy.WithSystemPrompt(string(summaryPrompt)),
  496. )
  497. summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  498. Role: message.Assistant,
  499. Model: a.largeModel.Model.Model(),
  500. Provider: a.largeModel.Model.Provider(),
  501. IsSummaryMessage: true,
  502. })
  503. if err != nil {
  504. return err
  505. }
  506. resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
  507. Prompt: "Provide a detailed summary of our conversation above.",
  508. Messages: aiMsgs,
  509. ProviderOptions: opts,
  510. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  511. prepared.Messages = options.Messages
  512. if a.systemPromptPrefix != "" {
  513. prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
  514. }
  515. return callContext, prepared, nil
  516. },
  517. OnReasoningDelta: func(id string, text string) error {
  518. summaryMessage.AppendReasoningContent(text)
  519. return a.messages.Update(genCtx, summaryMessage)
  520. },
  521. OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
  522. // Handle anthropic signature.
  523. if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
  524. if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
  525. summaryMessage.AppendReasoningSignature(signature.Signature)
  526. }
  527. }
  528. summaryMessage.FinishThinking()
  529. return a.messages.Update(genCtx, summaryMessage)
  530. },
  531. OnTextDelta: func(id, text string) error {
  532. summaryMessage.AppendContent(text)
  533. return a.messages.Update(genCtx, summaryMessage)
  534. },
  535. })
  536. if err != nil {
  537. isCancelErr := errors.Is(err, context.Canceled)
  538. if isCancelErr {
  539. // User cancelled summarize we need to remove the summary message.
  540. deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
  541. return deleteErr
  542. }
  543. return err
  544. }
  545. summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
  546. err = a.messages.Update(genCtx, summaryMessage)
  547. if err != nil {
  548. return err
  549. }
  550. var openrouterCost *float64
  551. for _, step := range resp.Steps {
  552. stepCost := a.openrouterCost(step.ProviderMetadata)
  553. if stepCost != nil {
  554. newCost := *stepCost
  555. if openrouterCost != nil {
  556. newCost += *openrouterCost
  557. }
  558. openrouterCost = &newCost
  559. }
  560. }
  561. a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
  562. // Just in case, get just the last usage info.
  563. usage := resp.Response.Usage
  564. currentSession.SummaryMessageID = summaryMessage.ID
  565. currentSession.CompletionTokens = usage.OutputTokens
  566. currentSession.PromptTokens = 0
  567. _, err = a.sessions.Save(genCtx, currentSession)
  568. return err
  569. }
  570. func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
  571. if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
  572. return fantasy.ProviderOptions{}
  573. }
  574. return fantasy.ProviderOptions{
  575. anthropic.Name: &anthropic.ProviderCacheControlOptions{
  576. CacheControl: anthropic.CacheControl{Type: "ephemeral"},
  577. },
  578. bedrock.Name: &anthropic.ProviderCacheControlOptions{
  579. CacheControl: anthropic.CacheControl{Type: "ephemeral"},
  580. },
  581. }
  582. }
  583. func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
  584. var attachmentParts []message.ContentPart
  585. for _, attachment := range call.Attachments {
  586. attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
  587. }
  588. parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
  589. parts = append(parts, attachmentParts...)
  590. msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
  591. Role: message.User,
  592. Parts: parts,
  593. })
  594. if err != nil {
  595. return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
  596. }
  597. return msg, nil
  598. }
  599. func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
  600. var history []fantasy.Message
  601. for _, m := range msgs {
  602. if len(m.Parts) == 0 {
  603. continue
  604. }
  605. // Assistant message without content or tool calls (cancelled before it
  606. // returned anything).
  607. if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
  608. continue
  609. }
  610. history = append(history, m.ToAIMessage()...)
  611. }
  612. var files []fantasy.FilePart
  613. for _, attachment := range attachments {
  614. files = append(files, fantasy.FilePart{
  615. Filename: attachment.FileName,
  616. Data: attachment.Content,
  617. MediaType: attachment.MimeType,
  618. })
  619. }
  620. return history, files
  621. }
  622. func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
  623. msgs, err := a.messages.List(ctx, session.ID)
  624. if err != nil {
  625. return nil, fmt.Errorf("failed to list messages: %w", err)
  626. }
  627. if session.SummaryMessageID != "" {
  628. summaryMsgInex := -1
  629. for i, msg := range msgs {
  630. if msg.ID == session.SummaryMessageID {
  631. summaryMsgInex = i
  632. break
  633. }
  634. }
  635. if summaryMsgInex != -1 {
  636. msgs = msgs[summaryMsgInex:]
  637. msgs[0].Role = message.User
  638. }
  639. }
  640. return msgs, nil
  641. }
  642. func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
  643. if prompt == "" {
  644. return
  645. }
  646. var maxOutput int64 = 40
  647. if a.smallModel.CatwalkCfg.CanReason {
  648. maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
  649. }
  650. agent := fantasy.NewAgent(a.smallModel.Model,
  651. fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
  652. fantasy.WithMaxOutputTokens(maxOutput),
  653. )
  654. resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
  655. Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
  656. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  657. prepared.Messages = options.Messages
  658. if a.systemPromptPrefix != "" {
  659. prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
  660. }
  661. return callContext, prepared, nil
  662. },
  663. })
  664. if err != nil {
  665. slog.Error("error generating title", "err", err)
  666. return
  667. }
  668. title := resp.Response.Content.Text()
  669. title = strings.ReplaceAll(title, "\n", " ")
  670. // Remove thinking tags if present.
  671. if idx := strings.Index(title, "</think>"); idx > 0 {
  672. title = title[idx+len("</think>"):]
  673. }
  674. title = strings.TrimSpace(title)
  675. if title == "" {
  676. slog.Warn("failed to generate title", "warn", "empty title")
  677. return
  678. }
  679. session.Title = title
  680. var openrouterCost *float64
  681. for _, step := range resp.Steps {
  682. stepCost := a.openrouterCost(step.ProviderMetadata)
  683. if stepCost != nil {
  684. newCost := *stepCost
  685. if openrouterCost != nil {
  686. newCost += *openrouterCost
  687. }
  688. openrouterCost = &newCost
  689. }
  690. }
  691. a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
  692. _, saveErr := a.sessions.Save(ctx, *session)
  693. if saveErr != nil {
  694. slog.Error("failed to save session title & usage", "error", saveErr)
  695. return
  696. }
  697. }
  698. func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
  699. openrouterMetadata, ok := metadata[openrouter.Name]
  700. if !ok {
  701. return nil
  702. }
  703. opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
  704. if !ok {
  705. return nil
  706. }
  707. return &opts.Usage.Cost
  708. }
  709. func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
  710. modelConfig := model.CatwalkCfg
  711. cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
  712. modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
  713. modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
  714. modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
  715. a.eventTokensUsed(session.ID, model, usage, cost)
  716. if overrideCost != nil {
  717. session.Cost += *overrideCost
  718. } else {
  719. session.Cost += cost
  720. }
  721. session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
  722. session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
  723. }
  724. func (a *sessionAgent) Cancel(sessionID string) {
  725. // Cancel regular requests.
  726. if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
  727. slog.Info("Request cancellation initiated", "session_id", sessionID)
  728. cancel()
  729. }
  730. // Also check for summarize requests.
  731. if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
  732. slog.Info("Summarize cancellation initiated", "session_id", sessionID)
  733. cancel()
  734. }
  735. if a.QueuedPrompts(sessionID) > 0 {
  736. slog.Info("Clearing queued prompts", "session_id", sessionID)
  737. a.messageQueue.Del(sessionID)
  738. }
  739. }
  740. func (a *sessionAgent) ClearQueue(sessionID string) {
  741. if a.QueuedPrompts(sessionID) > 0 {
  742. slog.Info("Clearing queued prompts", "session_id", sessionID)
  743. a.messageQueue.Del(sessionID)
  744. }
  745. }
  746. func (a *sessionAgent) CancelAll() {
  747. if !a.IsBusy() {
  748. return
  749. }
  750. for key := range a.activeRequests.Seq2() {
  751. a.Cancel(key) // key is sessionID
  752. }
  753. timeout := time.After(5 * time.Second)
  754. for a.IsBusy() {
  755. select {
  756. case <-timeout:
  757. return
  758. default:
  759. time.Sleep(200 * time.Millisecond)
  760. }
  761. }
  762. }
  763. func (a *sessionAgent) IsBusy() bool {
  764. var busy bool
  765. for cancelFunc := range a.activeRequests.Seq() {
  766. if cancelFunc != nil {
  767. busy = true
  768. break
  769. }
  770. }
  771. return busy
  772. }
  773. func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
  774. _, busy := a.activeRequests.Get(sessionID)
  775. return busy
  776. }
  777. func (a *sessionAgent) QueuedPrompts(sessionID string) int {
  778. l, ok := a.messageQueue.Get(sessionID)
  779. if !ok {
  780. return 0
  781. }
  782. return len(l)
  783. }
  784. func (a *sessionAgent) SetModels(large Model, small Model) {
  785. a.largeModel = large
  786. a.smallModel = small
  787. }
  788. func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
  789. a.tools = tools
  790. }
  791. func (a *sessionAgent) Model() Model {
  792. return a.largeModel
  793. }