agent.go 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764
  1. package agent
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "strings"
  7. "sync"
  8. "time"
  9. "log/slog"
  10. "github.com/opencode-ai/opencode/internal/config"
  11. "github.com/opencode-ai/opencode/internal/llm/models"
  12. "github.com/opencode-ai/opencode/internal/llm/prompt"
  13. "github.com/opencode-ai/opencode/internal/llm/provider"
  14. "github.com/opencode-ai/opencode/internal/llm/tools"
  15. "github.com/opencode-ai/opencode/internal/logging"
  16. "github.com/opencode-ai/opencode/internal/message"
  17. "github.com/opencode-ai/opencode/internal/permission"
  18. "github.com/opencode-ai/opencode/internal/session"
  19. "github.com/opencode-ai/opencode/internal/status"
  20. )
  21. // Common errors
  22. var (
  23. ErrRequestCancelled = errors.New("request cancelled by user")
  24. ErrSessionBusy = errors.New("session is currently processing another request")
  25. )
  26. type AgentEvent struct {
  27. message message.Message
  28. err error
  29. }
  30. func (e *AgentEvent) Err() error {
  31. return e.err
  32. }
  33. func (e *AgentEvent) Response() message.Message {
  34. return e.message
  35. }
  36. type Service interface {
  37. Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  38. Cancel(sessionID string)
  39. IsSessionBusy(sessionID string) bool
  40. IsBusy() bool
  41. Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
  42. CompactSession(ctx context.Context, sessionID string) error
  43. PauseSession(sessionID string) error
  44. ResumeSession(sessionID string) error
  45. }
  46. type agent struct {
  47. sessions session.Service
  48. messages message.Service
  49. tools []tools.BaseTool
  50. provider provider.Provider
  51. titleProvider provider.Provider
  52. activeRequests sync.Map
  53. pauseLock sync.RWMutex // Lock for pausing message processing
  54. }
  55. func NewAgent(
  56. agentName config.AgentName,
  57. sessions session.Service,
  58. messages message.Service,
  59. agentTools []tools.BaseTool,
  60. ) (Service, error) {
  61. agentProvider, err := createAgentProvider(agentName)
  62. if err != nil {
  63. return nil, err
  64. }
  65. var titleProvider provider.Provider
  66. // Only generate titles for the coder agent
  67. if agentName == config.AgentCoder {
  68. titleProvider, err = createAgentProvider(config.AgentTitle)
  69. if err != nil {
  70. return nil, err
  71. }
  72. }
  73. agent := &agent{
  74. provider: agentProvider,
  75. messages: messages,
  76. sessions: sessions,
  77. tools: agentTools,
  78. titleProvider: titleProvider,
  79. activeRequests: sync.Map{},
  80. }
  81. return agent, nil
  82. }
  83. func (a *agent) Cancel(sessionID string) {
  84. if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
  85. if cancel, ok := cancelFunc.(context.CancelFunc); ok {
  86. status.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
  87. cancel()
  88. }
  89. }
  90. }
  91. func (a *agent) IsBusy() bool {
  92. busy := false
  93. a.activeRequests.Range(func(key, value interface{}) bool {
  94. if cancelFunc, ok := value.(context.CancelFunc); ok {
  95. if cancelFunc != nil {
  96. busy = true
  97. return false // Stop iterating
  98. }
  99. }
  100. return true // Continue iterating
  101. })
  102. return busy
  103. }
  104. func (a *agent) IsSessionBusy(sessionID string) bool {
  105. _, busy := a.activeRequests.Load(sessionID)
  106. return busy
  107. }
  108. func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
  109. if content == "" {
  110. return nil
  111. }
  112. if a.titleProvider == nil {
  113. return nil
  114. }
  115. session, err := a.sessions.Get(ctx, sessionID)
  116. if err != nil {
  117. return err
  118. }
  119. parts := []message.ContentPart{message.TextContent{Text: content}}
  120. response, err := a.titleProvider.SendMessages(
  121. ctx,
  122. []message.Message{
  123. {
  124. Role: message.User,
  125. Parts: parts,
  126. },
  127. },
  128. make([]tools.BaseTool, 0),
  129. )
  130. if err != nil {
  131. return err
  132. }
  133. title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
  134. if title == "" {
  135. return nil
  136. }
  137. session.Title = title
  138. _, err = a.sessions.Update(ctx, session)
  139. return err
  140. }
  141. func (a *agent) err(err error) AgentEvent {
  142. return AgentEvent{
  143. err: err,
  144. }
  145. }
  146. func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
  147. if !a.provider.Model().SupportsAttachments && attachments != nil {
  148. attachments = nil
  149. }
  150. events := make(chan AgentEvent)
  151. if a.IsSessionBusy(sessionID) {
  152. return nil, ErrSessionBusy
  153. }
  154. genCtx, cancel := context.WithCancel(ctx)
  155. a.activeRequests.Store(sessionID, cancel)
  156. go func() {
  157. slog.Debug("Request started", "sessionID", sessionID)
  158. defer logging.RecoverPanic("agent.Run", func() {
  159. events <- a.err(fmt.Errorf("panic while running the agent"))
  160. })
  161. var attachmentParts []message.ContentPart
  162. for _, attachment := range attachments {
  163. attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
  164. }
  165. result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
  166. if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
  167. status.Error(result.Err().Error())
  168. }
  169. slog.Debug("Request completed", "sessionID", sessionID)
  170. a.activeRequests.Delete(sessionID)
  171. cancel()
  172. events <- result
  173. close(events)
  174. }()
  175. return events, nil
  176. }
  177. func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
  178. // Get the current session to check for summary
  179. currentSession, err := a.sessions.Get(ctx, sessionID)
  180. if err != nil {
  181. return a.err(fmt.Errorf("failed to get session: %w", err))
  182. }
  183. // Fetch messages based on whether a summary exists
  184. var sessionMessages []message.Message
  185. if currentSession.Summary != "" && currentSession.SummarizedAt > 0 {
  186. // If summary exists, only fetch messages after the summarization timestamp
  187. sessionMessages, err = a.messages.ListAfter(ctx, sessionID, currentSession.SummarizedAt)
  188. if err != nil {
  189. return a.err(fmt.Errorf("failed to list messages after summary: %w", err))
  190. }
  191. } else {
  192. // If no summary, fetch all messages
  193. sessionMessages, err = a.messages.List(ctx, sessionID)
  194. if err != nil {
  195. return a.err(fmt.Errorf("failed to list messages: %w", err))
  196. }
  197. }
  198. // If this is a new session, start title generation asynchronously
  199. if len(sessionMessages) == 0 && currentSession.Summary == "" {
  200. go func() {
  201. defer logging.RecoverPanic("agent.Run", func() {
  202. status.Error("panic while generating title")
  203. })
  204. titleErr := a.generateTitle(context.Background(), sessionID, content)
  205. if titleErr != nil {
  206. status.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
  207. }
  208. }()
  209. }
  210. userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
  211. if err != nil {
  212. return a.err(fmt.Errorf("failed to create user message: %w", err))
  213. }
  214. // Prepare the message history for the LLM
  215. var messages []message.Message
  216. if currentSession.Summary != "" && currentSession.SummarizedAt > 0 {
  217. // If summary exists, create a temporary message for the summary
  218. summaryMessage := message.Message{
  219. Role: message.Assistant,
  220. Parts: []message.ContentPart{
  221. message.TextContent{Text: currentSession.Summary},
  222. },
  223. }
  224. // Start with the summary, then add messages after the summary timestamp
  225. messages = append([]message.Message{summaryMessage}, sessionMessages...)
  226. } else {
  227. // If no summary, just use all messages
  228. messages = sessionMessages
  229. }
  230. // Append the new user message to the conversation history
  231. messages = append(messages, userMsg)
  232. for {
  233. // Check for cancellation before each iteration
  234. select {
  235. case <-ctx.Done():
  236. return a.err(ctx.Err())
  237. default:
  238. // Continue processing
  239. }
  240. agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, messages)
  241. if err != nil {
  242. if errors.Is(err, context.Canceled) {
  243. agentMessage.AddFinish(message.FinishReasonCanceled)
  244. a.messages.Update(context.Background(), agentMessage)
  245. return a.err(ErrRequestCancelled)
  246. }
  247. return a.err(fmt.Errorf("failed to process events: %w", err))
  248. }
  249. slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
  250. if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
  251. // We are not done, we need to respond with the tool response
  252. messages = append(messages, agentMessage, *toolResults)
  253. continue
  254. }
  255. return AgentEvent{
  256. message: agentMessage,
  257. }
  258. }
  259. }
  260. func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
  261. parts := []message.ContentPart{message.TextContent{Text: content}}
  262. parts = append(parts, attachmentParts...)
  263. return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  264. Role: message.User,
  265. Parts: parts,
  266. })
  267. }
  268. func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
  269. // Check if we need to auto-compact based on token count
  270. contextWindow := a.provider.Model().ContextWindow
  271. maxTokens := a.provider.MaxTokens()
  272. threshold := int64(float64(contextWindow) * 0.80)
  273. usage, err := a.GetUsage(ctx, sessionID)
  274. if err != nil || usage == nil {
  275. return message.Message{}, nil, fmt.Errorf("failed to get usage: %w", err)
  276. }
  277. // If we're approaching the context window limit, trigger auto-compaction
  278. if false && (*usage+maxTokens) >= threshold {
  279. status.Info(fmt.Sprintf("Auto-compaction triggered for session %s. Estimated tokens: %d, Threshold: %d", sessionID, usage, threshold))
  280. // Perform compaction with pause/resume to ensure safety
  281. if err := a.CompactSession(ctx, sessionID); err != nil {
  282. status.Error(fmt.Sprintf("Auto-compaction failed: %v", err))
  283. // Continue with the request even if compaction fails
  284. } else {
  285. // Re-fetch session details after compaction
  286. currentSession, err := a.sessions.Get(ctx, sessionID)
  287. if err != nil {
  288. return message.Message{}, nil, fmt.Errorf("failed to get session after compaction: %w", err)
  289. }
  290. // Re-prepare messages using the new summary
  291. var sessionMessages []message.Message
  292. if currentSession.Summary != "" && currentSession.SummarizedAt > 0 {
  293. // If summary exists, only fetch messages after the summarization timestamp
  294. sessionMessages, err = a.messages.ListAfter(ctx, sessionID, currentSession.SummarizedAt)
  295. if err != nil {
  296. return message.Message{}, nil, fmt.Errorf("failed to list messages after compaction: %w", err)
  297. }
  298. // Create a new message history with the summary and messages after summarization
  299. summaryMessage := message.Message{
  300. Role: message.Assistant,
  301. Parts: []message.ContentPart{
  302. message.TextContent{Text: currentSession.Summary},
  303. },
  304. }
  305. // Replace msgHistory with the new compacted version
  306. msgHistory = append([]message.Message{summaryMessage}, sessionMessages...)
  307. }
  308. }
  309. }
  310. eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
  311. assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  312. Role: message.Assistant,
  313. Parts: []message.ContentPart{},
  314. Model: a.provider.Model().ID,
  315. })
  316. if err != nil {
  317. return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
  318. }
  319. // Add the session and message ID into the context if needed by tools.
  320. ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
  321. ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
  322. // Process each event in the stream.
  323. for event := range eventChan {
  324. if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
  325. a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
  326. return assistantMsg, nil, processErr
  327. }
  328. if ctx.Err() != nil {
  329. a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
  330. return assistantMsg, nil, ctx.Err()
  331. }
  332. }
  333. toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
  334. toolCalls := assistantMsg.ToolCalls()
  335. for i, toolCall := range toolCalls {
  336. select {
  337. case <-ctx.Done():
  338. a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
  339. // Make all future tool calls cancelled
  340. for j := i; j < len(toolCalls); j++ {
  341. toolResults[j] = message.ToolResult{
  342. ToolCallID: toolCalls[j].ID,
  343. Content: "Tool execution canceled by user",
  344. IsError: true,
  345. }
  346. }
  347. goto out
  348. default:
  349. // Continue processing
  350. var tool tools.BaseTool
  351. for _, availableTools := range a.tools {
  352. if availableTools.Info().Name == toolCall.Name {
  353. tool = availableTools
  354. }
  355. }
  356. // Tool not found
  357. if tool == nil {
  358. toolResults[i] = message.ToolResult{
  359. ToolCallID: toolCall.ID,
  360. Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
  361. IsError: true,
  362. }
  363. continue
  364. }
  365. toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
  366. ID: toolCall.ID,
  367. Name: toolCall.Name,
  368. Input: toolCall.Input,
  369. })
  370. if toolErr != nil {
  371. if errors.Is(toolErr, permission.ErrorPermissionDenied) {
  372. toolResults[i] = message.ToolResult{
  373. ToolCallID: toolCall.ID,
  374. Content: "Permission denied",
  375. IsError: true,
  376. }
  377. for j := i + 1; j < len(toolCalls); j++ {
  378. toolResults[j] = message.ToolResult{
  379. ToolCallID: toolCalls[j].ID,
  380. Content: "Tool execution canceled by user",
  381. IsError: true,
  382. }
  383. }
  384. a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
  385. break
  386. }
  387. }
  388. toolResults[i] = message.ToolResult{
  389. ToolCallID: toolCall.ID,
  390. Content: toolResult.Content,
  391. Metadata: toolResult.Metadata,
  392. IsError: toolResult.IsError,
  393. }
  394. }
  395. }
  396. out:
  397. if len(toolResults) == 0 {
  398. return assistantMsg, nil, nil
  399. }
  400. parts := make([]message.ContentPart, 0)
  401. for _, tr := range toolResults {
  402. parts = append(parts, tr)
  403. }
  404. msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
  405. Role: message.Tool,
  406. Parts: parts,
  407. })
  408. if err != nil {
  409. return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
  410. }
  411. return assistantMsg, &msg, err
  412. }
  413. func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
  414. msg.AddFinish(finishReson)
  415. _, _ = a.messages.Update(ctx, *msg)
  416. }
  417. func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
  418. select {
  419. case <-ctx.Done():
  420. return ctx.Err()
  421. default:
  422. // Continue processing.
  423. }
  424. // Check if session is paused - use RLock to allow concurrent reads but block during pause
  425. a.pauseLock.RLock()
  426. defer a.pauseLock.RUnlock()
  427. switch event.Type {
  428. case provider.EventThinkingDelta:
  429. assistantMsg.AppendReasoningContent(event.Content)
  430. _, err := a.messages.Update(ctx, *assistantMsg)
  431. return err
  432. case provider.EventContentDelta:
  433. assistantMsg.AppendContent(event.Content)
  434. _, err := a.messages.Update(ctx, *assistantMsg)
  435. return err
  436. case provider.EventToolUseStart:
  437. assistantMsg.AddToolCall(*event.ToolCall)
  438. _, err := a.messages.Update(ctx, *assistantMsg)
  439. return err
  440. // TODO: see how to handle this
  441. // case provider.EventToolUseDelta:
  442. // tm := time.Unix(assistantMsg.UpdatedAt, 0)
  443. // assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
  444. // if time.Since(tm) > 1000*time.Millisecond {
  445. // err := a.messages.Update(ctx, *assistantMsg)
  446. // assistantMsg.UpdatedAt = time.Now().Unix()
  447. // return err
  448. // }
  449. case provider.EventToolUseStop:
  450. assistantMsg.FinishToolCall(event.ToolCall.ID)
  451. _, err := a.messages.Update(ctx, *assistantMsg)
  452. return err
  453. case provider.EventError:
  454. if errors.Is(event.Error, context.Canceled) {
  455. status.Info(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
  456. return context.Canceled
  457. }
  458. status.Error(event.Error.Error())
  459. return event.Error
  460. case provider.EventComplete:
  461. assistantMsg.SetToolCalls(event.Response.ToolCalls)
  462. assistantMsg.AddFinish(event.Response.FinishReason)
  463. if _, err := a.messages.Update(ctx, *assistantMsg); err != nil {
  464. return fmt.Errorf("failed to update message: %w", err)
  465. }
  466. return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
  467. }
  468. return nil
  469. }
  470. func (a *agent) GetUsage(ctx context.Context, sessionID string) (*int64, error) {
  471. session, err := a.sessions.Get(ctx, sessionID)
  472. if err != nil {
  473. return nil, fmt.Errorf("failed to get session: %w", err)
  474. }
  475. usage := session.PromptTokens + session.CompletionTokens
  476. return &usage, nil
  477. }
  478. func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
  479. sess, err := a.sessions.Get(ctx, sessionID)
  480. if err != nil {
  481. return fmt.Errorf("failed to get session: %w", err)
  482. }
  483. cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
  484. model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
  485. model.CostPer1MIn/1e6*float64(usage.InputTokens) +
  486. model.CostPer1MOut/1e6*float64(usage.OutputTokens)
  487. sess.Cost += cost
  488. sess.CompletionTokens += usage.OutputTokens
  489. sess.PromptTokens += usage.InputTokens
  490. _, err = a.sessions.Update(ctx, sess)
  491. if err != nil {
  492. return fmt.Errorf("failed to save session: %w", err)
  493. }
  494. return nil
  495. }
  496. func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) {
  497. if a.IsBusy() {
  498. return models.Model{}, fmt.Errorf("cannot change model while processing requests")
  499. }
  500. if err := config.UpdateAgentModel(agentName, modelID); err != nil {
  501. return models.Model{}, fmt.Errorf("failed to update config: %w", err)
  502. }
  503. provider, err := createAgentProvider(agentName)
  504. if err != nil {
  505. return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err)
  506. }
  507. a.provider = provider
  508. return a.provider.Model(), nil
  509. }
  510. // PauseSession pauses message processing for a specific session
  511. // This should be called before performing operations that require exclusive access
  512. func (a *agent) PauseSession(sessionID string) error {
  513. if !a.IsSessionBusy(sessionID) {
  514. return nil // Session is not active, no need to pause
  515. }
  516. status.Info(fmt.Sprintf("Pausing session: %s", sessionID))
  517. a.pauseLock.Lock() // Acquire write lock to block new operations
  518. return nil
  519. }
  520. // ResumeSession resumes message processing for a session
  521. // This should be called after completing operations that required exclusive access
  522. func (a *agent) ResumeSession(sessionID string) error {
  523. status.Info(fmt.Sprintf("Resuming session: %s", sessionID))
  524. a.pauseLock.Unlock() // Release write lock to allow operations to continue
  525. return nil
  526. }
  527. func (a *agent) CompactSession(ctx context.Context, sessionID string) error {
  528. // Check if the session is busy
  529. if a.IsSessionBusy(sessionID) {
  530. // Pause the session before compaction
  531. if err := a.PauseSession(sessionID); err != nil {
  532. return fmt.Errorf("failed to pause session: %w", err)
  533. }
  534. // Make sure to resume the session when we're done
  535. defer a.ResumeSession(sessionID)
  536. status.Info(fmt.Sprintf("Session %s paused for compaction", sessionID))
  537. }
  538. // Create a cancellable context
  539. ctx, cancel := context.WithCancel(ctx)
  540. defer cancel()
  541. // Mark the session as busy during compaction
  542. compactionCancelFunc := func() {}
  543. a.activeRequests.Store(sessionID+"-compact", compactionCancelFunc)
  544. defer a.activeRequests.Delete(sessionID + "-compact")
  545. // Fetch the session
  546. session, err := a.sessions.Get(ctx, sessionID)
  547. if err != nil {
  548. return fmt.Errorf("failed to get session: %w", err)
  549. }
  550. // Fetch all messages for the session
  551. sessionMessages, err := a.messages.List(ctx, sessionID)
  552. if err != nil {
  553. return fmt.Errorf("failed to list messages: %w", err)
  554. }
  555. var existingSummary string
  556. if session.Summary != "" && session.SummarizedAt > 0 {
  557. // Filter messages that were created after the last summarization
  558. var newMessages []message.Message
  559. for _, msg := range sessionMessages {
  560. if msg.CreatedAt > session.SummarizedAt {
  561. newMessages = append(newMessages, msg)
  562. }
  563. }
  564. sessionMessages = newMessages
  565. existingSummary = session.Summary
  566. }
  567. // If there are no messages to summarize and no existing summary, return early
  568. if len(sessionMessages) == 0 && existingSummary == "" {
  569. return nil
  570. }
  571. messages := []message.Message{
  572. message.Message{
  573. Role: message.System,
  574. Parts: []message.ContentPart{
  575. message.TextContent{
  576. Text: "You are a helpful AI assistant tasked with summarizing conversations.",
  577. },
  578. },
  579. },
  580. }
  581. // If there's an existing summary, include it
  582. if existingSummary != "" {
  583. messages = append(messages, message.Message{
  584. Role: message.Assistant, // TODO: should this be system or user instead?
  585. Parts: []message.ContentPart{
  586. message.TextContent{
  587. Text: existingSummary,
  588. },
  589. },
  590. })
  591. }
  592. // Add all messages since the last summarized message
  593. messages = append(messages, sessionMessages...)
  594. // Add a final user message requesting the summary
  595. messages = append(messages, message.Message{
  596. Role: message.User,
  597. Parts: []message.ContentPart{
  598. message.TextContent{
  599. Text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  600. },
  601. },
  602. })
  603. // Call provider to get the summary
  604. response, err := a.provider.SendMessages(ctx, messages, a.tools)
  605. if err != nil {
  606. return fmt.Errorf("failed to get summary from the assistant: %w", err)
  607. }
  608. // Extract the summary text
  609. summaryText := strings.TrimSpace(response.Content)
  610. if summaryText == "" {
  611. return fmt.Errorf("received empty summary from the assistant")
  612. }
  613. // Update the session with the new summary
  614. currentTime := time.Now().UnixMilli()
  615. session.Summary = summaryText
  616. session.SummarizedAt = currentTime
  617. // Save the updated session
  618. _, err = a.sessions.Update(ctx, session)
  619. if err != nil {
  620. return fmt.Errorf("failed to save session with summary: %w", err)
  621. }
  622. // Track token usage
  623. err = a.TrackUsage(ctx, sessionID, a.provider.Model(), response.Usage)
  624. if err != nil {
  625. return fmt.Errorf("failed to track usage: %w", err)
  626. }
  627. return nil
  628. }
  629. func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
  630. cfg := config.Get()
  631. agentConfig, ok := cfg.Agents[agentName]
  632. if !ok {
  633. return nil, fmt.Errorf("agent %s not found", agentName)
  634. }
  635. model, ok := models.SupportedModels[agentConfig.Model]
  636. if !ok {
  637. return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
  638. }
  639. providerCfg, ok := cfg.Providers[model.Provider]
  640. if !ok {
  641. return nil, fmt.Errorf("provider %s not supported", model.Provider)
  642. }
  643. if providerCfg.Disabled {
  644. return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
  645. }
  646. maxTokens := model.DefaultMaxTokens
  647. if agentConfig.MaxTokens > 0 {
  648. maxTokens = agentConfig.MaxTokens
  649. }
  650. opts := []provider.ProviderClientOption{
  651. provider.WithAPIKey(providerCfg.APIKey),
  652. provider.WithModel(model),
  653. provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
  654. provider.WithMaxTokens(maxTokens),
  655. }
  656. if model.Provider == models.ProviderOpenAI && model.CanReason {
  657. opts = append(
  658. opts,
  659. provider.WithOpenAIOptions(
  660. provider.WithReasoningEffort(agentConfig.ReasoningEffort),
  661. ),
  662. )
  663. } else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder {
  664. opts = append(
  665. opts,
  666. provider.WithAnthropicOptions(
  667. provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
  668. ),
  669. )
  670. }
  671. agentProvider, err := provider.NewProvider(
  672. model.Provider,
  673. opts...,
  674. )
  675. if err != nil {
  676. return nil, fmt.Errorf("could not create provider: %v", err)
  677. }
  678. return agentProvider, nil
  679. }