agent.go 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134
  1. // Package agent contains the implementation of the AI agent service.
  2. package agent
  3. import (
  4. "context"
  5. "errors"
  6. "fmt"
  7. "log/slog"
  8. "maps"
  9. "slices"
  10. "strings"
  11. "time"
  12. "github.com/charmbracelet/catwalk/pkg/catwalk"
  13. "github.com/charmbracelet/crush/internal/config"
  14. "github.com/charmbracelet/crush/internal/csync"
  15. "github.com/charmbracelet/crush/internal/event"
  16. "github.com/charmbracelet/crush/internal/history"
  17. "github.com/charmbracelet/crush/internal/llm/prompt"
  18. "github.com/charmbracelet/crush/internal/llm/provider"
  19. "github.com/charmbracelet/crush/internal/llm/tools"
  20. "github.com/charmbracelet/crush/internal/log"
  21. "github.com/charmbracelet/crush/internal/lsp"
  22. "github.com/charmbracelet/crush/internal/message"
  23. "github.com/charmbracelet/crush/internal/permission"
  24. "github.com/charmbracelet/crush/internal/pubsub"
  25. "github.com/charmbracelet/crush/internal/session"
  26. "github.com/charmbracelet/crush/internal/shell"
  27. )
  28. type AgentEventType string
  29. const (
  30. AgentEventTypeError AgentEventType = "error"
  31. AgentEventTypeResponse AgentEventType = "response"
  32. AgentEventTypeSummarize AgentEventType = "summarize"
  33. )
  34. type AgentEvent struct {
  35. Type AgentEventType
  36. Message message.Message
  37. Error error
  38. // When summarizing
  39. SessionID string
  40. Progress string
  41. Done bool
  42. }
  43. type Service interface {
  44. pubsub.Suscriber[AgentEvent]
  45. Model() catwalk.Model
  46. Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  47. Cancel(sessionID string)
  48. CancelAll()
  49. IsSessionBusy(sessionID string) bool
  50. IsBusy() bool
  51. Summarize(ctx context.Context, sessionID string) error
  52. UpdateModel() error
  53. QueuedPrompts(sessionID string) int
  54. ClearQueue(sessionID string)
  55. }
  56. type agent struct {
  57. *pubsub.Broker[AgentEvent]
  58. agentCfg config.Agent
  59. sessions session.Service
  60. messages message.Service
  61. permissions permission.Service
  62. baseTools *csync.Map[string, tools.BaseTool]
  63. mcpTools *csync.Map[string, tools.BaseTool]
  64. lspClients *csync.Map[string, *lsp.Client]
  65. // We need this to be able to update it when model changes
  66. agentToolFn func() (tools.BaseTool, error)
  67. cleanupFuncs []func()
  68. provider provider.Provider
  69. providerID string
  70. titleProvider provider.Provider
  71. summarizeProvider provider.Provider
  72. summarizeProviderID string
  73. activeRequests *csync.Map[string, context.CancelFunc]
  74. promptQueue *csync.Map[string, []string]
  75. }
  76. var agentPromptMap = map[string]prompt.PromptID{
  77. "coder": prompt.PromptCoder,
  78. "task": prompt.PromptTask,
  79. }
  80. func NewAgent(
  81. ctx context.Context,
  82. agentCfg config.Agent,
  83. // These services are needed in the tools
  84. permissions permission.Service,
  85. sessions session.Service,
  86. messages message.Service,
  87. history history.Service,
  88. lspClients *csync.Map[string, *lsp.Client],
  89. ) (Service, error) {
  90. cfg := config.Get()
  91. var agentToolFn func() (tools.BaseTool, error)
  92. if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
  93. agentToolFn = func() (tools.BaseTool, error) {
  94. taskAgentCfg := config.Get().Agents["task"]
  95. if taskAgentCfg.ID == "" {
  96. return nil, fmt.Errorf("task agent not found in config")
  97. }
  98. taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
  99. if err != nil {
  100. return nil, fmt.Errorf("failed to create task agent: %w", err)
  101. }
  102. return NewAgentTool(taskAgent, sessions, messages), nil
  103. }
  104. }
  105. providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
  106. if providerCfg == nil {
  107. return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
  108. }
  109. model := config.Get().GetModelByType(agentCfg.Model)
  110. if model == nil {
  111. return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
  112. }
  113. promptID := agentPromptMap[agentCfg.ID]
  114. if promptID == "" {
  115. promptID = prompt.PromptDefault
  116. }
  117. opts := []provider.ProviderClientOption{
  118. provider.WithModel(agentCfg.Model),
  119. provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
  120. }
  121. agentProvider, err := provider.NewProvider(*providerCfg, opts...)
  122. if err != nil {
  123. return nil, err
  124. }
  125. smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
  126. var smallModelProviderCfg *config.ProviderConfig
  127. if smallModelCfg.Provider == providerCfg.ID {
  128. smallModelProviderCfg = providerCfg
  129. } else {
  130. smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
  131. if smallModelProviderCfg.ID == "" {
  132. return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
  133. }
  134. }
  135. smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
  136. if smallModel.ID == "" {
  137. return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
  138. }
  139. titleOpts := []provider.ProviderClientOption{
  140. provider.WithModel(config.SelectedModelTypeSmall),
  141. provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
  142. }
  143. titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
  144. if err != nil {
  145. return nil, err
  146. }
  147. summarizeOpts := []provider.ProviderClientOption{
  148. provider.WithModel(config.SelectedModelTypeLarge),
  149. provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
  150. }
  151. summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
  152. if err != nil {
  153. return nil, err
  154. }
  155. baseToolsFn := func() map[string]tools.BaseTool {
  156. slog.Debug("Initializing agent base tools", "agent", agentCfg.ID)
  157. defer func() {
  158. slog.Debug("Initialized agent base tools", "agent", agentCfg.ID)
  159. }()
  160. // Base tools available to all agents
  161. cwd := cfg.WorkingDir()
  162. result := make(map[string]tools.BaseTool)
  163. for _, tool := range []tools.BaseTool{
  164. tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
  165. tools.NewDownloadTool(permissions, cwd),
  166. tools.NewEditTool(lspClients, permissions, history, cwd),
  167. tools.NewMultiEditTool(lspClients, permissions, history, cwd),
  168. tools.NewFetchTool(permissions, cwd),
  169. tools.NewGlobTool(cwd),
  170. tools.NewGrepTool(cwd),
  171. tools.NewLsTool(permissions, cwd),
  172. tools.NewSourcegraphTool(),
  173. tools.NewViewTool(lspClients, permissions, cwd),
  174. tools.NewWriteTool(lspClients, permissions, history, cwd),
  175. } {
  176. result[tool.Name()] = tool
  177. }
  178. return result
  179. }
  180. mcpToolsFn := func() map[string]tools.BaseTool {
  181. slog.Debug("Initializing agent mcp tools", "agent", agentCfg.ID)
  182. defer func() {
  183. slog.Debug("Initialized agent mcp tools", "agent", agentCfg.ID)
  184. }()
  185. mcpToolsOnce.Do(func() {
  186. doGetMCPTools(ctx, permissions, cfg)
  187. })
  188. return maps.Collect(mcpTools.Seq2())
  189. }
  190. a := &agent{
  191. Broker: pubsub.NewBroker[AgentEvent](),
  192. agentCfg: agentCfg,
  193. provider: agentProvider,
  194. providerID: string(providerCfg.ID),
  195. messages: messages,
  196. sessions: sessions,
  197. titleProvider: titleProvider,
  198. summarizeProvider: summarizeProvider,
  199. summarizeProviderID: string(providerCfg.ID),
  200. agentToolFn: agentToolFn,
  201. activeRequests: csync.NewMap[string, context.CancelFunc](),
  202. mcpTools: csync.NewLazyMap(mcpToolsFn),
  203. baseTools: csync.NewLazyMap(baseToolsFn),
  204. promptQueue: csync.NewMap[string, []string](),
  205. permissions: permissions,
  206. lspClients: lspClients,
  207. }
  208. a.setupEvents(ctx)
  209. return a, nil
  210. }
  211. func (a *agent) Model() catwalk.Model {
  212. return *config.Get().GetModelByType(a.agentCfg.Model)
  213. }
  214. func (a *agent) Cancel(sessionID string) {
  215. // Cancel regular requests
  216. if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
  217. slog.Info("Request cancellation initiated", "session_id", sessionID)
  218. cancel()
  219. }
  220. // Also check for summarize requests
  221. if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
  222. slog.Info("Summarize cancellation initiated", "session_id", sessionID)
  223. cancel()
  224. }
  225. if a.QueuedPrompts(sessionID) > 0 {
  226. slog.Info("Clearing queued prompts", "session_id", sessionID)
  227. a.promptQueue.Del(sessionID)
  228. }
  229. }
  230. func (a *agent) IsBusy() bool {
  231. var busy bool
  232. for cancelFunc := range a.activeRequests.Seq() {
  233. if cancelFunc != nil {
  234. busy = true
  235. break
  236. }
  237. }
  238. return busy
  239. }
  240. func (a *agent) IsSessionBusy(sessionID string) bool {
  241. _, busy := a.activeRequests.Get(sessionID)
  242. return busy
  243. }
  244. func (a *agent) QueuedPrompts(sessionID string) int {
  245. l, ok := a.promptQueue.Get(sessionID)
  246. if !ok {
  247. return 0
  248. }
  249. return len(l)
  250. }
  251. func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
  252. if content == "" {
  253. return nil
  254. }
  255. if a.titleProvider == nil {
  256. return nil
  257. }
  258. session, err := a.sessions.Get(ctx, sessionID)
  259. if err != nil {
  260. return err
  261. }
  262. parts := []message.ContentPart{message.TextContent{
  263. Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
  264. }}
  265. // Use streaming approach like summarization
  266. response := a.titleProvider.StreamResponse(
  267. ctx,
  268. []message.Message{
  269. {
  270. Role: message.User,
  271. Parts: parts,
  272. },
  273. },
  274. nil,
  275. )
  276. var finalResponse *provider.ProviderResponse
  277. for r := range response {
  278. if r.Error != nil {
  279. return r.Error
  280. }
  281. finalResponse = r.Response
  282. }
  283. if finalResponse == nil {
  284. return fmt.Errorf("no response received from title provider")
  285. }
  286. title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
  287. if idx := strings.Index(title, "</think>"); idx > 0 {
  288. title = title[idx+len("</think>"):]
  289. }
  290. title = strings.TrimSpace(title)
  291. if title == "" {
  292. return nil
  293. }
  294. session.Title = title
  295. _, err = a.sessions.Save(ctx, session)
  296. return err
  297. }
  298. func (a *agent) err(err error) AgentEvent {
  299. return AgentEvent{
  300. Type: AgentEventTypeError,
  301. Error: err,
  302. }
  303. }
  304. func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
  305. if !a.Model().SupportsImages && attachments != nil {
  306. attachments = nil
  307. }
  308. events := make(chan AgentEvent, 1)
  309. if a.IsSessionBusy(sessionID) {
  310. existing, ok := a.promptQueue.Get(sessionID)
  311. if !ok {
  312. existing = []string{}
  313. }
  314. existing = append(existing, content)
  315. a.promptQueue.Set(sessionID, existing)
  316. return nil, nil
  317. }
  318. genCtx, cancel := context.WithCancel(ctx)
  319. a.activeRequests.Set(sessionID, cancel)
  320. startTime := time.Now()
  321. go func() {
  322. slog.Debug("Request started", "sessionID", sessionID)
  323. defer log.RecoverPanic("agent.Run", func() {
  324. events <- a.err(fmt.Errorf("panic while running the agent"))
  325. })
  326. var attachmentParts []message.ContentPart
  327. for _, attachment := range attachments {
  328. attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
  329. }
  330. result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
  331. if result.Error != nil {
  332. if isCancelledErr(result.Error) {
  333. slog.Error("Request canceled", "sessionID", sessionID)
  334. } else {
  335. slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
  336. event.Error(result.Error)
  337. }
  338. } else {
  339. slog.Debug("Request completed", "sessionID", sessionID)
  340. }
  341. a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
  342. a.activeRequests.Del(sessionID)
  343. cancel()
  344. a.Publish(pubsub.CreatedEvent, result)
  345. events <- result
  346. close(events)
  347. }()
  348. a.eventPromptSent(sessionID)
  349. return events, nil
  350. }
  351. func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
  352. cfg := config.Get()
  353. // List existing messages; if none, start title generation asynchronously.
  354. msgs, err := a.messages.List(ctx, sessionID)
  355. if err != nil {
  356. return a.err(fmt.Errorf("failed to list messages: %w", err))
  357. }
  358. if len(msgs) == 0 {
  359. go func() {
  360. defer log.RecoverPanic("agent.Run", func() {
  361. slog.Error("panic while generating title")
  362. })
  363. titleErr := a.generateTitle(ctx, sessionID, content)
  364. if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
  365. slog.Error("failed to generate title", "error", titleErr)
  366. }
  367. }()
  368. }
  369. session, err := a.sessions.Get(ctx, sessionID)
  370. if err != nil {
  371. return a.err(fmt.Errorf("failed to get session: %w", err))
  372. }
  373. if session.SummaryMessageID != "" {
  374. summaryMsgInex := -1
  375. for i, msg := range msgs {
  376. if msg.ID == session.SummaryMessageID {
  377. summaryMsgInex = i
  378. break
  379. }
  380. }
  381. if summaryMsgInex != -1 {
  382. msgs = msgs[summaryMsgInex:]
  383. msgs[0].Role = message.User
  384. }
  385. }
  386. userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
  387. if err != nil {
  388. return a.err(fmt.Errorf("failed to create user message: %w", err))
  389. }
  390. // Append the new user message to the conversation history.
  391. msgHistory := append(msgs, userMsg)
  392. for {
  393. // Check for cancellation before each iteration
  394. select {
  395. case <-ctx.Done():
  396. return a.err(ctx.Err())
  397. default:
  398. // Continue processing
  399. }
  400. agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
  401. if err != nil {
  402. if errors.Is(err, context.Canceled) {
  403. agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
  404. a.messages.Update(context.Background(), agentMessage)
  405. return a.err(ErrRequestCancelled)
  406. }
  407. return a.err(fmt.Errorf("failed to process events: %w", err))
  408. }
  409. if cfg.Options.Debug {
  410. slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
  411. }
  412. if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
  413. // We are not done, we need to respond with the tool response
  414. msgHistory = append(msgHistory, agentMessage, *toolResults)
  415. // If there are queued prompts, process the next one
  416. nextPrompt, ok := a.promptQueue.Take(sessionID)
  417. if ok {
  418. for _, prompt := range nextPrompt {
  419. // Create a new user message for the queued prompt
  420. userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
  421. if err != nil {
  422. return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
  423. }
  424. // Append the new user message to the conversation history
  425. msgHistory = append(msgHistory, userMsg)
  426. }
  427. }
  428. continue
  429. } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
  430. queuePrompts, ok := a.promptQueue.Take(sessionID)
  431. if ok {
  432. for _, prompt := range queuePrompts {
  433. if prompt == "" {
  434. continue
  435. }
  436. userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
  437. if err != nil {
  438. return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
  439. }
  440. msgHistory = append(msgHistory, userMsg)
  441. }
  442. continue
  443. }
  444. }
  445. if agentMessage.FinishReason() == "" {
  446. // Kujtim: could not track down where this is happening but this means its cancelled
  447. agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
  448. _ = a.messages.Update(context.Background(), agentMessage)
  449. return a.err(ErrRequestCancelled)
  450. }
  451. return AgentEvent{
  452. Type: AgentEventTypeResponse,
  453. Message: agentMessage,
  454. Done: true,
  455. }
  456. }
  457. }
  458. func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
  459. parts := []message.ContentPart{message.TextContent{Text: content}}
  460. parts = append(parts, attachmentParts...)
  461. return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  462. Role: message.User,
  463. Parts: parts,
  464. })
  465. }
  466. func (a *agent) getAllTools() ([]tools.BaseTool, error) {
  467. var allTools []tools.BaseTool
  468. for tool := range a.baseTools.Seq() {
  469. if a.agentCfg.AllowedTools == nil || slices.Contains(a.agentCfg.AllowedTools, tool.Name()) {
  470. allTools = append(allTools, tool)
  471. }
  472. }
  473. if a.agentCfg.ID == "coder" {
  474. allTools = slices.AppendSeq(allTools, a.mcpTools.Seq())
  475. if a.lspClients.Len() > 0 {
  476. allTools = append(allTools, tools.NewDiagnosticsTool(a.lspClients), tools.NewReferencesTool(a.lspClients))
  477. }
  478. }
  479. if a.agentToolFn != nil {
  480. agentTool, agentToolErr := a.agentToolFn()
  481. if agentToolErr != nil {
  482. return nil, agentToolErr
  483. }
  484. allTools = append(allTools, agentTool)
  485. }
  486. return allTools, nil
  487. }
  488. func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
  489. ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
  490. // Create the assistant message first so the spinner shows immediately
  491. assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  492. Role: message.Assistant,
  493. Parts: []message.ContentPart{},
  494. Model: a.Model().ID,
  495. Provider: a.providerID,
  496. })
  497. if err != nil {
  498. return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
  499. }
  500. allTools, toolsErr := a.getAllTools()
  501. if toolsErr != nil {
  502. return assistantMsg, nil, toolsErr
  503. }
  504. // Now collect tools (which may block on MCP initialization)
  505. eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
  506. // Add the session and message ID into the context if needed by tools.
  507. ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
  508. loop:
  509. for {
  510. select {
  511. case event, ok := <-eventChan:
  512. if !ok {
  513. break loop
  514. }
  515. if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
  516. if errors.Is(processErr, context.Canceled) {
  517. a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
  518. } else {
  519. a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
  520. }
  521. return assistantMsg, nil, processErr
  522. }
  523. case <-ctx.Done():
  524. a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
  525. return assistantMsg, nil, ctx.Err()
  526. }
  527. }
  528. toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
  529. toolCalls := assistantMsg.ToolCalls()
  530. for i, toolCall := range toolCalls {
  531. select {
  532. case <-ctx.Done():
  533. a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
  534. // Make all future tool calls cancelled
  535. for j := i; j < len(toolCalls); j++ {
  536. toolResults[j] = message.ToolResult{
  537. ToolCallID: toolCalls[j].ID,
  538. Content: "Tool execution canceled by user",
  539. IsError: true,
  540. }
  541. }
  542. goto out
  543. default:
  544. // Continue processing
  545. var tool tools.BaseTool
  546. allTools, _ = a.getAllTools()
  547. for _, availableTool := range allTools {
  548. if availableTool.Info().Name == toolCall.Name {
  549. tool = availableTool
  550. break
  551. }
  552. }
  553. // Tool not found
  554. if tool == nil {
  555. toolResults[i] = message.ToolResult{
  556. ToolCallID: toolCall.ID,
  557. Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
  558. IsError: true,
  559. }
  560. continue
  561. }
  562. // Run tool in goroutine to allow cancellation
  563. type toolExecResult struct {
  564. response tools.ToolResponse
  565. err error
  566. }
  567. resultChan := make(chan toolExecResult, 1)
  568. go func() {
  569. response, err := tool.Run(ctx, tools.ToolCall{
  570. ID: toolCall.ID,
  571. Name: toolCall.Name,
  572. Input: toolCall.Input,
  573. })
  574. resultChan <- toolExecResult{response: response, err: err}
  575. }()
  576. var toolResponse tools.ToolResponse
  577. var toolErr error
  578. select {
  579. case <-ctx.Done():
  580. a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
  581. // Mark remaining tool calls as cancelled
  582. for j := i; j < len(toolCalls); j++ {
  583. toolResults[j] = message.ToolResult{
  584. ToolCallID: toolCalls[j].ID,
  585. Content: "Tool execution canceled by user",
  586. IsError: true,
  587. }
  588. }
  589. goto out
  590. case result := <-resultChan:
  591. toolResponse = result.response
  592. toolErr = result.err
  593. }
  594. if toolErr != nil {
  595. slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
  596. if errors.Is(toolErr, permission.ErrorPermissionDenied) {
  597. toolResults[i] = message.ToolResult{
  598. ToolCallID: toolCall.ID,
  599. Content: "Permission denied",
  600. IsError: true,
  601. }
  602. for j := i + 1; j < len(toolCalls); j++ {
  603. toolResults[j] = message.ToolResult{
  604. ToolCallID: toolCalls[j].ID,
  605. Content: "Tool execution canceled by user",
  606. IsError: true,
  607. }
  608. }
  609. a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
  610. break
  611. }
  612. }
  613. toolResults[i] = message.ToolResult{
  614. ToolCallID: toolCall.ID,
  615. Content: toolResponse.Content,
  616. Metadata: toolResponse.Metadata,
  617. IsError: toolResponse.IsError,
  618. }
  619. }
  620. }
  621. out:
  622. if len(toolResults) == 0 {
  623. return assistantMsg, nil, nil
  624. }
  625. parts := make([]message.ContentPart, 0)
  626. for _, tr := range toolResults {
  627. parts = append(parts, tr)
  628. }
  629. msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
  630. Role: message.Tool,
  631. Parts: parts,
  632. Provider: a.providerID,
  633. })
  634. if err != nil {
  635. return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
  636. }
  637. return assistantMsg, &msg, err
  638. }
  639. func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
  640. msg.AddFinish(finishReason, message, details)
  641. _ = a.messages.Update(ctx, *msg)
  642. }
  643. func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
  644. select {
  645. case <-ctx.Done():
  646. return ctx.Err()
  647. default:
  648. // Continue processing.
  649. }
  650. switch event.Type {
  651. case provider.EventThinkingDelta:
  652. assistantMsg.AppendReasoningContent(event.Thinking)
  653. return a.messages.Update(ctx, *assistantMsg)
  654. case provider.EventSignatureDelta:
  655. assistantMsg.AppendReasoningSignature(event.Signature)
  656. return a.messages.Update(ctx, *assistantMsg)
  657. case provider.EventContentDelta:
  658. assistantMsg.FinishThinking()
  659. assistantMsg.AppendContent(event.Content)
  660. return a.messages.Update(ctx, *assistantMsg)
  661. case provider.EventToolUseStart:
  662. assistantMsg.FinishThinking()
  663. slog.Info("Tool call started", "toolCall", event.ToolCall)
  664. assistantMsg.AddToolCall(*event.ToolCall)
  665. return a.messages.Update(ctx, *assistantMsg)
  666. case provider.EventToolUseDelta:
  667. assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
  668. return a.messages.Update(ctx, *assistantMsg)
  669. case provider.EventToolUseStop:
  670. slog.Info("Finished tool call", "toolCall", event.ToolCall)
  671. assistantMsg.FinishToolCall(event.ToolCall.ID)
  672. return a.messages.Update(ctx, *assistantMsg)
  673. case provider.EventError:
  674. return event.Error
  675. case provider.EventComplete:
  676. assistantMsg.FinishThinking()
  677. assistantMsg.SetToolCalls(event.Response.ToolCalls)
  678. assistantMsg.AddFinish(event.Response.FinishReason, "", "")
  679. if err := a.messages.Update(ctx, *assistantMsg); err != nil {
  680. return fmt.Errorf("failed to update message: %w", err)
  681. }
  682. return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
  683. }
  684. return nil
  685. }
  686. func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
  687. sess, err := a.sessions.Get(ctx, sessionID)
  688. if err != nil {
  689. return fmt.Errorf("failed to get session: %w", err)
  690. }
  691. cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
  692. model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
  693. model.CostPer1MIn/1e6*float64(usage.InputTokens) +
  694. model.CostPer1MOut/1e6*float64(usage.OutputTokens)
  695. a.eventTokensUsed(sessionID, usage, cost)
  696. sess.Cost += cost
  697. sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
  698. sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
  699. _, err = a.sessions.Save(ctx, sess)
  700. if err != nil {
  701. return fmt.Errorf("failed to save session: %w", err)
  702. }
  703. return nil
  704. }
  705. func (a *agent) Summarize(ctx context.Context, sessionID string) error {
  706. if a.summarizeProvider == nil {
  707. return fmt.Errorf("summarize provider not available")
  708. }
  709. // Check if session is busy
  710. if a.IsSessionBusy(sessionID) {
  711. return ErrSessionBusy
  712. }
  713. // Create a new context with cancellation
  714. summarizeCtx, cancel := context.WithCancel(ctx)
  715. // Store the cancel function in activeRequests to allow cancellation
  716. a.activeRequests.Set(sessionID+"-summarize", cancel)
  717. go func() {
  718. defer a.activeRequests.Del(sessionID + "-summarize")
  719. defer cancel()
  720. event := AgentEvent{
  721. Type: AgentEventTypeSummarize,
  722. Progress: "Starting summarization...",
  723. }
  724. a.Publish(pubsub.CreatedEvent, event)
  725. // Get all messages from the session
  726. msgs, err := a.messages.List(summarizeCtx, sessionID)
  727. if err != nil {
  728. event = AgentEvent{
  729. Type: AgentEventTypeError,
  730. Error: fmt.Errorf("failed to list messages: %w", err),
  731. Done: true,
  732. }
  733. a.Publish(pubsub.CreatedEvent, event)
  734. return
  735. }
  736. summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
  737. if len(msgs) == 0 {
  738. event = AgentEvent{
  739. Type: AgentEventTypeError,
  740. Error: fmt.Errorf("no messages to summarize"),
  741. Done: true,
  742. }
  743. a.Publish(pubsub.CreatedEvent, event)
  744. return
  745. }
  746. event = AgentEvent{
  747. Type: AgentEventTypeSummarize,
  748. Progress: "Analyzing conversation...",
  749. }
  750. a.Publish(pubsub.CreatedEvent, event)
  751. // Add a system message to guide the summarization
  752. summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
  753. // Create a new message with the summarize prompt
  754. promptMsg := message.Message{
  755. Role: message.User,
  756. Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
  757. }
  758. // Append the prompt to the messages
  759. msgsWithPrompt := append(msgs, promptMsg)
  760. event = AgentEvent{
  761. Type: AgentEventTypeSummarize,
  762. Progress: "Generating summary...",
  763. }
  764. a.Publish(pubsub.CreatedEvent, event)
  765. // Send the messages to the summarize provider
  766. response := a.summarizeProvider.StreamResponse(
  767. summarizeCtx,
  768. msgsWithPrompt,
  769. nil,
  770. )
  771. var finalResponse *provider.ProviderResponse
  772. for r := range response {
  773. if r.Error != nil {
  774. event = AgentEvent{
  775. Type: AgentEventTypeError,
  776. Error: fmt.Errorf("failed to summarize: %w", r.Error),
  777. Done: true,
  778. }
  779. a.Publish(pubsub.CreatedEvent, event)
  780. return
  781. }
  782. finalResponse = r.Response
  783. }
  784. summary := strings.TrimSpace(finalResponse.Content)
  785. if summary == "" {
  786. event = AgentEvent{
  787. Type: AgentEventTypeError,
  788. Error: fmt.Errorf("empty summary returned"),
  789. Done: true,
  790. }
  791. a.Publish(pubsub.CreatedEvent, event)
  792. return
  793. }
  794. shell := shell.GetPersistentShell(config.Get().WorkingDir())
  795. summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
  796. event = AgentEvent{
  797. Type: AgentEventTypeSummarize,
  798. Progress: "Creating new session...",
  799. }
  800. a.Publish(pubsub.CreatedEvent, event)
  801. oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
  802. if err != nil {
  803. event = AgentEvent{
  804. Type: AgentEventTypeError,
  805. Error: fmt.Errorf("failed to get session: %w", err),
  806. Done: true,
  807. }
  808. a.Publish(pubsub.CreatedEvent, event)
  809. return
  810. }
  811. // Create a message in the new session with the summary
  812. msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
  813. Role: message.Assistant,
  814. Parts: []message.ContentPart{
  815. message.TextContent{Text: summary},
  816. message.Finish{
  817. Reason: message.FinishReasonEndTurn,
  818. Time: time.Now().Unix(),
  819. },
  820. },
  821. Model: a.summarizeProvider.Model().ID,
  822. Provider: a.summarizeProviderID,
  823. })
  824. if err != nil {
  825. event = AgentEvent{
  826. Type: AgentEventTypeError,
  827. Error: fmt.Errorf("failed to create summary message: %w", err),
  828. Done: true,
  829. }
  830. a.Publish(pubsub.CreatedEvent, event)
  831. return
  832. }
  833. oldSession.SummaryMessageID = msg.ID
  834. oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
  835. oldSession.PromptTokens = 0
  836. model := a.summarizeProvider.Model()
  837. usage := finalResponse.Usage
  838. cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
  839. model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
  840. model.CostPer1MIn/1e6*float64(usage.InputTokens) +
  841. model.CostPer1MOut/1e6*float64(usage.OutputTokens)
  842. oldSession.Cost += cost
  843. _, err = a.sessions.Save(summarizeCtx, oldSession)
  844. if err != nil {
  845. event = AgentEvent{
  846. Type: AgentEventTypeError,
  847. Error: fmt.Errorf("failed to save session: %w", err),
  848. Done: true,
  849. }
  850. a.Publish(pubsub.CreatedEvent, event)
  851. }
  852. event = AgentEvent{
  853. Type: AgentEventTypeSummarize,
  854. SessionID: oldSession.ID,
  855. Progress: "Summary complete",
  856. Done: true,
  857. }
  858. a.Publish(pubsub.CreatedEvent, event)
  859. // Send final success event with the new session ID
  860. }()
  861. return nil
  862. }
  863. func (a *agent) ClearQueue(sessionID string) {
  864. if a.QueuedPrompts(sessionID) > 0 {
  865. slog.Info("Clearing queued prompts", "session_id", sessionID)
  866. a.promptQueue.Del(sessionID)
  867. }
  868. }
  869. func (a *agent) CancelAll() {
  870. if !a.IsBusy() {
  871. return
  872. }
  873. for key := range a.activeRequests.Seq2() {
  874. a.Cancel(key) // key is sessionID
  875. }
  876. for _, cleanup := range a.cleanupFuncs {
  877. if cleanup != nil {
  878. cleanup()
  879. }
  880. }
  881. timeout := time.After(5 * time.Second)
  882. for a.IsBusy() {
  883. select {
  884. case <-timeout:
  885. return
  886. default:
  887. time.Sleep(200 * time.Millisecond)
  888. }
  889. }
  890. }
  891. func (a *agent) UpdateModel() error {
  892. cfg := config.Get()
  893. // Get current provider configuration
  894. currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
  895. if currentProviderCfg == nil || currentProviderCfg.ID == "" {
  896. return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
  897. }
  898. // Check if provider has changed
  899. if string(currentProviderCfg.ID) != a.providerID {
  900. // Provider changed, need to recreate the main provider
  901. model := cfg.GetModelByType(a.agentCfg.Model)
  902. if model.ID == "" {
  903. return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
  904. }
  905. promptID := agentPromptMap[a.agentCfg.ID]
  906. if promptID == "" {
  907. promptID = prompt.PromptDefault
  908. }
  909. opts := []provider.ProviderClientOption{
  910. provider.WithModel(a.agentCfg.Model),
  911. provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
  912. }
  913. newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
  914. if err != nil {
  915. return fmt.Errorf("failed to create new provider: %w", err)
  916. }
  917. // Update the provider and provider ID
  918. a.provider = newProvider
  919. a.providerID = string(currentProviderCfg.ID)
  920. }
  921. // Check if providers have changed for title (small) and summarize (large)
  922. smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
  923. var smallModelProviderCfg config.ProviderConfig
  924. for p := range cfg.Providers.Seq() {
  925. if p.ID == smallModelCfg.Provider {
  926. smallModelProviderCfg = p
  927. break
  928. }
  929. }
  930. if smallModelProviderCfg.ID == "" {
  931. return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
  932. }
  933. largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
  934. var largeModelProviderCfg config.ProviderConfig
  935. for p := range cfg.Providers.Seq() {
  936. if p.ID == largeModelCfg.Provider {
  937. largeModelProviderCfg = p
  938. break
  939. }
  940. }
  941. if largeModelProviderCfg.ID == "" {
  942. return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
  943. }
  944. var maxTitleTokens int64 = 40
  945. // if the max output is too low for the gemini provider it won't return anything
  946. if smallModelCfg.Provider == "gemini" {
  947. maxTitleTokens = 1000
  948. }
  949. // Recreate title provider
  950. titleOpts := []provider.ProviderClientOption{
  951. provider.WithModel(config.SelectedModelTypeSmall),
  952. provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
  953. provider.WithMaxTokens(maxTitleTokens),
  954. }
  955. newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
  956. if err != nil {
  957. return fmt.Errorf("failed to create new title provider: %w", err)
  958. }
  959. a.titleProvider = newTitleProvider
  960. // Recreate summarize provider if provider changed (now large model)
  961. if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
  962. largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
  963. if largeModel == nil {
  964. return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
  965. }
  966. summarizeOpts := []provider.ProviderClientOption{
  967. provider.WithModel(config.SelectedModelTypeLarge),
  968. provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
  969. }
  970. newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
  971. if err != nil {
  972. return fmt.Errorf("failed to create new summarize provider: %w", err)
  973. }
  974. a.summarizeProvider = newSummarizeProvider
  975. a.summarizeProviderID = string(largeModelProviderCfg.ID)
  976. }
  977. return nil
  978. }
  979. func (a *agent) setupEvents(ctx context.Context) {
  980. ctx, cancel := context.WithCancel(ctx)
  981. go func() {
  982. subCh := SubscribeMCPEvents(ctx)
  983. for {
  984. select {
  985. case event, ok := <-subCh:
  986. if !ok {
  987. slog.Debug("MCPEvents subscription channel closed")
  988. return
  989. }
  990. switch event.Payload.Type {
  991. case MCPEventToolsListChanged:
  992. name := event.Payload.Name
  993. c, ok := mcpClients.Get(name)
  994. if !ok {
  995. slog.Warn("MCP client not found for tools update", "name", name)
  996. continue
  997. }
  998. cfg := config.Get()
  999. tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
  1000. if err != nil {
  1001. slog.Error("error listing tools", "error", err)
  1002. updateMCPState(name, MCPStateError, err, nil, 0)
  1003. _ = c.Close()
  1004. continue
  1005. }
  1006. updateMcpTools(name, tools)
  1007. a.mcpTools.Reset(maps.Collect(mcpTools.Seq2()))
  1008. updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
  1009. default:
  1010. continue
  1011. }
  1012. case <-ctx.Done():
  1013. slog.Debug("MCPEvents subscription cancelled")
  1014. return
  1015. }
  1016. }
  1017. }()
  1018. a.cleanupFuncs = append(a.cleanupFuncs, cancel)
  1019. }