agent.go 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742
  1. package agent
  2. import (
  3. "context"
  4. _ "embed"
  5. "errors"
  6. "fmt"
  7. "log/slog"
  8. "strings"
  9. "sync"
  10. "time"
  11. "charm.land/fantasy"
  12. "charm.land/fantasy/providers/anthropic"
  13. "charm.land/fantasy/providers/openai"
  14. "github.com/charmbracelet/catwalk/pkg/catwalk"
  15. "github.com/charmbracelet/crush/internal/agent/tools"
  16. "github.com/charmbracelet/crush/internal/config"
  17. "github.com/charmbracelet/crush/internal/csync"
  18. "github.com/charmbracelet/crush/internal/message"
  19. "github.com/charmbracelet/crush/internal/permission"
  20. "github.com/charmbracelet/crush/internal/session"
  21. )
  22. //go:embed templates/title.md
  23. var titlePrompt []byte
  24. //go:embed templates/summary.md
  25. var summaryPrompt []byte
  26. type SessionAgentCall struct {
  27. SessionID string
  28. Prompt string
  29. ProviderOptions fantasy.ProviderOptions
  30. Attachments []message.Attachment
  31. MaxOutputTokens int64
  32. Temperature *float64
  33. TopP *float64
  34. TopK *int64
  35. FrequencyPenalty *float64
  36. PresencePenalty *float64
  37. }
  38. type SessionAgent interface {
  39. Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  40. SetModels(large Model, small Model)
  41. SetTools(tools []fantasy.AgentTool)
  42. Cancel(sessionID string)
  43. CancelAll()
  44. IsSessionBusy(sessionID string) bool
  45. IsBusy() bool
  46. QueuedPrompts(sessionID string) int
  47. ClearQueue(sessionID string)
  48. Summarize(context.Context, string) error
  49. Model() Model
  50. }
  51. type Model struct {
  52. Model fantasy.LanguageModel
  53. CatwalkCfg catwalk.Model
  54. ModelCfg config.SelectedModel
  55. }
  56. type sessionAgent struct {
  57. largeModel Model
  58. smallModel Model
  59. systemPrompt string
  60. tools []fantasy.AgentTool
  61. sessions session.Service
  62. messages message.Service
  63. disableAutoSummarize bool
  64. messageQueue *csync.Map[string, []SessionAgentCall]
  65. activeRequests *csync.Map[string, context.CancelFunc]
  66. }
  67. type SessionAgentOptions struct {
  68. LargeModel Model
  69. SmallModel Model
  70. SystemPrompt string
  71. DisableAutoSummarize bool
  72. Sessions session.Service
  73. Messages message.Service
  74. Tools []fantasy.AgentTool
  75. }
  76. func NewSessionAgent(
  77. opts SessionAgentOptions,
  78. ) SessionAgent {
  79. return &sessionAgent{
  80. largeModel: opts.LargeModel,
  81. smallModel: opts.SmallModel,
  82. systemPrompt: opts.SystemPrompt,
  83. sessions: opts.Sessions,
  84. messages: opts.Messages,
  85. disableAutoSummarize: opts.DisableAutoSummarize,
  86. tools: opts.Tools,
  87. messageQueue: csync.NewMap[string, []SessionAgentCall](),
  88. activeRequests: csync.NewMap[string, context.CancelFunc](),
  89. }
  90. }
  91. func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
  92. if call.Prompt == "" {
  93. return nil, ErrEmptyPrompt
  94. }
  95. if call.SessionID == "" {
  96. return nil, ErrSessionMissing
  97. }
  98. // Queue the message if busy
  99. if a.IsSessionBusy(call.SessionID) {
  100. existing, ok := a.messageQueue.Get(call.SessionID)
  101. if !ok {
  102. existing = []SessionAgentCall{}
  103. }
  104. existing = append(existing, call)
  105. a.messageQueue.Set(call.SessionID, existing)
  106. return nil, nil
  107. }
  108. if len(a.tools) > 0 {
  109. // add anthropic caching to the last tool
  110. a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
  111. }
  112. agent := fantasy.NewAgent(
  113. a.largeModel.Model,
  114. fantasy.WithSystemPrompt(a.systemPrompt),
  115. fantasy.WithTools(a.tools...),
  116. )
  117. sessionLock := sync.Mutex{}
  118. currentSession, err := a.sessions.Get(ctx, call.SessionID)
  119. if err != nil {
  120. return nil, fmt.Errorf("failed to get session: %w", err)
  121. }
  122. msgs, err := a.getSessionMessages(ctx, currentSession)
  123. if err != nil {
  124. return nil, fmt.Errorf("failed to get session messages: %w", err)
  125. }
  126. var wg sync.WaitGroup
  127. // Generate title if first message
  128. if len(msgs) == 0 {
  129. wg.Go(func() {
  130. sessionLock.Lock()
  131. a.generateTitle(ctx, &currentSession, call.Prompt)
  132. sessionLock.Unlock()
  133. })
  134. }
  135. // Add the user message to the session
  136. _, err = a.createUserMessage(ctx, call)
  137. if err != nil {
  138. return nil, err
  139. }
  140. // add the session to the context
  141. ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
  142. genCtx, cancel := context.WithCancel(ctx)
  143. a.activeRequests.Set(call.SessionID, cancel)
  144. defer cancel()
  145. defer a.activeRequests.Del(call.SessionID)
  146. history, files := a.preparePrompt(msgs, call.Attachments...)
  147. var currentAssistant *message.Message
  148. var shouldSummarize bool
  149. result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
  150. Prompt: call.Prompt,
  151. Files: files,
  152. Messages: history,
  153. ProviderOptions: call.ProviderOptions,
  154. MaxOutputTokens: &call.MaxOutputTokens,
  155. TopP: call.TopP,
  156. Temperature: call.Temperature,
  157. PresencePenalty: call.PresencePenalty,
  158. TopK: call.TopK,
  159. FrequencyPenalty: call.FrequencyPenalty,
  160. // Before each step create the new assistant message
  161. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
  162. var assistantMsg message.Message
  163. assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
  164. Role: message.Assistant,
  165. Parts: []message.ContentPart{},
  166. Model: a.largeModel.ModelCfg.Model,
  167. Provider: a.largeModel.ModelCfg.Provider,
  168. })
  169. if err != nil {
  170. return callContext, prepared, err
  171. }
  172. callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
  173. currentAssistant = &assistantMsg
  174. prepared.Messages = options.Messages
  175. // reset all cached items
  176. for i := range prepared.Messages {
  177. prepared.Messages[i].ProviderOptions = nil
  178. }
  179. queuedCalls, _ := a.messageQueue.Get(call.SessionID)
  180. a.messageQueue.Del(call.SessionID)
  181. for _, queued := range queuedCalls {
  182. userMessage, createErr := a.createUserMessage(callContext, queued)
  183. if createErr != nil {
  184. return callContext, prepared, createErr
  185. }
  186. prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
  187. }
  188. lastSystemRoleInx := 0
  189. systemMessageUpdated := false
  190. for i, msg := range prepared.Messages {
  191. // only add cache control to the last message
  192. if msg.Role == fantasy.MessageRoleSystem {
  193. lastSystemRoleInx = i
  194. } else if !systemMessageUpdated {
  195. prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
  196. systemMessageUpdated = true
  197. }
  198. // than add cache control to the last 2 messages
  199. if i > len(prepared.Messages)-3 {
  200. prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
  201. }
  202. }
  203. return callContext, prepared, err
  204. },
  205. OnReasoningDelta: func(id string, text string) error {
  206. currentAssistant.AppendReasoningContent(text)
  207. return a.messages.Update(genCtx, *currentAssistant)
  208. },
  209. OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
  210. // handle anthropic signature
  211. if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
  212. if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
  213. currentAssistant.AppendReasoningSignature(reasoning.Signature)
  214. }
  215. }
  216. if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
  217. if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
  218. currentAssistant.SetReasoningResponsesData(reasoning)
  219. }
  220. }
  221. currentAssistant.FinishThinking()
  222. return a.messages.Update(genCtx, *currentAssistant)
  223. },
  224. OnTextDelta: func(id string, text string) error {
  225. currentAssistant.AppendContent(text)
  226. return a.messages.Update(genCtx, *currentAssistant)
  227. },
  228. OnToolInputStart: func(id string, toolName string) error {
  229. toolCall := message.ToolCall{
  230. ID: id,
  231. Name: toolName,
  232. ProviderExecuted: false,
  233. Finished: false,
  234. }
  235. currentAssistant.AddToolCall(toolCall)
  236. return a.messages.Update(genCtx, *currentAssistant)
  237. },
  238. OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
  239. // TODO: implement
  240. },
  241. OnToolCall: func(tc fantasy.ToolCallContent) error {
  242. toolCall := message.ToolCall{
  243. ID: tc.ToolCallID,
  244. Name: tc.ToolName,
  245. Input: tc.Input,
  246. ProviderExecuted: false,
  247. Finished: true,
  248. }
  249. currentAssistant.AddToolCall(toolCall)
  250. return a.messages.Update(genCtx, *currentAssistant)
  251. },
  252. OnToolResult: func(result fantasy.ToolResultContent) error {
  253. var resultContent string
  254. isError := false
  255. switch result.Result.GetType() {
  256. case fantasy.ToolResultContentTypeText:
  257. r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
  258. if ok {
  259. resultContent = r.Text
  260. }
  261. case fantasy.ToolResultContentTypeError:
  262. r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
  263. if ok {
  264. isError = true
  265. resultContent = r.Error.Error()
  266. }
  267. case fantasy.ToolResultContentTypeMedia:
  268. // TODO: handle this message type
  269. }
  270. toolResult := message.ToolResult{
  271. ToolCallID: result.ToolCallID,
  272. Name: result.ToolName,
  273. Content: resultContent,
  274. IsError: isError,
  275. Metadata: result.ClientMetadata,
  276. }
  277. _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
  278. Role: message.Tool,
  279. Parts: []message.ContentPart{
  280. toolResult,
  281. },
  282. })
  283. if createMsgErr != nil {
  284. return createMsgErr
  285. }
  286. return nil
  287. },
  288. OnStepFinish: func(stepResult fantasy.StepResult) error {
  289. finishReason := message.FinishReasonUnknown
  290. switch stepResult.FinishReason {
  291. case fantasy.FinishReasonLength:
  292. finishReason = message.FinishReasonMaxTokens
  293. case fantasy.FinishReasonStop:
  294. finishReason = message.FinishReasonEndTurn
  295. case fantasy.FinishReasonToolCalls:
  296. finishReason = message.FinishReasonToolUse
  297. }
  298. currentAssistant.AddFinish(finishReason, "", "")
  299. a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage)
  300. sessionLock.Lock()
  301. _, sessionErr := a.sessions.Save(genCtx, currentSession)
  302. sessionLock.Unlock()
  303. if sessionErr != nil {
  304. return sessionErr
  305. }
  306. return a.messages.Update(genCtx, *currentAssistant)
  307. },
  308. StopWhen: []fantasy.StopCondition{
  309. func(_ []fantasy.StepResult) bool {
  310. contextWindow := a.largeModel.CatwalkCfg.ContextWindow
  311. tokens := currentSession.CompletionTokens + currentSession.PromptTokens
  312. percentage := (float64(tokens) / float64(contextWindow)) * 100
  313. if (percentage > 80) && !a.disableAutoSummarize {
  314. shouldSummarize = true
  315. return true
  316. }
  317. return false
  318. },
  319. },
  320. })
  321. if err != nil {
  322. isCancelErr := errors.Is(err, context.Canceled)
  323. isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
  324. if currentAssistant == nil {
  325. return result, err
  326. }
  327. toolCalls := currentAssistant.ToolCalls()
  328. // INFO: we use the parent context here because the genCtx has been cancelled
  329. msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
  330. if createErr != nil {
  331. return nil, createErr
  332. }
  333. for _, tc := range toolCalls {
  334. if !tc.Finished {
  335. tc.Finished = true
  336. tc.Input = "{}"
  337. currentAssistant.AddToolCall(tc)
  338. updateErr := a.messages.Update(ctx, *currentAssistant)
  339. if updateErr != nil {
  340. return nil, updateErr
  341. }
  342. }
  343. found := false
  344. for _, msg := range msgs {
  345. if msg.Role == message.Tool {
  346. for _, tr := range msg.ToolResults() {
  347. if tr.ToolCallID == tc.ID {
  348. found = true
  349. break
  350. }
  351. }
  352. }
  353. if found {
  354. break
  355. }
  356. }
  357. if found {
  358. continue
  359. }
  360. content := "There was an error while executing the tool"
  361. if isCancelErr {
  362. content = "Tool execution canceled by user"
  363. } else if isPermissionErr {
  364. content = "Permission denied"
  365. }
  366. toolResult := message.ToolResult{
  367. ToolCallID: tc.ID,
  368. Name: tc.Name,
  369. Content: content,
  370. IsError: true,
  371. }
  372. _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
  373. Role: message.Tool,
  374. Parts: []message.ContentPart{
  375. toolResult,
  376. },
  377. })
  378. if createErr != nil {
  379. return nil, createErr
  380. }
  381. }
  382. if isCancelErr {
  383. currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
  384. } else if isPermissionErr {
  385. currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
  386. } else {
  387. currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
  388. }
  389. // INFO: we use the parent context here because the genCtx has been cancelled
  390. updateErr := a.messages.Update(ctx, *currentAssistant)
  391. if updateErr != nil {
  392. return nil, updateErr
  393. }
  394. return nil, err
  395. }
  396. wg.Wait()
  397. if shouldSummarize {
  398. a.activeRequests.Del(call.SessionID)
  399. if summarizeErr := a.Summarize(genCtx, call.SessionID); summarizeErr != nil {
  400. return nil, summarizeErr
  401. }
  402. }
  403. queuedMessages, ok := a.messageQueue.Get(call.SessionID)
  404. if !ok || len(queuedMessages) == 0 {
  405. return result, err
  406. }
  407. // there are queued messages restart the loop
  408. firstQueuedMessage := queuedMessages[0]
  409. a.messageQueue.Set(call.SessionID, queuedMessages[1:])
  410. return a.Run(genCtx, firstQueuedMessage)
  411. }
  412. func (a *sessionAgent) Summarize(ctx context.Context, sessionID string) error {
  413. if a.IsSessionBusy(sessionID) {
  414. return ErrSessionBusy
  415. }
  416. currentSession, err := a.sessions.Get(ctx, sessionID)
  417. if err != nil {
  418. return fmt.Errorf("failed to get session: %w", err)
  419. }
  420. msgs, err := a.getSessionMessages(ctx, currentSession)
  421. if err != nil {
  422. return err
  423. }
  424. if len(msgs) == 0 {
  425. // nothing to summarize
  426. return nil
  427. }
  428. aiMsgs, _ := a.preparePrompt(msgs)
  429. genCtx, cancel := context.WithCancel(ctx)
  430. a.activeRequests.Set(sessionID, cancel)
  431. defer a.activeRequests.Del(sessionID)
  432. defer cancel()
  433. agent := fantasy.NewAgent(a.largeModel.Model,
  434. fantasy.WithSystemPrompt(string(summaryPrompt)),
  435. )
  436. summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  437. Role: message.Assistant,
  438. Model: a.largeModel.Model.Model(),
  439. Provider: a.largeModel.Model.Provider(),
  440. IsSummaryMessage: true,
  441. })
  442. if err != nil {
  443. return err
  444. }
  445. resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
  446. Prompt: "Provide a detailed summary of our conversation above.",
  447. Messages: aiMsgs,
  448. OnReasoningDelta: func(id string, text string) error {
  449. summaryMessage.AppendReasoningContent(text)
  450. return a.messages.Update(genCtx, summaryMessage)
  451. },
  452. OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
  453. // handle anthropic signature
  454. if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
  455. if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
  456. summaryMessage.AppendReasoningSignature(signature.Signature)
  457. }
  458. }
  459. summaryMessage.FinishThinking()
  460. return a.messages.Update(genCtx, summaryMessage)
  461. },
  462. OnTextDelta: func(id, text string) error {
  463. summaryMessage.AppendContent(text)
  464. return a.messages.Update(genCtx, summaryMessage)
  465. },
  466. })
  467. if err != nil {
  468. isCancelErr := errors.Is(err, context.Canceled)
  469. if isCancelErr {
  470. // User cancelled summarize we need to remove the summary message
  471. deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
  472. return deleteErr
  473. }
  474. return err
  475. }
  476. summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
  477. err = a.messages.Update(genCtx, summaryMessage)
  478. if err != nil {
  479. return err
  480. }
  481. a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage)
  482. // just in case get just the last usage
  483. usage := resp.Response.Usage
  484. currentSession.SummaryMessageID = summaryMessage.ID
  485. currentSession.CompletionTokens = usage.OutputTokens
  486. currentSession.PromptTokens = 0
  487. _, err = a.sessions.Save(genCtx, currentSession)
  488. return err
  489. }
  490. func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
  491. return fantasy.ProviderOptions{
  492. anthropic.Name: &anthropic.ProviderCacheControlOptions{
  493. CacheControl: anthropic.CacheControl{Type: "ephemeral"},
  494. },
  495. }
  496. }
  497. func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
  498. var attachmentParts []message.ContentPart
  499. for _, attachment := range call.Attachments {
  500. attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
  501. }
  502. parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
  503. parts = append(parts, attachmentParts...)
  504. msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
  505. Role: message.User,
  506. Parts: parts,
  507. })
  508. if err != nil {
  509. return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
  510. }
  511. return msg, nil
  512. }
  513. func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
  514. var history []fantasy.Message
  515. for _, m := range msgs {
  516. if len(m.Parts) == 0 {
  517. continue
  518. }
  519. // Assistant message without content or tool calls (cancelled before it returned anything)
  520. if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
  521. continue
  522. }
  523. history = append(history, m.ToAIMessage()...)
  524. }
  525. var files []fantasy.FilePart
  526. for _, attachment := range attachments {
  527. files = append(files, fantasy.FilePart{
  528. Filename: attachment.FileName,
  529. Data: attachment.Content,
  530. MediaType: attachment.MimeType,
  531. })
  532. }
  533. return history, files
  534. }
  535. func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
  536. msgs, err := a.messages.List(ctx, session.ID)
  537. if err != nil {
  538. return nil, fmt.Errorf("failed to list messages: %w", err)
  539. }
  540. if session.SummaryMessageID != "" {
  541. summaryMsgInex := -1
  542. for i, msg := range msgs {
  543. if msg.ID == session.SummaryMessageID {
  544. summaryMsgInex = i
  545. break
  546. }
  547. }
  548. if summaryMsgInex != -1 {
  549. msgs = msgs[summaryMsgInex:]
  550. msgs[0].Role = message.User
  551. }
  552. }
  553. return msgs, nil
  554. }
  555. func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
  556. if prompt == "" {
  557. return
  558. }
  559. var maxOutput int64 = 40
  560. if a.smallModel.CatwalkCfg.CanReason {
  561. maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
  562. }
  563. agent := fantasy.NewAgent(a.smallModel.Model,
  564. fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
  565. fantasy.WithMaxOutputTokens(maxOutput),
  566. )
  567. resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
  568. Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
  569. })
  570. if err != nil {
  571. slog.Error("error generating title", "err", err)
  572. return
  573. }
  574. title := resp.Response.Content.Text()
  575. title = strings.ReplaceAll(title, "\n", " ")
  576. // remove thinking tags if present
  577. if idx := strings.Index(title, "</think>"); idx > 0 {
  578. title = title[idx+len("</think>"):]
  579. }
  580. title = strings.TrimSpace(title)
  581. if title == "" {
  582. slog.Warn("failed to generate title", "warn", "empty title")
  583. return
  584. }
  585. session.Title = title
  586. a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
  587. _, saveErr := a.sessions.Save(ctx, *session)
  588. if saveErr != nil {
  589. slog.Error("failed to save session title & usage", "error", saveErr)
  590. return
  591. }
  592. }
  593. func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage) {
  594. modelConfig := model.CatwalkCfg
  595. cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
  596. modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
  597. modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
  598. modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
  599. session.Cost += cost
  600. session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
  601. session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
  602. }
  603. func (a *sessionAgent) Cancel(sessionID string) {
  604. // Cancel regular requests
  605. if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
  606. slog.Info("Request cancellation initiated", "session_id", sessionID)
  607. cancel()
  608. }
  609. // Also check for summarize requests
  610. if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
  611. slog.Info("Summarize cancellation initiated", "session_id", sessionID)
  612. cancel()
  613. }
  614. if a.QueuedPrompts(sessionID) > 0 {
  615. slog.Info("Clearing queued prompts", "session_id", sessionID)
  616. a.messageQueue.Del(sessionID)
  617. }
  618. }
  619. func (a *sessionAgent) ClearQueue(sessionID string) {
  620. if a.QueuedPrompts(sessionID) > 0 {
  621. slog.Info("Clearing queued prompts", "session_id", sessionID)
  622. a.messageQueue.Del(sessionID)
  623. }
  624. }
  625. func (a *sessionAgent) CancelAll() {
  626. if !a.IsBusy() {
  627. return
  628. }
  629. for key := range a.activeRequests.Seq2() {
  630. a.Cancel(key) // key is sessionID
  631. }
  632. timeout := time.After(5 * time.Second)
  633. for a.IsBusy() {
  634. select {
  635. case <-timeout:
  636. return
  637. default:
  638. time.Sleep(200 * time.Millisecond)
  639. }
  640. }
  641. }
  642. func (a *sessionAgent) IsBusy() bool {
  643. var busy bool
  644. for cancelFunc := range a.activeRequests.Seq() {
  645. if cancelFunc != nil {
  646. busy = true
  647. break
  648. }
  649. }
  650. return busy
  651. }
  652. func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
  653. _, busy := a.activeRequests.Get(sessionID)
  654. return busy
  655. }
  656. func (a *sessionAgent) QueuedPrompts(sessionID string) int {
  657. l, ok := a.messageQueue.Get(sessionID)
  658. if !ok {
  659. return 0
  660. }
  661. return len(l)
  662. }
  663. func (a *sessionAgent) SetModels(large Model, small Model) {
  664. a.largeModel = large
  665. a.smallModel = small
  666. }
  667. func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
  668. a.tools = tools
  669. }
  670. func (a *sessionAgent) Model() Model {
  671. return a.largeModel
  672. }