agent.go 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967
  1. // Package agent is the core orchestration layer for Crush AI agents.
  2. //
  3. // It provides session-based AI agent functionality for managing
  4. // conversations, tool execution, and message handling. It coordinates
  5. // interactions between language models, messages, sessions, and tools while
  6. // handling features like automatic summarization, queuing, and token
  7. // management.
  8. package agent
  9. import (
  10. "cmp"
  11. "context"
  12. _ "embed"
  13. "errors"
  14. "fmt"
  15. "log/slog"
  16. "os"
  17. "strconv"
  18. "strings"
  19. "sync"
  20. "time"
  21. "charm.land/fantasy"
  22. "charm.land/fantasy/providers/anthropic"
  23. "charm.land/fantasy/providers/bedrock"
  24. "charm.land/fantasy/providers/google"
  25. "charm.land/fantasy/providers/openai"
  26. "charm.land/fantasy/providers/openrouter"
  27. "github.com/charmbracelet/catwalk/pkg/catwalk"
  28. "github.com/charmbracelet/crush/internal/agent/tools"
  29. "github.com/charmbracelet/crush/internal/config"
  30. "github.com/charmbracelet/crush/internal/csync"
  31. "github.com/charmbracelet/crush/internal/hooks"
  32. "github.com/charmbracelet/crush/internal/message"
  33. "github.com/charmbracelet/crush/internal/permission"
  34. "github.com/charmbracelet/crush/internal/session"
  35. "github.com/charmbracelet/crush/internal/stringext"
  36. )
  37. //go:embed templates/title.md
  38. var titlePrompt []byte
  39. //go:embed templates/summary.md
  40. var summaryPrompt []byte
  41. type SessionAgentCall struct {
  42. SessionID string
  43. Prompt string
  44. ProviderOptions fantasy.ProviderOptions
  45. Attachments []message.Attachment
  46. MaxOutputTokens int64
  47. Temperature *float64
  48. TopP *float64
  49. TopK *int64
  50. FrequencyPenalty *float64
  51. PresencePenalty *float64
  52. }
  53. type SessionAgent interface {
  54. Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  55. SetModels(large Model, small Model)
  56. SetTools(tools []fantasy.AgentTool)
  57. Cancel(sessionID string)
  58. CancelAll()
  59. IsSessionBusy(sessionID string) bool
  60. IsBusy() bool
  61. QueuedPrompts(sessionID string) int
  62. ClearQueue(sessionID string)
  63. Summarize(context.Context, string, fantasy.ProviderOptions) error
  64. Model() Model
  65. }
  66. type Model struct {
  67. Model fantasy.LanguageModel
  68. CatwalkCfg catwalk.Model
  69. ModelCfg config.SelectedModel
  70. }
  71. type sessionAgent struct {
  72. largeModel Model
  73. smallModel Model
  74. systemPromptPrefix string
  75. systemPrompt string
  76. tools []fantasy.AgentTool
  77. sessions session.Service
  78. messages message.Service
  79. hooks hooks.Service
  80. disableAutoSummarize bool
  81. isYolo bool
  82. messageQueue *csync.Map[string, []SessionAgentCall]
  83. activeRequests *csync.Map[string, context.CancelFunc]
  84. }
  85. type SessionAgentOptions struct {
  86. LargeModel Model
  87. SmallModel Model
  88. SystemPromptPrefix string
  89. SystemPrompt string
  90. DisableAutoSummarize bool
  91. IsYolo bool
  92. Sessions session.Service
  93. Messages message.Service
  94. Hooks hooks.Service
  95. Tools []fantasy.AgentTool
  96. }
  97. func NewSessionAgent(
  98. opts SessionAgentOptions,
  99. ) SessionAgent {
  100. return &sessionAgent{
  101. largeModel: opts.LargeModel,
  102. smallModel: opts.SmallModel,
  103. systemPromptPrefix: opts.SystemPromptPrefix,
  104. systemPrompt: opts.SystemPrompt,
  105. sessions: opts.Sessions,
  106. messages: opts.Messages,
  107. disableAutoSummarize: opts.DisableAutoSummarize,
  108. hooks: opts.Hooks,
  109. tools: opts.Tools,
  110. isYolo: opts.IsYolo,
  111. messageQueue: csync.NewMap[string, []SessionAgentCall](),
  112. activeRequests: csync.NewMap[string, context.CancelFunc](),
  113. }
  114. }
  115. func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
  116. if call.Prompt == "" {
  117. return nil, ErrEmptyPrompt
  118. }
  119. if call.SessionID == "" {
  120. return nil, ErrSessionMissing
  121. }
  122. // Queue the message if busy
  123. if a.IsSessionBusy(call.SessionID) {
  124. existing, ok := a.messageQueue.Get(call.SessionID)
  125. if !ok {
  126. existing = []SessionAgentCall{}
  127. }
  128. existing = append(existing, call)
  129. a.messageQueue.Set(call.SessionID, existing)
  130. return nil, nil
  131. }
  132. if len(a.tools) > 0 {
  133. // Add Anthropic caching to the last tool.
  134. a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
  135. }
  136. agent := fantasy.NewAgent(
  137. a.largeModel.Model,
  138. fantasy.WithSystemPrompt(a.systemPrompt),
  139. fantasy.WithTools(a.tools...),
  140. )
  141. sessionLock := sync.Mutex{}
  142. currentSession, err := a.sessions.Get(ctx, call.SessionID)
  143. if err != nil {
  144. return nil, fmt.Errorf("failed to get session: %w", err)
  145. }
  146. msgs, err := a.getSessionMessages(ctx, currentSession)
  147. if err != nil {
  148. return nil, fmt.Errorf("failed to get session messages: %w", err)
  149. }
  150. var wg sync.WaitGroup
  151. // Generate title if first message.
  152. if len(msgs) == 0 {
  153. wg.Go(func() {
  154. sessionLock.Lock()
  155. a.generateTitle(ctx, &currentSession, call.Prompt)
  156. sessionLock.Unlock()
  157. })
  158. }
  159. // Add the user message to the session.
  160. userMsg, err := a.createUserMessage(ctx, call)
  161. if err != nil {
  162. return nil, err
  163. }
  164. // Add the session to the context.
  165. ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
  166. genCtx, cancel := context.WithCancel(ctx)
  167. a.activeRequests.Set(call.SessionID, cancel)
  168. defer cancel()
  169. defer a.activeRequests.Del(call.SessionID)
  170. // create the agent message asap to show loading
  171. var currentAssistant *message.Message
  172. assistantMessage, err := a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
  173. Role: message.Assistant,
  174. Parts: []message.ContentPart{},
  175. Model: a.largeModel.ModelCfg.Model,
  176. Provider: a.largeModel.ModelCfg.Provider,
  177. })
  178. if err != nil {
  179. return nil, err
  180. }
  181. currentAssistant = &assistantMessage
  182. // run hooks after assistant message is created
  183. // this way we show loading for the user
  184. hooks, err := a.executeUserPromptSubmitHook(genCtx, call.SessionID, call.Prompt)
  185. if err != nil {
  186. return nil, err
  187. }
  188. userMsg.AddHookOutputs(hooks...)
  189. if updateErr := a.messages.Update(genCtx, userMsg); updateErr != nil {
  190. return nil, updateErr
  191. }
  192. for _, hook := range hooks {
  193. // execution stopped
  194. if hook.Stop {
  195. deleteErr := a.messages.Delete(genCtx, assistantMessage.ID)
  196. if deleteErr != nil {
  197. return nil, deleteErr
  198. }
  199. return nil, nil
  200. }
  201. }
  202. history, files := a.preparePrompt(msgs, call.Attachments...)
  203. startTime := time.Now()
  204. a.eventPromptSent(call.SessionID)
  205. var shouldSummarize bool
  206. result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
  207. Prompt: userMsg.ContentWithHooksContext(),
  208. Files: files,
  209. Messages: history,
  210. ProviderOptions: call.ProviderOptions,
  211. MaxOutputTokens: &call.MaxOutputTokens,
  212. TopP: call.TopP,
  213. Temperature: call.Temperature,
  214. PresencePenalty: call.PresencePenalty,
  215. TopK: call.TopK,
  216. FrequencyPenalty: call.FrequencyPenalty,
  217. // Before each step create a new assistant message.
  218. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  219. // only add new assistant message when its not the first step
  220. if options.StepNumber != 0 {
  221. var assistantMsg message.Message
  222. assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
  223. Role: message.Assistant,
  224. Parts: []message.ContentPart{},
  225. Model: a.largeModel.ModelCfg.Model,
  226. Provider: a.largeModel.ModelCfg.Provider,
  227. })
  228. currentAssistant = &assistantMsg
  229. // create the message first so we show loading asap
  230. if err != nil {
  231. return callContext, prepared, err
  232. }
  233. }
  234. prepared.Messages = options.Messages
  235. // Reset all cached items.
  236. for i := range prepared.Messages {
  237. prepared.Messages[i].ProviderOptions = nil
  238. }
  239. queuedCalls, _ := a.messageQueue.Get(call.SessionID)
  240. a.messageQueue.Del(call.SessionID)
  241. for _, queued := range queuedCalls {
  242. userMessage, createErr := a.createUserMessage(callContext, queued)
  243. if createErr != nil {
  244. return callContext, prepared, createErr
  245. }
  246. // run hooks after assistant message is created
  247. // this way we show loading for the user
  248. hooks, hookErr := a.executeUserPromptSubmitHook(genCtx, call.SessionID, call.Prompt)
  249. if hookErr != nil {
  250. return callContext, prepared, hookErr
  251. }
  252. userMsg.AddHookOutputs(hooks...)
  253. for _, hook := range hooks {
  254. // execution stopped
  255. if hook.Stop {
  256. deleteErr := a.messages.Delete(genCtx, assistantMessage.ID)
  257. if deleteErr != nil {
  258. return callContext, prepared, deleteErr
  259. }
  260. return callContext, prepared, ErrHookCancellation
  261. }
  262. }
  263. if updateErr := a.messages.Update(genCtx, userMsg); updateErr != nil {
  264. return callContext, prepared, updateErr
  265. }
  266. prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
  267. }
  268. lastSystemRoleInx := 0
  269. systemMessageUpdated := false
  270. for i, msg := range prepared.Messages {
  271. // Only add cache control to the last message.
  272. if msg.Role == fantasy.MessageRoleSystem {
  273. lastSystemRoleInx = i
  274. } else if !systemMessageUpdated {
  275. prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
  276. systemMessageUpdated = true
  277. }
  278. // Than add cache control to the last 2 messages.
  279. if i > len(prepared.Messages)-3 {
  280. prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
  281. }
  282. }
  283. if a.systemPromptPrefix != "" {
  284. prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
  285. }
  286. callContext = context.WithValue(callContext, tools.MessageIDContextKey, currentAssistant.ID)
  287. return callContext, prepared, err
  288. },
  289. OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
  290. currentAssistant.AppendReasoningContent(reasoning.Text)
  291. return a.messages.Update(genCtx, *currentAssistant)
  292. },
  293. OnReasoningDelta: func(id string, text string) error {
  294. currentAssistant.AppendReasoningContent(text)
  295. return a.messages.Update(genCtx, *currentAssistant)
  296. },
  297. OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
  298. // handle anthropic signature
  299. if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
  300. if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
  301. currentAssistant.AppendReasoningSignature(reasoning.Signature)
  302. }
  303. }
  304. if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
  305. if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
  306. currentAssistant.AppendReasoningSignature(reasoning.Signature)
  307. }
  308. }
  309. if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
  310. if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
  311. currentAssistant.SetReasoningResponsesData(reasoning)
  312. }
  313. }
  314. currentAssistant.FinishThinking()
  315. return a.messages.Update(genCtx, *currentAssistant)
  316. },
  317. OnTextDelta: func(id string, text string) error {
  318. // Strip leading newline from initial text content. This is is
  319. // particularly important in non-interactive mode where leading
  320. // newlines are very visible.
  321. if len(currentAssistant.Parts) == 0 {
  322. text = strings.TrimPrefix(text, "\n")
  323. }
  324. currentAssistant.AppendContent(text)
  325. return a.messages.Update(genCtx, *currentAssistant)
  326. },
  327. OnToolInputStart: func(id string, toolName string) error {
  328. toolCall := message.ToolCall{
  329. ID: id,
  330. Name: toolName,
  331. ProviderExecuted: false,
  332. Finished: false,
  333. }
  334. currentAssistant.AddToolCall(toolCall)
  335. return a.messages.Update(genCtx, *currentAssistant)
  336. },
  337. OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
  338. // TODO: implement
  339. },
  340. OnToolCall: func(tc fantasy.ToolCallContent) error {
  341. toolCall := message.ToolCall{
  342. ID: tc.ToolCallID,
  343. Name: tc.ToolName,
  344. Input: tc.Input,
  345. ProviderExecuted: false,
  346. Finished: true,
  347. }
  348. currentAssistant.AddToolCall(toolCall)
  349. return a.messages.Update(genCtx, *currentAssistant)
  350. },
  351. OnToolResult: func(result fantasy.ToolResultContent) error {
  352. var resultContent string
  353. isError := false
  354. switch result.Result.GetType() {
  355. case fantasy.ToolResultContentTypeText:
  356. r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
  357. if ok {
  358. resultContent = r.Text
  359. }
  360. case fantasy.ToolResultContentTypeError:
  361. r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
  362. if ok {
  363. isError = true
  364. resultContent = r.Error.Error()
  365. }
  366. case fantasy.ToolResultContentTypeMedia:
  367. // TODO: handle this message type
  368. }
  369. toolResult := message.ToolResult{
  370. ToolCallID: result.ToolCallID,
  371. Name: result.ToolName,
  372. Content: resultContent,
  373. IsError: isError,
  374. Metadata: result.ClientMetadata,
  375. }
  376. _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
  377. Role: message.Tool,
  378. Parts: []message.ContentPart{
  379. toolResult,
  380. },
  381. })
  382. if createMsgErr != nil {
  383. return createMsgErr
  384. }
  385. return nil
  386. },
  387. OnStepFinish: func(stepResult fantasy.StepResult) error {
  388. finishReason := message.FinishReasonUnknown
  389. switch stepResult.FinishReason {
  390. case fantasy.FinishReasonLength:
  391. finishReason = message.FinishReasonMaxTokens
  392. case fantasy.FinishReasonStop:
  393. finishReason = message.FinishReasonEndTurn
  394. case fantasy.FinishReasonToolCalls:
  395. finishReason = message.FinishReasonToolUse
  396. }
  397. currentAssistant.AddFinish(finishReason, "", "")
  398. a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
  399. sessionLock.Lock()
  400. _, sessionErr := a.sessions.Save(genCtx, currentSession)
  401. sessionLock.Unlock()
  402. if sessionErr != nil {
  403. return sessionErr
  404. }
  405. return a.messages.Update(genCtx, *currentAssistant)
  406. },
  407. StopWhen: []fantasy.StopCondition{
  408. func(_ []fantasy.StepResult) bool {
  409. cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
  410. tokens := currentSession.CompletionTokens + currentSession.PromptTokens
  411. remaining := cw - tokens
  412. var threshold int64
  413. if cw > 200_000 {
  414. threshold = 20_000
  415. } else {
  416. threshold = int64(float64(cw) * 0.2)
  417. }
  418. if (remaining <= threshold) && !a.disableAutoSummarize {
  419. shouldSummarize = true
  420. return true
  421. }
  422. return false
  423. },
  424. },
  425. })
  426. a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
  427. if err != nil {
  428. isCancelErr := errors.Is(err, context.Canceled)
  429. isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
  430. isHookCancelled := errors.Is(err, ErrHookCancellation)
  431. if currentAssistant == nil {
  432. return result, err
  433. }
  434. // Ensure we finish thinking on error to close the reasoning state.
  435. currentAssistant.FinishThinking()
  436. toolCalls := currentAssistant.ToolCalls()
  437. // INFO: we use the parent context here because the genCtx has been cancelled.
  438. msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
  439. if createErr != nil {
  440. return nil, createErr
  441. }
  442. for _, tc := range toolCalls {
  443. if !tc.Finished {
  444. tc.Finished = true
  445. tc.Input = "{}"
  446. currentAssistant.AddToolCall(tc)
  447. updateErr := a.messages.Update(ctx, *currentAssistant)
  448. if updateErr != nil {
  449. return nil, updateErr
  450. }
  451. }
  452. found := false
  453. for _, msg := range msgs {
  454. if msg.Role == message.Tool {
  455. for _, tr := range msg.ToolResults() {
  456. if tr.ToolCallID == tc.ID {
  457. found = true
  458. break
  459. }
  460. }
  461. }
  462. if found {
  463. break
  464. }
  465. }
  466. if found {
  467. continue
  468. }
  469. content := "There was an error while executing the tool"
  470. if isCancelErr {
  471. content = "Tool execution canceled by user"
  472. } else if isPermissionErr {
  473. content = "User denied permission"
  474. }
  475. toolResult := message.ToolResult{
  476. ToolCallID: tc.ID,
  477. Name: tc.Name,
  478. Content: content,
  479. IsError: true,
  480. }
  481. _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
  482. Role: message.Tool,
  483. Parts: []message.ContentPart{
  484. toolResult,
  485. },
  486. })
  487. if createErr != nil {
  488. return nil, createErr
  489. }
  490. }
  491. var fantasyErr *fantasy.Error
  492. var providerErr *fantasy.ProviderError
  493. const defaultTitle = "Provider Error"
  494. if isCancelErr {
  495. currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
  496. } else if isPermissionErr {
  497. currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
  498. } else if isHookCancelled {
  499. currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Hook canceled request", "")
  500. } else if errors.As(err, &providerErr) {
  501. currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
  502. } else if errors.As(err, &fantasyErr) {
  503. currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
  504. } else {
  505. currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
  506. }
  507. // Note: we use the parent context here because the genCtx has been
  508. // cancelled.
  509. updateErr := a.messages.Update(ctx, *currentAssistant)
  510. if updateErr != nil {
  511. return nil, updateErr
  512. }
  513. return nil, err
  514. }
  515. wg.Wait()
  516. if shouldSummarize {
  517. a.activeRequests.Del(call.SessionID)
  518. if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
  519. return nil, summarizeErr
  520. }
  521. // If the agent wasn't done...
  522. if len(currentAssistant.ToolCalls()) > 0 {
  523. existing, ok := a.messageQueue.Get(call.SessionID)
  524. if !ok {
  525. existing = []SessionAgentCall{}
  526. }
  527. call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
  528. existing = append(existing, call)
  529. a.messageQueue.Set(call.SessionID, existing)
  530. }
  531. }
  532. // Release active request before processing queued messages.
  533. a.activeRequests.Del(call.SessionID)
  534. cancel()
  535. queuedMessages, ok := a.messageQueue.Get(call.SessionID)
  536. if !ok || len(queuedMessages) == 0 {
  537. return result, err
  538. }
  539. // There are queued messages restart the loop.
  540. firstQueuedMessage := queuedMessages[0]
  541. a.messageQueue.Set(call.SessionID, queuedMessages[1:])
  542. return a.Run(ctx, firstQueuedMessage)
  543. }
  544. func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
  545. if a.IsSessionBusy(sessionID) {
  546. return ErrSessionBusy
  547. }
  548. currentSession, err := a.sessions.Get(ctx, sessionID)
  549. if err != nil {
  550. return fmt.Errorf("failed to get session: %w", err)
  551. }
  552. msgs, err := a.getSessionMessages(ctx, currentSession)
  553. if err != nil {
  554. return err
  555. }
  556. if len(msgs) == 0 {
  557. // Nothing to summarize.
  558. return nil
  559. }
  560. aiMsgs, _ := a.preparePrompt(msgs)
  561. genCtx, cancel := context.WithCancel(ctx)
  562. a.activeRequests.Set(sessionID, cancel)
  563. defer a.activeRequests.Del(sessionID)
  564. defer cancel()
  565. agent := fantasy.NewAgent(a.largeModel.Model,
  566. fantasy.WithSystemPrompt(string(summaryPrompt)),
  567. )
  568. summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  569. Role: message.Assistant,
  570. Model: a.largeModel.Model.Model(),
  571. Provider: a.largeModel.Model.Provider(),
  572. IsSummaryMessage: true,
  573. })
  574. if err != nil {
  575. return err
  576. }
  577. resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
  578. Prompt: "Provide a detailed summary of our conversation above.",
  579. Messages: aiMsgs,
  580. ProviderOptions: opts,
  581. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  582. prepared.Messages = options.Messages
  583. if a.systemPromptPrefix != "" {
  584. prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
  585. }
  586. return callContext, prepared, nil
  587. },
  588. OnReasoningDelta: func(id string, text string) error {
  589. summaryMessage.AppendReasoningContent(text)
  590. return a.messages.Update(genCtx, summaryMessage)
  591. },
  592. OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
  593. // Handle anthropic signature.
  594. if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
  595. if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
  596. summaryMessage.AppendReasoningSignature(signature.Signature)
  597. }
  598. }
  599. summaryMessage.FinishThinking()
  600. return a.messages.Update(genCtx, summaryMessage)
  601. },
  602. OnTextDelta: func(id, text string) error {
  603. summaryMessage.AppendContent(text)
  604. return a.messages.Update(genCtx, summaryMessage)
  605. },
  606. })
  607. if err != nil {
  608. isCancelErr := errors.Is(err, context.Canceled)
  609. if isCancelErr {
  610. // User cancelled summarize we need to remove the summary message.
  611. deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
  612. return deleteErr
  613. }
  614. return err
  615. }
  616. summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
  617. err = a.messages.Update(genCtx, summaryMessage)
  618. if err != nil {
  619. return err
  620. }
  621. var openrouterCost *float64
  622. for _, step := range resp.Steps {
  623. stepCost := a.openrouterCost(step.ProviderMetadata)
  624. if stepCost != nil {
  625. newCost := *stepCost
  626. if openrouterCost != nil {
  627. newCost += *openrouterCost
  628. }
  629. openrouterCost = &newCost
  630. }
  631. }
  632. a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
  633. // Just in case, get just the last usage info.
  634. usage := resp.Response.Usage
  635. currentSession.SummaryMessageID = summaryMessage.ID
  636. currentSession.CompletionTokens = usage.OutputTokens
  637. currentSession.PromptTokens = 0
  638. _, err = a.sessions.Save(genCtx, currentSession)
  639. return err
  640. }
  641. func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
  642. if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
  643. return fantasy.ProviderOptions{}
  644. }
  645. return fantasy.ProviderOptions{
  646. anthropic.Name: &anthropic.ProviderCacheControlOptions{
  647. CacheControl: anthropic.CacheControl{Type: "ephemeral"},
  648. },
  649. bedrock.Name: &anthropic.ProviderCacheControlOptions{
  650. CacheControl: anthropic.CacheControl{Type: "ephemeral"},
  651. },
  652. }
  653. }
  654. func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
  655. var attachmentParts []message.ContentPart
  656. for _, attachment := range call.Attachments {
  657. attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
  658. }
  659. parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
  660. parts = append(parts, attachmentParts...)
  661. msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
  662. Role: message.User,
  663. Parts: parts,
  664. })
  665. if err != nil {
  666. return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
  667. }
  668. return msg, nil
  669. }
  670. func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
  671. var history []fantasy.Message
  672. for _, m := range msgs {
  673. if len(m.Parts) == 0 {
  674. continue
  675. }
  676. // Assistant message without content or tool calls (cancelled before it
  677. // returned anything).
  678. if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
  679. continue
  680. }
  681. history = append(history, m.ToAIMessage()...)
  682. }
  683. var files []fantasy.FilePart
  684. for _, attachment := range attachments {
  685. files = append(files, fantasy.FilePart{
  686. Filename: attachment.FileName,
  687. Data: attachment.Content,
  688. MediaType: attachment.MimeType,
  689. })
  690. }
  691. return history, files
  692. }
  693. func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
  694. msgs, err := a.messages.List(ctx, session.ID)
  695. if err != nil {
  696. return nil, fmt.Errorf("failed to list messages: %w", err)
  697. }
  698. if session.SummaryMessageID != "" {
  699. summaryMsgInex := -1
  700. for i, msg := range msgs {
  701. if msg.ID == session.SummaryMessageID {
  702. summaryMsgInex = i
  703. break
  704. }
  705. }
  706. if summaryMsgInex != -1 {
  707. msgs = msgs[summaryMsgInex:]
  708. msgs[0].Role = message.User
  709. }
  710. }
  711. return msgs, nil
  712. }
  713. func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
  714. if prompt == "" {
  715. return
  716. }
  717. var maxOutput int64 = 40
  718. if a.smallModel.CatwalkCfg.CanReason {
  719. maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
  720. }
  721. agent := fantasy.NewAgent(a.smallModel.Model,
  722. fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
  723. fantasy.WithMaxOutputTokens(maxOutput),
  724. )
  725. resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
  726. Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
  727. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  728. prepared.Messages = options.Messages
  729. if a.systemPromptPrefix != "" {
  730. prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
  731. }
  732. return callContext, prepared, nil
  733. },
  734. })
  735. if err != nil {
  736. slog.Error("error generating title", "err", err)
  737. return
  738. }
  739. title := resp.Response.Content.Text()
  740. title = strings.ReplaceAll(title, "\n", " ")
  741. // Remove thinking tags if present.
  742. if idx := strings.Index(title, "</think>"); idx > 0 {
  743. title = title[idx+len("</think>"):]
  744. }
  745. title = strings.TrimSpace(title)
  746. if title == "" {
  747. slog.Warn("failed to generate title", "warn", "empty title")
  748. return
  749. }
  750. session.Title = title
  751. var openrouterCost *float64
  752. for _, step := range resp.Steps {
  753. stepCost := a.openrouterCost(step.ProviderMetadata)
  754. if stepCost != nil {
  755. newCost := *stepCost
  756. if openrouterCost != nil {
  757. newCost += *openrouterCost
  758. }
  759. openrouterCost = &newCost
  760. }
  761. }
  762. a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
  763. _, saveErr := a.sessions.Save(ctx, *session)
  764. if saveErr != nil {
  765. slog.Error("failed to save session title & usage", "error", saveErr)
  766. return
  767. }
  768. }
  769. func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
  770. openrouterMetadata, ok := metadata[openrouter.Name]
  771. if !ok {
  772. return nil
  773. }
  774. opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
  775. if !ok {
  776. return nil
  777. }
  778. return &opts.Usage.Cost
  779. }
  780. func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
  781. modelConfig := model.CatwalkCfg
  782. cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
  783. modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
  784. modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
  785. modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
  786. a.eventTokensUsed(session.ID, model, usage, cost)
  787. if overrideCost != nil {
  788. session.Cost += *overrideCost
  789. } else {
  790. session.Cost += cost
  791. }
  792. session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
  793. session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
  794. }
  795. func (a *sessionAgent) Cancel(sessionID string) {
  796. // Cancel regular requests.
  797. if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
  798. slog.Info("Request cancellation initiated", "session_id", sessionID)
  799. cancel()
  800. }
  801. // Also check for summarize requests.
  802. if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
  803. slog.Info("Summarize cancellation initiated", "session_id", sessionID)
  804. cancel()
  805. }
  806. if a.QueuedPrompts(sessionID) > 0 {
  807. slog.Info("Clearing queued prompts", "session_id", sessionID)
  808. a.messageQueue.Del(sessionID)
  809. }
  810. }
  811. func (a *sessionAgent) ClearQueue(sessionID string) {
  812. if a.QueuedPrompts(sessionID) > 0 {
  813. slog.Info("Clearing queued prompts", "session_id", sessionID)
  814. a.messageQueue.Del(sessionID)
  815. }
  816. }
  817. func (a *sessionAgent) CancelAll() {
  818. if !a.IsBusy() {
  819. return
  820. }
  821. for key := range a.activeRequests.Seq2() {
  822. a.Cancel(key) // key is sessionID
  823. }
  824. timeout := time.After(5 * time.Second)
  825. for a.IsBusy() {
  826. select {
  827. case <-timeout:
  828. return
  829. default:
  830. time.Sleep(200 * time.Millisecond)
  831. }
  832. }
  833. }
  834. func (a *sessionAgent) IsBusy() bool {
  835. var busy bool
  836. for cancelFunc := range a.activeRequests.Seq() {
  837. if cancelFunc != nil {
  838. busy = true
  839. break
  840. }
  841. }
  842. return busy
  843. }
  844. func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
  845. _, busy := a.activeRequests.Get(sessionID)
  846. return busy
  847. }
  848. func (a *sessionAgent) QueuedPrompts(sessionID string) int {
  849. l, ok := a.messageQueue.Get(sessionID)
  850. if !ok {
  851. return 0
  852. }
  853. return len(l)
  854. }
  855. func (a *sessionAgent) SetModels(large Model, small Model) {
  856. a.largeModel = large
  857. a.smallModel = small
  858. }
  859. func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
  860. a.tools = tools
  861. }
  862. func (a *sessionAgent) Model() Model {
  863. return a.largeModel
  864. }
  865. func (a *sessionAgent) executeUserPromptSubmitHook(
  866. ctx context.Context,
  867. sessionID string,
  868. prompt string,
  869. ) ([]message.HookOutput, error) {
  870. if a.hooks == nil {
  871. return nil, nil
  872. }
  873. return a.hooks.Execute(ctx, hooks.HookContext{
  874. EventType: config.UserPromptSubmit,
  875. SessionID: sessionID,
  876. UserPrompt: prompt,
  877. Provider: a.largeModel.ModelCfg.Provider,
  878. Model: a.largeModel.ModelCfg.Model,
  879. })
  880. }