agent.go 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769
  1. package agent
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "strings"
  7. "sync"
  8. "time"
  9. "github.com/opencode-ai/opencode/internal/config"
  10. "github.com/opencode-ai/opencode/internal/llm/models"
  11. "github.com/opencode-ai/opencode/internal/llm/prompt"
  12. "github.com/opencode-ai/opencode/internal/llm/provider"
  13. "github.com/opencode-ai/opencode/internal/llm/tools"
  14. "github.com/opencode-ai/opencode/internal/logging"
  15. "github.com/opencode-ai/opencode/internal/message"
  16. "github.com/opencode-ai/opencode/internal/permission"
  17. "github.com/opencode-ai/opencode/internal/session"
  18. )
  19. // Common errors
  20. var (
  21. ErrRequestCancelled = errors.New("request cancelled by user")
  22. ErrSessionBusy = errors.New("session is currently processing another request")
  23. )
  24. type AgentEvent struct {
  25. message message.Message
  26. err error
  27. }
  28. func (e *AgentEvent) Err() error {
  29. return e.err
  30. }
  31. func (e *AgentEvent) Response() message.Message {
  32. return e.message
  33. }
  34. type Service interface {
  35. Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  36. Cancel(sessionID string)
  37. IsSessionBusy(sessionID string) bool
  38. IsBusy() bool
  39. Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
  40. CompactSession(ctx context.Context, sessionID string) error
  41. PauseSession(sessionID string) error
  42. ResumeSession(sessionID string) error
  43. }
  44. type agent struct {
  45. sessions session.Service
  46. messages message.Service
  47. tools []tools.BaseTool
  48. provider provider.Provider
  49. titleProvider provider.Provider
  50. activeRequests sync.Map
  51. pauseLock sync.RWMutex // Lock for pausing message processing
  52. }
  53. func NewAgent(
  54. agentName config.AgentName,
  55. sessions session.Service,
  56. messages message.Service,
  57. agentTools []tools.BaseTool,
  58. ) (Service, error) {
  59. agentProvider, err := createAgentProvider(agentName)
  60. if err != nil {
  61. return nil, err
  62. }
  63. var titleProvider provider.Provider
  64. // Only generate titles for the coder agent
  65. if agentName == config.AgentCoder {
  66. titleProvider, err = createAgentProvider(config.AgentTitle)
  67. if err != nil {
  68. return nil, err
  69. }
  70. }
  71. agent := &agent{
  72. provider: agentProvider,
  73. messages: messages,
  74. sessions: sessions,
  75. tools: agentTools,
  76. titleProvider: titleProvider,
  77. activeRequests: sync.Map{},
  78. }
  79. return agent, nil
  80. }
  81. func (a *agent) Cancel(sessionID string) {
  82. if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
  83. if cancel, ok := cancelFunc.(context.CancelFunc); ok {
  84. logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
  85. cancel()
  86. }
  87. }
  88. }
  89. func (a *agent) IsBusy() bool {
  90. busy := false
  91. a.activeRequests.Range(func(key, value interface{}) bool {
  92. if cancelFunc, ok := value.(context.CancelFunc); ok {
  93. if cancelFunc != nil {
  94. busy = true
  95. return false // Stop iterating
  96. }
  97. }
  98. return true // Continue iterating
  99. })
  100. return busy
  101. }
  102. func (a *agent) IsSessionBusy(sessionID string) bool {
  103. _, busy := a.activeRequests.Load(sessionID)
  104. return busy
  105. }
  106. func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
  107. if content == "" {
  108. return nil
  109. }
  110. if a.titleProvider == nil {
  111. return nil
  112. }
  113. session, err := a.sessions.Get(ctx, sessionID)
  114. if err != nil {
  115. return err
  116. }
  117. parts := []message.ContentPart{message.TextContent{Text: content}}
  118. response, err := a.titleProvider.SendMessages(
  119. ctx,
  120. []message.Message{
  121. {
  122. Role: message.User,
  123. Parts: parts,
  124. },
  125. },
  126. make([]tools.BaseTool, 0),
  127. )
  128. if err != nil {
  129. return err
  130. }
  131. title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
  132. if title == "" {
  133. return nil
  134. }
  135. session.Title = title
  136. _, err = a.sessions.Save(ctx, session)
  137. return err
  138. }
  139. func (a *agent) err(err error) AgentEvent {
  140. return AgentEvent{
  141. err: err,
  142. }
  143. }
  144. func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
  145. if !a.provider.Model().SupportsAttachments && attachments != nil {
  146. attachments = nil
  147. }
  148. events := make(chan AgentEvent)
  149. if a.IsSessionBusy(sessionID) {
  150. return nil, ErrSessionBusy
  151. }
  152. genCtx, cancel := context.WithCancel(ctx)
  153. a.activeRequests.Store(sessionID, cancel)
  154. go func() {
  155. logging.Debug("Request started", "sessionID", sessionID)
  156. defer logging.RecoverPanic("agent.Run", func() {
  157. events <- a.err(fmt.Errorf("panic while running the agent"))
  158. })
  159. var attachmentParts []message.ContentPart
  160. for _, attachment := range attachments {
  161. attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
  162. }
  163. result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
  164. if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
  165. logging.ErrorPersist(result.Err().Error())
  166. }
  167. logging.Debug("Request completed", "sessionID", sessionID)
  168. a.activeRequests.Delete(sessionID)
  169. cancel()
  170. events <- result
  171. close(events)
  172. }()
  173. return events, nil
  174. }
  175. func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
  176. // Get the current session to check for summary
  177. currentSession, err := a.sessions.Get(ctx, sessionID)
  178. if err != nil {
  179. return a.err(fmt.Errorf("failed to get session: %w", err))
  180. }
  181. // Fetch messages based on whether a summary exists
  182. var sessionMessages []message.Message
  183. if currentSession.Summary != "" && currentSession.SummarizedAt > 0 {
  184. // If summary exists, only fetch messages after the summarization timestamp
  185. sessionMessages, err = a.messages.ListAfter(ctx, sessionID, currentSession.SummarizedAt)
  186. if err != nil {
  187. return a.err(fmt.Errorf("failed to list messages after summary: %w", err))
  188. }
  189. } else {
  190. // If no summary, fetch all messages
  191. sessionMessages, err = a.messages.List(ctx, sessionID)
  192. if err != nil {
  193. return a.err(fmt.Errorf("failed to list messages: %w", err))
  194. }
  195. }
  196. // If this is a new session, start title generation asynchronously
  197. if len(sessionMessages) == 0 && currentSession.Summary == "" {
  198. go func() {
  199. defer logging.RecoverPanic("agent.Run", func() {
  200. logging.ErrorPersist("panic while generating title")
  201. })
  202. titleErr := a.generateTitle(context.Background(), sessionID, content)
  203. if titleErr != nil {
  204. logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
  205. }
  206. }()
  207. }
  208. userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
  209. if err != nil {
  210. return a.err(fmt.Errorf("failed to create user message: %w", err))
  211. }
  212. // Prepare the message history for the LLM
  213. var messages []message.Message
  214. if currentSession.Summary != "" && currentSession.SummarizedAt > 0 {
  215. // If summary exists, create a temporary message for the summary
  216. summaryMessage := message.Message{
  217. Role: message.Assistant,
  218. Parts: []message.ContentPart{
  219. message.TextContent{Text: currentSession.Summary},
  220. },
  221. }
  222. // Start with the summary, then add messages after the summary timestamp
  223. messages = append([]message.Message{summaryMessage}, sessionMessages...)
  224. } else {
  225. // If no summary, just use all messages
  226. messages = sessionMessages
  227. }
  228. // Append the new user message to the conversation history
  229. messages = append(messages, userMsg)
  230. for {
  231. // Check for cancellation before each iteration
  232. select {
  233. case <-ctx.Done():
  234. return a.err(ctx.Err())
  235. default:
  236. // Continue processing
  237. }
  238. agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, messages)
  239. if err != nil {
  240. if errors.Is(err, context.Canceled) {
  241. agentMessage.AddFinish(message.FinishReasonCanceled)
  242. a.messages.Update(context.Background(), agentMessage)
  243. return a.err(ErrRequestCancelled)
  244. }
  245. return a.err(fmt.Errorf("failed to process events: %w", err))
  246. }
  247. logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
  248. if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
  249. // We are not done, we need to respond with the tool response
  250. messages = append(messages, agentMessage, *toolResults)
  251. continue
  252. }
  253. return AgentEvent{
  254. message: agentMessage,
  255. }
  256. }
  257. }
  258. func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
  259. parts := []message.ContentPart{message.TextContent{Text: content}}
  260. parts = append(parts, attachmentParts...)
  261. return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  262. Role: message.User,
  263. Parts: parts,
  264. })
  265. }
  266. // estimateTokens provides a rough estimate of token count based on character count
  267. // using a simple heuristic of ~4 characters per token
  268. func estimateTokens(messages []message.Message) int64 {
  269. totalChars := 0
  270. for _, msg := range messages {
  271. // Get text content from all parts
  272. for _, part := range msg.Parts {
  273. if textContent, ok := part.(message.TextContent); ok {
  274. totalChars += len(textContent.Text)
  275. } else {
  276. // For non-text parts, add a conservative estimate
  277. totalChars += 100
  278. }
  279. }
  280. // Add chars for role (conservative estimate)
  281. totalChars += 10
  282. }
  283. // Heuristic: ~4 chars per token
  284. return int64(totalChars / 4)
  285. }
  286. func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
  287. // Check if we need to auto-compact based on token count
  288. contextWindow := a.provider.Model().ContextWindow
  289. threshold := int64(float64(contextWindow) * 0.80)
  290. estimatedTokens := estimateTokens(msgHistory)
  291. // If we're approaching the context window limit, trigger auto-compaction
  292. if estimatedTokens >= threshold {
  293. logging.InfoPersist(fmt.Sprintf("Auto-compaction triggered for session %s. Estimated tokens: %d, Threshold: %d", sessionID, estimatedTokens, threshold))
  294. // Perform compaction with pause/resume to ensure safety
  295. if err := a.CompactSession(ctx, sessionID); err != nil {
  296. logging.ErrorPersist(fmt.Sprintf("Auto-compaction failed: %v", err))
  297. // Continue with the request even if compaction fails
  298. } else {
  299. // Re-fetch session details after compaction
  300. currentSession, err := a.sessions.Get(ctx, sessionID)
  301. if err != nil {
  302. return message.Message{}, nil, fmt.Errorf("failed to get session after compaction: %w", err)
  303. }
  304. // Re-prepare messages using the new summary
  305. var sessionMessages []message.Message
  306. if currentSession.Summary != "" && currentSession.SummarizedAt > 0 {
  307. // If summary exists, only fetch messages after the summarization timestamp
  308. sessionMessages, err = a.messages.ListAfter(ctx, sessionID, currentSession.SummarizedAt)
  309. if err != nil {
  310. return message.Message{}, nil, fmt.Errorf("failed to list messages after compaction: %w", err)
  311. }
  312. // Create a new message history with the summary and messages after summarization
  313. summaryMessage := message.Message{
  314. Role: message.Assistant,
  315. Parts: []message.ContentPart{
  316. message.TextContent{Text: currentSession.Summary},
  317. },
  318. }
  319. // Replace msgHistory with the new compacted version
  320. msgHistory = append([]message.Message{summaryMessage}, sessionMessages...)
  321. // Log the new token estimate after compaction
  322. newEstimate := estimateTokens(msgHistory)
  323. logging.InfoPersist(fmt.Sprintf("After compaction: Estimated tokens: %d (reduced by %d)",
  324. newEstimate, estimatedTokens-newEstimate))
  325. }
  326. }
  327. }
  328. eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
  329. assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  330. Role: message.Assistant,
  331. Parts: []message.ContentPart{},
  332. Model: a.provider.Model().ID,
  333. })
  334. if err != nil {
  335. return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
  336. }
  337. // Add the session and message ID into the context if needed by tools.
  338. ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
  339. ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
  340. // Process each event in the stream.
  341. for event := range eventChan {
  342. if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
  343. a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
  344. return assistantMsg, nil, processErr
  345. }
  346. if ctx.Err() != nil {
  347. a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
  348. return assistantMsg, nil, ctx.Err()
  349. }
  350. }
  351. toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
  352. toolCalls := assistantMsg.ToolCalls()
  353. for i, toolCall := range toolCalls {
  354. select {
  355. case <-ctx.Done():
  356. a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
  357. // Make all future tool calls cancelled
  358. for j := i; j < len(toolCalls); j++ {
  359. toolResults[j] = message.ToolResult{
  360. ToolCallID: toolCalls[j].ID,
  361. Content: "Tool execution canceled by user",
  362. IsError: true,
  363. }
  364. }
  365. goto out
  366. default:
  367. // Continue processing
  368. var tool tools.BaseTool
  369. for _, availableTools := range a.tools {
  370. if availableTools.Info().Name == toolCall.Name {
  371. tool = availableTools
  372. }
  373. }
  374. // Tool not found
  375. if tool == nil {
  376. toolResults[i] = message.ToolResult{
  377. ToolCallID: toolCall.ID,
  378. Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
  379. IsError: true,
  380. }
  381. continue
  382. }
  383. toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
  384. ID: toolCall.ID,
  385. Name: toolCall.Name,
  386. Input: toolCall.Input,
  387. })
  388. if toolErr != nil {
  389. if errors.Is(toolErr, permission.ErrorPermissionDenied) {
  390. toolResults[i] = message.ToolResult{
  391. ToolCallID: toolCall.ID,
  392. Content: "Permission denied",
  393. IsError: true,
  394. }
  395. for j := i + 1; j < len(toolCalls); j++ {
  396. toolResults[j] = message.ToolResult{
  397. ToolCallID: toolCalls[j].ID,
  398. Content: "Tool execution canceled by user",
  399. IsError: true,
  400. }
  401. }
  402. a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
  403. break
  404. }
  405. }
  406. toolResults[i] = message.ToolResult{
  407. ToolCallID: toolCall.ID,
  408. Content: toolResult.Content,
  409. Metadata: toolResult.Metadata,
  410. IsError: toolResult.IsError,
  411. }
  412. }
  413. }
  414. out:
  415. if len(toolResults) == 0 {
  416. return assistantMsg, nil, nil
  417. }
  418. parts := make([]message.ContentPart, 0)
  419. for _, tr := range toolResults {
  420. parts = append(parts, tr)
  421. }
  422. msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
  423. Role: message.Tool,
  424. Parts: parts,
  425. })
  426. if err != nil {
  427. return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
  428. }
  429. return assistantMsg, &msg, err
  430. }
  431. func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
  432. msg.AddFinish(finishReson)
  433. _ = a.messages.Update(ctx, *msg)
  434. }
  435. func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
  436. select {
  437. case <-ctx.Done():
  438. return ctx.Err()
  439. default:
  440. // Continue processing.
  441. }
  442. // Check if session is paused - use RLock to allow concurrent reads but block during pause
  443. a.pauseLock.RLock()
  444. defer a.pauseLock.RUnlock()
  445. switch event.Type {
  446. case provider.EventThinkingDelta:
  447. assistantMsg.AppendReasoningContent(event.Content)
  448. return a.messages.Update(ctx, *assistantMsg)
  449. case provider.EventContentDelta:
  450. assistantMsg.AppendContent(event.Content)
  451. return a.messages.Update(ctx, *assistantMsg)
  452. case provider.EventToolUseStart:
  453. assistantMsg.AddToolCall(*event.ToolCall)
  454. return a.messages.Update(ctx, *assistantMsg)
  455. // TODO: see how to handle this
  456. // case provider.EventToolUseDelta:
  457. // tm := time.Unix(assistantMsg.UpdatedAt, 0)
  458. // assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
  459. // if time.Since(tm) > 1000*time.Millisecond {
  460. // err := a.messages.Update(ctx, *assistantMsg)
  461. // assistantMsg.UpdatedAt = time.Now().Unix()
  462. // return err
  463. // }
  464. case provider.EventToolUseStop:
  465. assistantMsg.FinishToolCall(event.ToolCall.ID)
  466. return a.messages.Update(ctx, *assistantMsg)
  467. case provider.EventError:
  468. if errors.Is(event.Error, context.Canceled) {
  469. logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
  470. return context.Canceled
  471. }
  472. logging.ErrorPersist(event.Error.Error())
  473. return event.Error
  474. case provider.EventComplete:
  475. assistantMsg.SetToolCalls(event.Response.ToolCalls)
  476. assistantMsg.AddFinish(event.Response.FinishReason)
  477. if err := a.messages.Update(ctx, *assistantMsg); err != nil {
  478. return fmt.Errorf("failed to update message: %w", err)
  479. }
  480. return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
  481. }
  482. return nil
  483. }
  484. func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
  485. sess, err := a.sessions.Get(ctx, sessionID)
  486. if err != nil {
  487. return fmt.Errorf("failed to get session: %w", err)
  488. }
  489. cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
  490. model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
  491. model.CostPer1MIn/1e6*float64(usage.InputTokens) +
  492. model.CostPer1MOut/1e6*float64(usage.OutputTokens)
  493. sess.Cost += cost
  494. sess.CompletionTokens += usage.OutputTokens
  495. sess.PromptTokens += usage.InputTokens
  496. _, err = a.sessions.Save(ctx, sess)
  497. if err != nil {
  498. return fmt.Errorf("failed to save session: %w", err)
  499. }
  500. return nil
  501. }
  502. func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) {
  503. if a.IsBusy() {
  504. return models.Model{}, fmt.Errorf("cannot change model while processing requests")
  505. }
  506. if err := config.UpdateAgentModel(agentName, modelID); err != nil {
  507. return models.Model{}, fmt.Errorf("failed to update config: %w", err)
  508. }
  509. provider, err := createAgentProvider(agentName)
  510. if err != nil {
  511. return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err)
  512. }
  513. a.provider = provider
  514. return a.provider.Model(), nil
  515. }
  516. // PauseSession pauses message processing for a specific session
  517. // This should be called before performing operations that require exclusive access
  518. func (a *agent) PauseSession(sessionID string) error {
  519. if !a.IsSessionBusy(sessionID) {
  520. return nil // Session is not active, no need to pause
  521. }
  522. logging.InfoPersist(fmt.Sprintf("Pausing session: %s", sessionID))
  523. a.pauseLock.Lock() // Acquire write lock to block new operations
  524. return nil
  525. }
  526. // ResumeSession resumes message processing for a session
  527. // This should be called after completing operations that required exclusive access
  528. func (a *agent) ResumeSession(sessionID string) error {
  529. logging.InfoPersist(fmt.Sprintf("Resuming session: %s", sessionID))
  530. a.pauseLock.Unlock() // Release write lock to allow operations to continue
  531. return nil
  532. }
  533. func (a *agent) CompactSession(ctx context.Context, sessionID string) error {
  534. // Check if the session is busy
  535. if a.IsSessionBusy(sessionID) {
  536. // Pause the session before compaction
  537. if err := a.PauseSession(sessionID); err != nil {
  538. return fmt.Errorf("failed to pause session: %w", err)
  539. }
  540. // Make sure to resume the session when we're done
  541. defer a.ResumeSession(sessionID)
  542. logging.InfoPersist(fmt.Sprintf("Session %s paused for compaction", sessionID))
  543. }
  544. // Create a cancellable context
  545. ctx, cancel := context.WithCancel(ctx)
  546. defer cancel()
  547. // Mark the session as busy during compaction
  548. compactionCancelFunc := func() {}
  549. a.activeRequests.Store(sessionID+"-compact", compactionCancelFunc)
  550. defer a.activeRequests.Delete(sessionID + "-compact")
  551. // Fetch the session
  552. session, err := a.sessions.Get(ctx, sessionID)
  553. if err != nil {
  554. return fmt.Errorf("failed to get session: %w", err)
  555. }
  556. // Fetch all messages for the session
  557. sessionMessages, err := a.messages.List(ctx, sessionID)
  558. if err != nil {
  559. return fmt.Errorf("failed to list messages: %w", err)
  560. }
  561. var existingSummary string
  562. if session.Summary != "" && session.SummarizedAt > 0 {
  563. // Filter messages that were created after the last summarization
  564. var newMessages []message.Message
  565. for _, msg := range sessionMessages {
  566. if msg.CreatedAt > session.SummarizedAt {
  567. newMessages = append(newMessages, msg)
  568. }
  569. }
  570. sessionMessages = newMessages
  571. existingSummary = session.Summary
  572. }
  573. // If there are no messages to summarize and no existing summary, return early
  574. if len(sessionMessages) == 0 && existingSummary == "" {
  575. return nil
  576. }
  577. messages := []message.Message{
  578. message.Message{
  579. Role: message.System,
  580. Parts: []message.ContentPart{
  581. message.TextContent{
  582. Text: "You are a helpful AI assistant tasked with summarizing conversations.",
  583. },
  584. },
  585. },
  586. }
  587. // If there's an existing summary, include it
  588. if existingSummary != "" {
  589. messages = append(messages, message.Message{
  590. Role: message.Assistant, // TODO: should this be system or user instead?
  591. Parts: []message.ContentPart{
  592. message.TextContent{
  593. Text: existingSummary,
  594. },
  595. },
  596. })
  597. }
  598. // Add all messages since the last summarized message
  599. messages = append(messages, sessionMessages...)
  600. // Add a final user message requesting the summary
  601. messages = append(messages, message.Message{
  602. Role: message.User,
  603. Parts: []message.ContentPart{
  604. message.TextContent{
  605. Text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  606. },
  607. },
  608. })
  609. // Call provider to get the summary
  610. response, err := a.provider.SendMessages(ctx, messages, a.tools)
  611. if err != nil {
  612. return fmt.Errorf("failed to get summary from the assistant: %w", err)
  613. }
  614. // Extract the summary text
  615. summaryText := strings.TrimSpace(response.Content)
  616. if summaryText == "" {
  617. return fmt.Errorf("received empty summary from the assistant")
  618. }
  619. // Update the session with the new summary
  620. currentTime := time.Now().UnixMilli()
  621. session.Summary = summaryText
  622. session.SummarizedAt = currentTime
  623. // Save the updated session
  624. _, err = a.sessions.Save(ctx, session)
  625. if err != nil {
  626. return fmt.Errorf("failed to save session with summary: %w", err)
  627. }
  628. // Track token usage
  629. err = a.TrackUsage(ctx, sessionID, a.provider.Model(), response.Usage)
  630. if err != nil {
  631. return fmt.Errorf("failed to track usage: %w", err)
  632. }
  633. return nil
  634. }
  635. func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
  636. cfg := config.Get()
  637. agentConfig, ok := cfg.Agents[agentName]
  638. if !ok {
  639. return nil, fmt.Errorf("agent %s not found", agentName)
  640. }
  641. model, ok := models.SupportedModels[agentConfig.Model]
  642. if !ok {
  643. return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
  644. }
  645. providerCfg, ok := cfg.Providers[model.Provider]
  646. if !ok {
  647. return nil, fmt.Errorf("provider %s not supported", model.Provider)
  648. }
  649. if providerCfg.Disabled {
  650. return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
  651. }
  652. maxTokens := model.DefaultMaxTokens
  653. if agentConfig.MaxTokens > 0 {
  654. maxTokens = agentConfig.MaxTokens
  655. }
  656. opts := []provider.ProviderClientOption{
  657. provider.WithAPIKey(providerCfg.APIKey),
  658. provider.WithModel(model),
  659. provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
  660. provider.WithMaxTokens(maxTokens),
  661. }
  662. if model.Provider == models.ProviderOpenAI && model.CanReason {
  663. opts = append(
  664. opts,
  665. provider.WithOpenAIOptions(
  666. provider.WithReasoningEffort(agentConfig.ReasoningEffort),
  667. ),
  668. )
  669. } else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder {
  670. opts = append(
  671. opts,
  672. provider.WithAnthropicOptions(
  673. provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
  674. ),
  675. )
  676. }
  677. agentProvider, err := provider.NewProvider(
  678. model.Provider,
  679. opts...,
  680. )
  681. if err != nil {
  682. return nil, fmt.Errorf("could not create provider: %v", err)
  683. }
  684. return agentProvider, nil
  685. }