agent.go 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757
  1. package agent
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "strings"
  7. "sync"
  8. "time"
  9. "github.com/opencode-ai/opencode/internal/config"
  10. "github.com/opencode-ai/opencode/internal/llm/models"
  11. "github.com/opencode-ai/opencode/internal/llm/prompt"
  12. "github.com/opencode-ai/opencode/internal/llm/provider"
  13. "github.com/opencode-ai/opencode/internal/llm/tools"
  14. "github.com/opencode-ai/opencode/internal/logging"
  15. "github.com/opencode-ai/opencode/internal/message"
  16. "github.com/opencode-ai/opencode/internal/permission"
  17. "github.com/opencode-ai/opencode/internal/session"
  18. )
  19. // Common errors
  20. var (
  21. ErrRequestCancelled = errors.New("request cancelled by user")
  22. ErrSessionBusy = errors.New("session is currently processing another request")
  23. )
  24. type AgentEvent struct {
  25. message message.Message
  26. err error
  27. }
  28. func (e *AgentEvent) Err() error {
  29. return e.err
  30. }
  31. func (e *AgentEvent) Response() message.Message {
  32. return e.message
  33. }
  34. type Service interface {
  35. Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  36. Cancel(sessionID string)
  37. IsSessionBusy(sessionID string) bool
  38. IsBusy() bool
  39. Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
  40. CompactSession(ctx context.Context, sessionID string) error
  41. PauseSession(sessionID string) error
  42. ResumeSession(sessionID string) error
  43. }
  44. type agent struct {
  45. sessions session.Service
  46. messages message.Service
  47. tools []tools.BaseTool
  48. provider provider.Provider
  49. titleProvider provider.Provider
  50. activeRequests sync.Map
  51. pauseLock sync.RWMutex // Lock for pausing message processing
  52. }
  53. func NewAgent(
  54. agentName config.AgentName,
  55. sessions session.Service,
  56. messages message.Service,
  57. agentTools []tools.BaseTool,
  58. ) (Service, error) {
  59. agentProvider, err := createAgentProvider(agentName)
  60. if err != nil {
  61. return nil, err
  62. }
  63. var titleProvider provider.Provider
  64. // Only generate titles for the coder agent
  65. if agentName == config.AgentCoder {
  66. titleProvider, err = createAgentProvider(config.AgentTitle)
  67. if err != nil {
  68. return nil, err
  69. }
  70. }
  71. agent := &agent{
  72. provider: agentProvider,
  73. messages: messages,
  74. sessions: sessions,
  75. tools: agentTools,
  76. titleProvider: titleProvider,
  77. activeRequests: sync.Map{},
  78. }
  79. return agent, nil
  80. }
  81. func (a *agent) Cancel(sessionID string) {
  82. if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
  83. if cancel, ok := cancelFunc.(context.CancelFunc); ok {
  84. logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
  85. cancel()
  86. }
  87. }
  88. }
  89. func (a *agent) IsBusy() bool {
  90. busy := false
  91. a.activeRequests.Range(func(key, value interface{}) bool {
  92. if cancelFunc, ok := value.(context.CancelFunc); ok {
  93. if cancelFunc != nil {
  94. busy = true
  95. return false // Stop iterating
  96. }
  97. }
  98. return true // Continue iterating
  99. })
  100. return busy
  101. }
  102. func (a *agent) IsSessionBusy(sessionID string) bool {
  103. _, busy := a.activeRequests.Load(sessionID)
  104. return busy
  105. }
  106. func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
  107. if content == "" {
  108. return nil
  109. }
  110. if a.titleProvider == nil {
  111. return nil
  112. }
  113. session, err := a.sessions.Get(ctx, sessionID)
  114. if err != nil {
  115. return err
  116. }
  117. parts := []message.ContentPart{message.TextContent{Text: content}}
  118. response, err := a.titleProvider.SendMessages(
  119. ctx,
  120. []message.Message{
  121. {
  122. Role: message.User,
  123. Parts: parts,
  124. },
  125. },
  126. make([]tools.BaseTool, 0),
  127. )
  128. if err != nil {
  129. return err
  130. }
  131. title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
  132. if title == "" {
  133. return nil
  134. }
  135. session.Title = title
  136. _, err = a.sessions.Save(ctx, session)
  137. return err
  138. }
  139. func (a *agent) err(err error) AgentEvent {
  140. return AgentEvent{
  141. err: err,
  142. }
  143. }
  144. func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
  145. if !a.provider.Model().SupportsAttachments && attachments != nil {
  146. attachments = nil
  147. }
  148. events := make(chan AgentEvent)
  149. if a.IsSessionBusy(sessionID) {
  150. return nil, ErrSessionBusy
  151. }
  152. genCtx, cancel := context.WithCancel(ctx)
  153. a.activeRequests.Store(sessionID, cancel)
  154. go func() {
  155. logging.Debug("Request started", "sessionID", sessionID)
  156. defer logging.RecoverPanic("agent.Run", func() {
  157. events <- a.err(fmt.Errorf("panic while running the agent"))
  158. })
  159. var attachmentParts []message.ContentPart
  160. for _, attachment := range attachments {
  161. attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
  162. }
  163. result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
  164. if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
  165. logging.ErrorPersist(result.Err().Error())
  166. }
  167. logging.Debug("Request completed", "sessionID", sessionID)
  168. a.activeRequests.Delete(sessionID)
  169. cancel()
  170. events <- result
  171. close(events)
  172. }()
  173. return events, nil
  174. }
  175. func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
  176. // Get the current session to check for summary
  177. currentSession, err := a.sessions.Get(ctx, sessionID)
  178. if err != nil {
  179. return a.err(fmt.Errorf("failed to get session: %w", err))
  180. }
  181. // Fetch messages based on whether a summary exists
  182. var sessionMessages []message.Message
  183. if currentSession.Summary != "" && currentSession.SummarizedAt > 0 {
  184. // If summary exists, only fetch messages after the summarization timestamp
  185. sessionMessages, err = a.messages.ListAfter(ctx, sessionID, currentSession.SummarizedAt)
  186. if err != nil {
  187. return a.err(fmt.Errorf("failed to list messages after summary: %w", err))
  188. }
  189. } else {
  190. // If no summary, fetch all messages
  191. sessionMessages, err = a.messages.List(ctx, sessionID)
  192. if err != nil {
  193. return a.err(fmt.Errorf("failed to list messages: %w", err))
  194. }
  195. }
  196. // If this is a new session, start title generation asynchronously
  197. if len(sessionMessages) == 0 && currentSession.Summary == "" {
  198. go func() {
  199. defer logging.RecoverPanic("agent.Run", func() {
  200. logging.ErrorPersist("panic while generating title")
  201. })
  202. titleErr := a.generateTitle(context.Background(), sessionID, content)
  203. if titleErr != nil {
  204. logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
  205. }
  206. }()
  207. }
  208. userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
  209. if err != nil {
  210. return a.err(fmt.Errorf("failed to create user message: %w", err))
  211. }
  212. // Prepare the message history for the LLM
  213. var messages []message.Message
  214. if currentSession.Summary != "" && currentSession.SummarizedAt > 0 {
  215. // If summary exists, create a temporary message for the summary
  216. summaryMessage := message.Message{
  217. Role: message.Assistant,
  218. Parts: []message.ContentPart{
  219. message.TextContent{Text: currentSession.Summary},
  220. },
  221. }
  222. // Start with the summary, then add messages after the summary timestamp
  223. messages = append([]message.Message{summaryMessage}, sessionMessages...)
  224. } else {
  225. // If no summary, just use all messages
  226. messages = sessionMessages
  227. }
  228. // Append the new user message to the conversation history
  229. messages = append(messages, userMsg)
  230. for {
  231. // Check for cancellation before each iteration
  232. select {
  233. case <-ctx.Done():
  234. return a.err(ctx.Err())
  235. default:
  236. // Continue processing
  237. }
  238. agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, messages)
  239. if err != nil {
  240. if errors.Is(err, context.Canceled) {
  241. agentMessage.AddFinish(message.FinishReasonCanceled)
  242. a.messages.Update(context.Background(), agentMessage)
  243. return a.err(ErrRequestCancelled)
  244. }
  245. return a.err(fmt.Errorf("failed to process events: %w", err))
  246. }
  247. logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
  248. if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
  249. // We are not done, we need to respond with the tool response
  250. messages = append(messages, agentMessage, *toolResults)
  251. continue
  252. }
  253. return AgentEvent{
  254. message: agentMessage,
  255. }
  256. }
  257. }
  258. func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
  259. parts := []message.ContentPart{message.TextContent{Text: content}}
  260. parts = append(parts, attachmentParts...)
  261. return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  262. Role: message.User,
  263. Parts: parts,
  264. })
  265. }
  266. func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
  267. // Check if we need to auto-compact based on token count
  268. contextWindow := a.provider.Model().ContextWindow
  269. maxTokens := a.provider.MaxTokens()
  270. threshold := int64(float64(contextWindow) * 0.80)
  271. usage, err := a.GetUsage(ctx, sessionID)
  272. if err != nil || usage == nil {
  273. return message.Message{}, nil, fmt.Errorf("failed to get usage: %w", err)
  274. }
  275. // If we're approaching the context window limit, trigger auto-compaction
  276. if (*usage + maxTokens) >= threshold {
  277. logging.InfoPersist(fmt.Sprintf("Auto-compaction triggered for session %s. Estimated tokens: %d, Threshold: %d", sessionID, usage, threshold))
  278. // Perform compaction with pause/resume to ensure safety
  279. if err := a.CompactSession(ctx, sessionID); err != nil {
  280. logging.ErrorPersist(fmt.Sprintf("Auto-compaction failed: %v", err))
  281. // Continue with the request even if compaction fails
  282. } else {
  283. // Re-fetch session details after compaction
  284. currentSession, err := a.sessions.Get(ctx, sessionID)
  285. if err != nil {
  286. return message.Message{}, nil, fmt.Errorf("failed to get session after compaction: %w", err)
  287. }
  288. // Re-prepare messages using the new summary
  289. var sessionMessages []message.Message
  290. if currentSession.Summary != "" && currentSession.SummarizedAt > 0 {
  291. // If summary exists, only fetch messages after the summarization timestamp
  292. sessionMessages, err = a.messages.ListAfter(ctx, sessionID, currentSession.SummarizedAt)
  293. if err != nil {
  294. return message.Message{}, nil, fmt.Errorf("failed to list messages after compaction: %w", err)
  295. }
  296. // Create a new message history with the summary and messages after summarization
  297. summaryMessage := message.Message{
  298. Role: message.Assistant,
  299. Parts: []message.ContentPart{
  300. message.TextContent{Text: currentSession.Summary},
  301. },
  302. }
  303. // Replace msgHistory with the new compacted version
  304. msgHistory = append([]message.Message{summaryMessage}, sessionMessages...)
  305. }
  306. }
  307. }
  308. eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
  309. assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  310. Role: message.Assistant,
  311. Parts: []message.ContentPart{},
  312. Model: a.provider.Model().ID,
  313. })
  314. if err != nil {
  315. return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
  316. }
  317. // Add the session and message ID into the context if needed by tools.
  318. ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
  319. ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
  320. // Process each event in the stream.
  321. for event := range eventChan {
  322. if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
  323. a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
  324. return assistantMsg, nil, processErr
  325. }
  326. if ctx.Err() != nil {
  327. a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
  328. return assistantMsg, nil, ctx.Err()
  329. }
  330. }
  331. toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
  332. toolCalls := assistantMsg.ToolCalls()
  333. for i, toolCall := range toolCalls {
  334. select {
  335. case <-ctx.Done():
  336. a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
  337. // Make all future tool calls cancelled
  338. for j := i; j < len(toolCalls); j++ {
  339. toolResults[j] = message.ToolResult{
  340. ToolCallID: toolCalls[j].ID,
  341. Content: "Tool execution canceled by user",
  342. IsError: true,
  343. }
  344. }
  345. goto out
  346. default:
  347. // Continue processing
  348. var tool tools.BaseTool
  349. for _, availableTools := range a.tools {
  350. if availableTools.Info().Name == toolCall.Name {
  351. tool = availableTools
  352. }
  353. }
  354. // Tool not found
  355. if tool == nil {
  356. toolResults[i] = message.ToolResult{
  357. ToolCallID: toolCall.ID,
  358. Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
  359. IsError: true,
  360. }
  361. continue
  362. }
  363. toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
  364. ID: toolCall.ID,
  365. Name: toolCall.Name,
  366. Input: toolCall.Input,
  367. })
  368. if toolErr != nil {
  369. if errors.Is(toolErr, permission.ErrorPermissionDenied) {
  370. toolResults[i] = message.ToolResult{
  371. ToolCallID: toolCall.ID,
  372. Content: "Permission denied",
  373. IsError: true,
  374. }
  375. for j := i + 1; j < len(toolCalls); j++ {
  376. toolResults[j] = message.ToolResult{
  377. ToolCallID: toolCalls[j].ID,
  378. Content: "Tool execution canceled by user",
  379. IsError: true,
  380. }
  381. }
  382. a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
  383. break
  384. }
  385. }
  386. toolResults[i] = message.ToolResult{
  387. ToolCallID: toolCall.ID,
  388. Content: toolResult.Content,
  389. Metadata: toolResult.Metadata,
  390. IsError: toolResult.IsError,
  391. }
  392. }
  393. }
  394. out:
  395. if len(toolResults) == 0 {
  396. return assistantMsg, nil, nil
  397. }
  398. parts := make([]message.ContentPart, 0)
  399. for _, tr := range toolResults {
  400. parts = append(parts, tr)
  401. }
  402. msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
  403. Role: message.Tool,
  404. Parts: parts,
  405. })
  406. if err != nil {
  407. return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
  408. }
  409. return assistantMsg, &msg, err
  410. }
  411. func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
  412. msg.AddFinish(finishReson)
  413. _ = a.messages.Update(ctx, *msg)
  414. }
  415. func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
  416. select {
  417. case <-ctx.Done():
  418. return ctx.Err()
  419. default:
  420. // Continue processing.
  421. }
  422. // Check if session is paused - use RLock to allow concurrent reads but block during pause
  423. a.pauseLock.RLock()
  424. defer a.pauseLock.RUnlock()
  425. switch event.Type {
  426. case provider.EventThinkingDelta:
  427. assistantMsg.AppendReasoningContent(event.Content)
  428. return a.messages.Update(ctx, *assistantMsg)
  429. case provider.EventContentDelta:
  430. assistantMsg.AppendContent(event.Content)
  431. return a.messages.Update(ctx, *assistantMsg)
  432. case provider.EventToolUseStart:
  433. assistantMsg.AddToolCall(*event.ToolCall)
  434. return a.messages.Update(ctx, *assistantMsg)
  435. // TODO: see how to handle this
  436. // case provider.EventToolUseDelta:
  437. // tm := time.Unix(assistantMsg.UpdatedAt, 0)
  438. // assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
  439. // if time.Since(tm) > 1000*time.Millisecond {
  440. // err := a.messages.Update(ctx, *assistantMsg)
  441. // assistantMsg.UpdatedAt = time.Now().Unix()
  442. // return err
  443. // }
  444. case provider.EventToolUseStop:
  445. assistantMsg.FinishToolCall(event.ToolCall.ID)
  446. return a.messages.Update(ctx, *assistantMsg)
  447. case provider.EventError:
  448. if errors.Is(event.Error, context.Canceled) {
  449. logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
  450. return context.Canceled
  451. }
  452. logging.ErrorPersist(event.Error.Error())
  453. return event.Error
  454. case provider.EventComplete:
  455. assistantMsg.SetToolCalls(event.Response.ToolCalls)
  456. assistantMsg.AddFinish(event.Response.FinishReason)
  457. if err := a.messages.Update(ctx, *assistantMsg); err != nil {
  458. return fmt.Errorf("failed to update message: %w", err)
  459. }
  460. return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
  461. }
  462. return nil
  463. }
  464. func (a *agent) GetUsage(ctx context.Context, sessionID string) (*int64, error) {
  465. session, err := a.sessions.Get(ctx, sessionID)
  466. if err != nil {
  467. return nil, fmt.Errorf("failed to get session: %w", err)
  468. }
  469. usage := session.PromptTokens + session.CompletionTokens
  470. return &usage, nil
  471. }
  472. func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
  473. sess, err := a.sessions.Get(ctx, sessionID)
  474. if err != nil {
  475. return fmt.Errorf("failed to get session: %w", err)
  476. }
  477. cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
  478. model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
  479. model.CostPer1MIn/1e6*float64(usage.InputTokens) +
  480. model.CostPer1MOut/1e6*float64(usage.OutputTokens)
  481. sess.Cost += cost
  482. sess.CompletionTokens += usage.OutputTokens
  483. sess.PromptTokens += usage.InputTokens
  484. _, err = a.sessions.Save(ctx, sess)
  485. if err != nil {
  486. return fmt.Errorf("failed to save session: %w", err)
  487. }
  488. return nil
  489. }
  490. func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) {
  491. if a.IsBusy() {
  492. return models.Model{}, fmt.Errorf("cannot change model while processing requests")
  493. }
  494. if err := config.UpdateAgentModel(agentName, modelID); err != nil {
  495. return models.Model{}, fmt.Errorf("failed to update config: %w", err)
  496. }
  497. provider, err := createAgentProvider(agentName)
  498. if err != nil {
  499. return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err)
  500. }
  501. a.provider = provider
  502. return a.provider.Model(), nil
  503. }
  504. // PauseSession pauses message processing for a specific session
  505. // This should be called before performing operations that require exclusive access
  506. func (a *agent) PauseSession(sessionID string) error {
  507. if !a.IsSessionBusy(sessionID) {
  508. return nil // Session is not active, no need to pause
  509. }
  510. logging.InfoPersist(fmt.Sprintf("Pausing session: %s", sessionID))
  511. a.pauseLock.Lock() // Acquire write lock to block new operations
  512. return nil
  513. }
  514. // ResumeSession resumes message processing for a session
  515. // This should be called after completing operations that required exclusive access
  516. func (a *agent) ResumeSession(sessionID string) error {
  517. logging.InfoPersist(fmt.Sprintf("Resuming session: %s", sessionID))
  518. a.pauseLock.Unlock() // Release write lock to allow operations to continue
  519. return nil
  520. }
  521. func (a *agent) CompactSession(ctx context.Context, sessionID string) error {
  522. // Check if the session is busy
  523. if a.IsSessionBusy(sessionID) {
  524. // Pause the session before compaction
  525. if err := a.PauseSession(sessionID); err != nil {
  526. return fmt.Errorf("failed to pause session: %w", err)
  527. }
  528. // Make sure to resume the session when we're done
  529. defer a.ResumeSession(sessionID)
  530. logging.InfoPersist(fmt.Sprintf("Session %s paused for compaction", sessionID))
  531. }
  532. // Create a cancellable context
  533. ctx, cancel := context.WithCancel(ctx)
  534. defer cancel()
  535. // Mark the session as busy during compaction
  536. compactionCancelFunc := func() {}
  537. a.activeRequests.Store(sessionID+"-compact", compactionCancelFunc)
  538. defer a.activeRequests.Delete(sessionID + "-compact")
  539. // Fetch the session
  540. session, err := a.sessions.Get(ctx, sessionID)
  541. if err != nil {
  542. return fmt.Errorf("failed to get session: %w", err)
  543. }
  544. // Fetch all messages for the session
  545. sessionMessages, err := a.messages.List(ctx, sessionID)
  546. if err != nil {
  547. return fmt.Errorf("failed to list messages: %w", err)
  548. }
  549. var existingSummary string
  550. if session.Summary != "" && session.SummarizedAt > 0 {
  551. // Filter messages that were created after the last summarization
  552. var newMessages []message.Message
  553. for _, msg := range sessionMessages {
  554. if msg.CreatedAt > session.SummarizedAt {
  555. newMessages = append(newMessages, msg)
  556. }
  557. }
  558. sessionMessages = newMessages
  559. existingSummary = session.Summary
  560. }
  561. // If there are no messages to summarize and no existing summary, return early
  562. if len(sessionMessages) == 0 && existingSummary == "" {
  563. return nil
  564. }
  565. messages := []message.Message{
  566. message.Message{
  567. Role: message.System,
  568. Parts: []message.ContentPart{
  569. message.TextContent{
  570. Text: "You are a helpful AI assistant tasked with summarizing conversations.",
  571. },
  572. },
  573. },
  574. }
  575. // If there's an existing summary, include it
  576. if existingSummary != "" {
  577. messages = append(messages, message.Message{
  578. Role: message.Assistant, // TODO: should this be system or user instead?
  579. Parts: []message.ContentPart{
  580. message.TextContent{
  581. Text: existingSummary,
  582. },
  583. },
  584. })
  585. }
  586. // Add all messages since the last summarized message
  587. messages = append(messages, sessionMessages...)
  588. // Add a final user message requesting the summary
  589. messages = append(messages, message.Message{
  590. Role: message.User,
  591. Parts: []message.ContentPart{
  592. message.TextContent{
  593. Text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  594. },
  595. },
  596. })
  597. // Call provider to get the summary
  598. response, err := a.provider.SendMessages(ctx, messages, a.tools)
  599. if err != nil {
  600. return fmt.Errorf("failed to get summary from the assistant: %w", err)
  601. }
  602. // Extract the summary text
  603. summaryText := strings.TrimSpace(response.Content)
  604. if summaryText == "" {
  605. return fmt.Errorf("received empty summary from the assistant")
  606. }
  607. // Update the session with the new summary
  608. currentTime := time.Now().UnixMilli()
  609. session.Summary = summaryText
  610. session.SummarizedAt = currentTime
  611. // Save the updated session
  612. _, err = a.sessions.Save(ctx, session)
  613. if err != nil {
  614. return fmt.Errorf("failed to save session with summary: %w", err)
  615. }
  616. // Track token usage
  617. err = a.TrackUsage(ctx, sessionID, a.provider.Model(), response.Usage)
  618. if err != nil {
  619. return fmt.Errorf("failed to track usage: %w", err)
  620. }
  621. return nil
  622. }
  623. func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
  624. cfg := config.Get()
  625. agentConfig, ok := cfg.Agents[agentName]
  626. if !ok {
  627. return nil, fmt.Errorf("agent %s not found", agentName)
  628. }
  629. model, ok := models.SupportedModels[agentConfig.Model]
  630. if !ok {
  631. return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
  632. }
  633. providerCfg, ok := cfg.Providers[model.Provider]
  634. if !ok {
  635. return nil, fmt.Errorf("provider %s not supported", model.Provider)
  636. }
  637. if providerCfg.Disabled {
  638. return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
  639. }
  640. maxTokens := model.DefaultMaxTokens
  641. if agentConfig.MaxTokens > 0 {
  642. maxTokens = agentConfig.MaxTokens
  643. }
  644. opts := []provider.ProviderClientOption{
  645. provider.WithAPIKey(providerCfg.APIKey),
  646. provider.WithModel(model),
  647. provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
  648. provider.WithMaxTokens(maxTokens),
  649. }
  650. if model.Provider == models.ProviderOpenAI && model.CanReason {
  651. opts = append(
  652. opts,
  653. provider.WithOpenAIOptions(
  654. provider.WithReasoningEffort(agentConfig.ReasoningEffort),
  655. ),
  656. )
  657. } else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder {
  658. opts = append(
  659. opts,
  660. provider.WithAnthropicOptions(
  661. provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
  662. ),
  663. )
  664. }
  665. agentProvider, err := provider.NewProvider(
  666. model.Provider,
  667. opts...,
  668. )
  669. if err != nil {
  670. return nil, fmt.Errorf("could not create provider: %v", err)
  671. }
  672. return agentProvider, nil
  673. }