agent.go 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804
  1. package agent
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "log/slog"
  7. "strings"
  8. "sync"
  9. "time"
  10. "github.com/sst/opencode/internal/config"
  11. "github.com/sst/opencode/internal/llm/models"
  12. "github.com/sst/opencode/internal/llm/prompt"
  13. "github.com/sst/opencode/internal/llm/provider"
  14. "github.com/sst/opencode/internal/llm/tools"
  15. "github.com/sst/opencode/internal/logging"
  16. "github.com/sst/opencode/internal/message"
  17. "github.com/sst/opencode/internal/permission"
  18. "github.com/sst/opencode/internal/session"
  19. "github.com/sst/opencode/internal/status"
  20. )
  21. // Common errors
  22. var (
  23. ErrRequestCancelled = errors.New("request cancelled by user")
  24. ErrSessionBusy = errors.New("session is currently processing another request")
  25. )
  26. type AgentEvent struct {
  27. message message.Message
  28. err error
  29. }
  30. func (e *AgentEvent) Err() error {
  31. return e.err
  32. }
  33. func (e *AgentEvent) Response() message.Message {
  34. return e.message
  35. }
  36. type Service interface {
  37. Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  38. Cancel(sessionID string)
  39. IsSessionBusy(sessionID string) bool
  40. IsBusy() bool
  41. Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
  42. CompactSession(ctx context.Context, sessionID string, force bool) error
  43. GetUsage(ctx context.Context, sessionID string) (*int64, error)
  44. EstimateContextWindowUsage(ctx context.Context, sessionID string) (float64, bool, error)
  45. }
  46. type agent struct {
  47. sessions session.Service
  48. messages message.Service
  49. tools []tools.BaseTool
  50. provider provider.Provider
  51. titleProvider provider.Provider
  52. activeRequests sync.Map
  53. }
  54. func NewAgent(
  55. agentName config.AgentName,
  56. sessions session.Service,
  57. messages message.Service,
  58. agentTools []tools.BaseTool,
  59. ) (Service, error) {
  60. agentProvider, err := createAgentProvider(agentName)
  61. if err != nil {
  62. return nil, err
  63. }
  64. var titleProvider provider.Provider
  65. // Only generate titles for the primary agent
  66. if agentName == config.AgentPrimary {
  67. titleProvider, err = createAgentProvider(config.AgentTitle)
  68. if err != nil {
  69. return nil, err
  70. }
  71. }
  72. agent := &agent{
  73. provider: agentProvider,
  74. messages: messages,
  75. sessions: sessions,
  76. tools: agentTools,
  77. titleProvider: titleProvider,
  78. activeRequests: sync.Map{},
  79. }
  80. return agent, nil
  81. }
  82. func (a *agent) Cancel(sessionID string) {
  83. if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
  84. if cancel, ok := cancelFunc.(context.CancelFunc); ok {
  85. status.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
  86. cancel()
  87. }
  88. }
  89. }
  90. func (a *agent) IsBusy() bool {
  91. busy := false
  92. a.activeRequests.Range(func(key, value interface{}) bool {
  93. if cancelFunc, ok := value.(context.CancelFunc); ok {
  94. if cancelFunc != nil {
  95. busy = true
  96. return false // Stop iterating
  97. }
  98. }
  99. return true // Continue iterating
  100. })
  101. return busy
  102. }
  103. func (a *agent) IsSessionBusy(sessionID string) bool {
  104. _, busy := a.activeRequests.Load(sessionID)
  105. return busy
  106. }
  107. func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
  108. if content == "" {
  109. return nil
  110. }
  111. if a.titleProvider == nil {
  112. return nil
  113. }
  114. session, err := a.sessions.Get(ctx, sessionID)
  115. if err != nil {
  116. return err
  117. }
  118. parts := []message.ContentPart{message.TextContent{Text: content}}
  119. response, err := a.titleProvider.SendMessages(
  120. ctx,
  121. []message.Message{
  122. {
  123. Role: message.User,
  124. Parts: parts,
  125. },
  126. },
  127. make([]tools.BaseTool, 0),
  128. )
  129. if err != nil {
  130. return err
  131. }
  132. title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
  133. if title == "" {
  134. return nil
  135. }
  136. session.Title = title
  137. _, err = a.sessions.Update(ctx, session)
  138. return err
  139. }
  140. func (a *agent) err(err error) AgentEvent {
  141. return AgentEvent{
  142. err: err,
  143. }
  144. }
  145. func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
  146. if !a.provider.Model().SupportsAttachments && attachments != nil {
  147. attachments = nil
  148. }
  149. events := make(chan AgentEvent)
  150. if a.IsSessionBusy(sessionID) {
  151. return nil, ErrSessionBusy
  152. }
  153. genCtx, cancel := context.WithCancel(ctx)
  154. a.activeRequests.Store(sessionID, cancel)
  155. go func() {
  156. slog.Debug("Request started", "sessionID", sessionID)
  157. defer logging.RecoverPanic("agent.Run", func() {
  158. events <- a.err(fmt.Errorf("panic while running the agent"))
  159. })
  160. var attachmentParts []message.ContentPart
  161. for _, attachment := range attachments {
  162. attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
  163. }
  164. result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
  165. if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
  166. status.Error(result.Err().Error())
  167. }
  168. slog.Debug("Request completed", "sessionID", sessionID)
  169. a.activeRequests.Delete(sessionID)
  170. cancel()
  171. events <- result
  172. close(events)
  173. }()
  174. return events, nil
  175. }
  176. func (a *agent) prepareMessageHistory(ctx context.Context, sessionID string) (session.Session, []message.Message, error) {
  177. currentSession, err := a.sessions.Get(ctx, sessionID)
  178. if err != nil {
  179. return currentSession, nil, fmt.Errorf("failed to get session: %w", err)
  180. }
  181. var sessionMessages []message.Message
  182. if currentSession.Summary != "" && !currentSession.SummarizedAt.IsZero() {
  183. // If summary exists, only fetch messages after the summarization timestamp
  184. sessionMessages, err = a.messages.ListAfter(ctx, sessionID, currentSession.SummarizedAt)
  185. if err != nil {
  186. return currentSession, nil, fmt.Errorf("failed to list messages after summary: %w", err)
  187. }
  188. } else {
  189. // If no summary, fetch all messages
  190. sessionMessages, err = a.messages.List(ctx, sessionID)
  191. if err != nil {
  192. return currentSession, nil, fmt.Errorf("failed to list messages: %w", err)
  193. }
  194. }
  195. var messages []message.Message
  196. if currentSession.Summary != "" && !currentSession.SummarizedAt.IsZero() {
  197. // If summary exists, create a temporary message for the summary
  198. summaryMessage := message.Message{
  199. Role: message.Assistant,
  200. Parts: []message.ContentPart{
  201. message.TextContent{Text: currentSession.Summary},
  202. },
  203. }
  204. // Start with the summary, then add messages after the summary timestamp
  205. messages = append([]message.Message{summaryMessage}, sessionMessages...)
  206. } else {
  207. // If no summary, just use all messages
  208. messages = sessionMessages
  209. }
  210. return currentSession, messages, nil
  211. }
  212. func (a *agent) triggerTitleGeneration(sessionID string, content string) {
  213. go func() {
  214. defer logging.RecoverPanic("agent.Run", func() {
  215. status.Error("panic while generating title")
  216. })
  217. titleErr := a.generateTitle(context.Background(), sessionID, content)
  218. if titleErr != nil {
  219. status.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
  220. }
  221. }()
  222. }
  223. func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
  224. currentSession, sessionMessages, err := a.prepareMessageHistory(ctx, sessionID)
  225. if err != nil {
  226. return a.err(err)
  227. }
  228. // If this is a new session, start title generation asynchronously
  229. if len(sessionMessages) == 0 && currentSession.Summary == "" {
  230. a.triggerTitleGeneration(sessionID, content)
  231. }
  232. userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
  233. if err != nil {
  234. return a.err(fmt.Errorf("failed to create user message: %w", err))
  235. }
  236. messages := append(sessionMessages, userMsg)
  237. for {
  238. // Check for cancellation before each iteration
  239. select {
  240. case <-ctx.Done():
  241. return a.err(ctx.Err())
  242. default:
  243. // Continue processing
  244. }
  245. // Check if auto-compaction is needed before calling the provider
  246. usagePercentage, needsCompaction, errEstimate := a.EstimateContextWindowUsage(ctx, sessionID)
  247. if errEstimate != nil {
  248. slog.Warn("Failed to estimate context window usage for auto-compaction", "error", errEstimate, "sessionID", sessionID)
  249. } else if needsCompaction {
  250. status.Info(fmt.Sprintf("Context window usage is at %.2f%%. Auto-compacting conversation...", usagePercentage))
  251. // Run compaction synchronously
  252. compactCtx, cancelCompact := context.WithTimeout(ctx, 30*time.Second) // Use appropriate context
  253. errCompact := a.CompactSession(compactCtx, sessionID, true)
  254. cancelCompact()
  255. if errCompact != nil {
  256. status.Warn(fmt.Sprintf("Auto-compaction failed: %v. Context window usage may continue to grow.", errCompact))
  257. } else {
  258. status.Info("Auto-compaction completed successfully.")
  259. // After compaction, message history needs to be re-prepared.
  260. // The 'messages' slice needs to be updated with the new summary and subsequent messages,
  261. // ensuring the latest user message is correctly appended.
  262. _, sessionMessagesFromCompact, errPrepare := a.prepareMessageHistory(ctx, sessionID)
  263. if errPrepare != nil {
  264. return a.err(fmt.Errorf("failed to re-prepare message history after compaction: %w", errPrepare))
  265. }
  266. messages = sessionMessagesFromCompact
  267. // Ensure the user message that triggered this cycle is the last one.
  268. // 'userMsg' was created before this loop using a.createUserMessage.
  269. // It should be appended to the 'messages' slice if it's not already the last element.
  270. if len(messages) == 0 || (len(messages) > 0 && messages[len(messages)-1].ID != userMsg.ID) {
  271. messages = append(messages, userMsg)
  272. }
  273. }
  274. }
  275. agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, messages)
  276. if err != nil {
  277. if errors.Is(err, context.Canceled) {
  278. agentMessage.AddFinish(message.FinishReasonCanceled)
  279. a.messages.Update(context.Background(), agentMessage)
  280. return a.err(ErrRequestCancelled)
  281. }
  282. return a.err(fmt.Errorf("failed to process events: %w", err))
  283. }
  284. slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
  285. if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
  286. // We are not done, we need to respond with the tool response
  287. messages = append(messages, agentMessage, *toolResults)
  288. continue
  289. }
  290. return AgentEvent{
  291. message: agentMessage,
  292. }
  293. }
  294. }
  295. func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
  296. parts := []message.ContentPart{message.TextContent{Text: content}}
  297. parts = append(parts, attachmentParts...)
  298. return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  299. Role: message.User,
  300. Parts: parts,
  301. })
  302. }
  303. func (a *agent) createToolResponseMessage(ctx context.Context, sessionID string, toolResults []message.ToolResult) (*message.Message, error) {
  304. if len(toolResults) == 0 {
  305. return nil, nil
  306. }
  307. parts := make([]message.ContentPart, 0, len(toolResults))
  308. for _, tr := range toolResults {
  309. parts = append(parts, tr)
  310. }
  311. msg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  312. Role: message.Tool,
  313. Parts: parts,
  314. })
  315. if err != nil {
  316. return nil, fmt.Errorf("failed to create tool response message: %w", err)
  317. }
  318. return &msg, nil
  319. }
  320. func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
  321. eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
  322. assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  323. Role: message.Assistant,
  324. Parts: []message.ContentPart{},
  325. Model: a.provider.Model().ID,
  326. })
  327. if err != nil {
  328. return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
  329. }
  330. // Add the session and message ID into the context if needed by tools.
  331. ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
  332. ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
  333. // Process each event in the stream.
  334. for event := range eventChan {
  335. if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
  336. a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
  337. return assistantMsg, nil, processErr
  338. }
  339. if ctx.Err() != nil {
  340. a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
  341. return assistantMsg, nil, ctx.Err()
  342. }
  343. }
  344. // If the assistant wants to use tools, execute them
  345. if assistantMsg.FinishReason() == message.FinishReasonToolUse {
  346. toolCalls := assistantMsg.ToolCalls()
  347. if len(toolCalls) > 0 {
  348. toolResults, err := a.executeToolCalls(ctx, toolCalls)
  349. if err != nil {
  350. if errors.Is(err, context.Canceled) {
  351. a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
  352. }
  353. return assistantMsg, nil, err
  354. }
  355. // Create a message with the tool results
  356. toolResponseMsg, err := a.createToolResponseMessage(ctx, sessionID, toolResults)
  357. if err != nil {
  358. return assistantMsg, nil, err
  359. }
  360. return assistantMsg, toolResponseMsg, nil
  361. }
  362. }
  363. return assistantMsg, nil, nil
  364. }
  365. func (a *agent) executeToolCalls(ctx context.Context, toolCalls []message.ToolCall) ([]message.ToolResult, error) {
  366. toolResults := make([]message.ToolResult, len(toolCalls))
  367. for i, toolCall := range toolCalls {
  368. select {
  369. case <-ctx.Done():
  370. // Make all future tool calls cancelled
  371. for j := i; j < len(toolCalls); j++ {
  372. toolResults[j] = message.ToolResult{
  373. ToolCallID: toolCalls[j].ID,
  374. Content: "Tool execution canceled by user",
  375. IsError: true,
  376. }
  377. }
  378. return toolResults, ctx.Err()
  379. default:
  380. // Continue processing
  381. var tool tools.BaseTool
  382. for _, availableTools := range a.tools {
  383. if availableTools.Info().Name == toolCall.Name {
  384. tool = availableTools
  385. }
  386. }
  387. // Tool not found
  388. if tool == nil {
  389. toolResults[i] = message.ToolResult{
  390. ToolCallID: toolCall.ID,
  391. Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
  392. IsError: true,
  393. }
  394. continue
  395. }
  396. toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
  397. ID: toolCall.ID,
  398. Name: toolCall.Name,
  399. Input: toolCall.Input,
  400. })
  401. if toolErr != nil {
  402. if errors.Is(toolErr, permission.ErrorPermissionDenied) {
  403. toolResults[i] = message.ToolResult{
  404. ToolCallID: toolCall.ID,
  405. Content: "Permission denied",
  406. IsError: true,
  407. }
  408. // Cancel all remaining tool calls if permission is denied
  409. for j := i + 1; j < len(toolCalls); j++ {
  410. toolResults[j] = message.ToolResult{
  411. ToolCallID: toolCalls[j].ID,
  412. Content: "Tool execution canceled by user",
  413. IsError: true,
  414. }
  415. }
  416. return toolResults, nil
  417. }
  418. // Handle other errors
  419. toolResults[i] = message.ToolResult{
  420. ToolCallID: toolCall.ID,
  421. Content: toolErr.Error(),
  422. IsError: true,
  423. }
  424. continue
  425. }
  426. toolResults[i] = message.ToolResult{
  427. ToolCallID: toolCall.ID,
  428. Content: toolResult.Content,
  429. Metadata: toolResult.Metadata,
  430. IsError: toolResult.IsError,
  431. }
  432. }
  433. }
  434. return toolResults, nil
  435. }
  436. func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
  437. msg.AddFinish(finishReson)
  438. _, _ = a.messages.Update(ctx, *msg)
  439. }
  440. func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
  441. select {
  442. case <-ctx.Done():
  443. return ctx.Err()
  444. default:
  445. // Continue processing
  446. }
  447. switch event.Type {
  448. case provider.EventThinkingDelta:
  449. assistantMsg.AppendReasoningContent(event.Content)
  450. _, err := a.messages.Update(ctx, *assistantMsg)
  451. return err
  452. case provider.EventContentDelta:
  453. assistantMsg.AppendContent(event.Content)
  454. _, err := a.messages.Update(ctx, *assistantMsg)
  455. return err
  456. case provider.EventToolUseStart:
  457. assistantMsg.AddToolCall(*event.ToolCall)
  458. _, err := a.messages.Update(ctx, *assistantMsg)
  459. return err
  460. case provider.EventToolUseStop:
  461. assistantMsg.FinishToolCall(event.ToolCall.ID)
  462. _, err := a.messages.Update(ctx, *assistantMsg)
  463. return err
  464. case provider.EventError:
  465. if errors.Is(event.Error, context.Canceled) {
  466. status.Info(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
  467. return context.Canceled
  468. }
  469. status.Error(event.Error.Error())
  470. return event.Error
  471. case provider.EventComplete:
  472. assistantMsg.SetToolCalls(event.Response.ToolCalls)
  473. assistantMsg.AddFinish(event.Response.FinishReason)
  474. if _, err := a.messages.Update(ctx, *assistantMsg); err != nil {
  475. return fmt.Errorf("failed to update message: %w", err)
  476. }
  477. return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
  478. }
  479. return nil
  480. }
  481. func (a *agent) GetUsage(ctx context.Context, sessionID string) (*int64, error) {
  482. session, err := a.sessions.Get(ctx, sessionID)
  483. if err != nil {
  484. return nil, fmt.Errorf("failed to get session: %w", err)
  485. }
  486. usage := session.PromptTokens + session.CompletionTokens
  487. return &usage, nil
  488. }
  489. func (a *agent) EstimateContextWindowUsage(ctx context.Context, sessionID string) (float64, bool, error) {
  490. session, err := a.sessions.Get(ctx, sessionID)
  491. if err != nil {
  492. return 0, false, fmt.Errorf("failed to get session: %w", err)
  493. }
  494. // Get the model's context window size
  495. model := a.provider.Model()
  496. contextWindow := model.ContextWindow
  497. if contextWindow <= 0 {
  498. // Default to a reasonable size if not specified
  499. contextWindow = 100000
  500. }
  501. // Calculate current token usage
  502. currentTokens := session.PromptTokens + session.CompletionTokens
  503. // Get the max tokens setting for the agent
  504. maxTokens := a.provider.MaxTokens()
  505. // Calculate percentage of context window used
  506. usagePercentage := float64(currentTokens) / float64(contextWindow)
  507. // Check if we need to auto-compact
  508. // Auto-compact when:
  509. // 1. Usage exceeds 90% of context window, OR
  510. // 2. Current usage + maxTokens would exceed 100% of context window
  511. needsCompaction := usagePercentage >= 0.9 ||
  512. float64(currentTokens+maxTokens) > float64(contextWindow)
  513. return usagePercentage * 100, needsCompaction, nil
  514. }
  515. func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
  516. sess, err := a.sessions.Get(ctx, sessionID)
  517. if err != nil {
  518. return fmt.Errorf("failed to get session: %w", err)
  519. }
  520. cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
  521. model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
  522. model.CostPer1MIn/1e6*float64(usage.InputTokens) +
  523. model.CostPer1MOut/1e6*float64(usage.OutputTokens)
  524. sess.Cost += cost
  525. sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
  526. sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
  527. _, err = a.sessions.Update(ctx, sess)
  528. if err != nil {
  529. return fmt.Errorf("failed to save session: %w", err)
  530. }
  531. return nil
  532. }
  533. func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) {
  534. if a.IsBusy() {
  535. return models.Model{}, fmt.Errorf("cannot change model while processing requests")
  536. }
  537. if err := config.UpdateAgentModel(agentName, modelID); err != nil {
  538. return models.Model{}, fmt.Errorf("failed to update config: %w", err)
  539. }
  540. provider, err := createAgentProvider(agentName)
  541. if err != nil {
  542. return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err)
  543. }
  544. a.provider = provider
  545. return a.provider.Model(), nil
  546. }
  547. func (a *agent) CompactSession(ctx context.Context, sessionID string, force bool) error {
  548. // Check if the session is busy
  549. if a.IsSessionBusy(sessionID) && !force {
  550. return ErrSessionBusy
  551. }
  552. // Create a cancellable context
  553. ctx, cancel := context.WithCancel(ctx)
  554. defer cancel()
  555. // Mark the session as busy during compaction
  556. compactionCancelFunc := func() {}
  557. a.activeRequests.Store(sessionID+"-compact", compactionCancelFunc)
  558. defer a.activeRequests.Delete(sessionID + "-compact")
  559. // Fetch the session
  560. session, err := a.sessions.Get(ctx, sessionID)
  561. if err != nil {
  562. return fmt.Errorf("failed to get session: %w", err)
  563. }
  564. // Fetch all messages for the session
  565. sessionMessages, err := a.messages.List(ctx, sessionID)
  566. if err != nil {
  567. return fmt.Errorf("failed to list messages: %w", err)
  568. }
  569. var existingSummary string
  570. if session.Summary != "" && !session.SummarizedAt.IsZero() {
  571. // Filter messages that were created after the last summarization
  572. var newMessages []message.Message
  573. for _, msg := range sessionMessages {
  574. if msg.CreatedAt.After(session.SummarizedAt) {
  575. newMessages = append(newMessages, msg)
  576. }
  577. }
  578. sessionMessages = newMessages
  579. existingSummary = session.Summary
  580. }
  581. // If there are no messages to summarize and no existing summary, return early
  582. if len(sessionMessages) == 0 && existingSummary == "" {
  583. return nil
  584. }
  585. messages := []message.Message{
  586. message.Message{
  587. Role: message.System,
  588. Parts: []message.ContentPart{
  589. message.TextContent{
  590. Text: `You are a helpful AI assistant tasked with summarizing conversations.
  591. When asked to summarize, provide a detailed but concise summary of the conversation.
  592. Focus on information that would be helpful for continuing the conversation, including:
  593. - What was done
  594. - What is currently being worked on
  595. - Which files are being modified
  596. - What needs to be done next
  597. Your summary should be comprehensive enough to provide context but concise enough to be quickly understood.`,
  598. },
  599. },
  600. },
  601. }
  602. // If there's an existing summary, include it
  603. if existingSummary != "" {
  604. messages = append(messages, message.Message{
  605. Role: message.Assistant,
  606. Parts: []message.ContentPart{
  607. message.TextContent{
  608. Text: existingSummary,
  609. },
  610. },
  611. })
  612. }
  613. // Add all messages since the last summarized message
  614. messages = append(messages, sessionMessages...)
  615. // Add a final user message requesting the summary
  616. messages = append(messages, message.Message{
  617. Role: message.User,
  618. Parts: []message.ContentPart{
  619. message.TextContent{
  620. Text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  621. },
  622. },
  623. })
  624. // Call provider to get the summary
  625. response, err := a.provider.SendMessages(ctx, messages, a.tools)
  626. if err != nil {
  627. return fmt.Errorf("failed to get summary from the assistant: %w", err)
  628. }
  629. // Extract the summary text
  630. summaryText := strings.TrimSpace(response.Content)
  631. if summaryText == "" {
  632. return fmt.Errorf("received empty summary from the assistant")
  633. }
  634. // Update the session with the new summary
  635. session.Summary = summaryText
  636. session.SummarizedAt = time.Now()
  637. // Save the updated session
  638. _, err = a.sessions.Update(ctx, session)
  639. if err != nil {
  640. return fmt.Errorf("failed to save session with summary: %w", err)
  641. }
  642. // Track token usage
  643. err = a.TrackUsage(ctx, sessionID, a.provider.Model(), response.Usage)
  644. if err != nil {
  645. return fmt.Errorf("failed to track usage: %w", err)
  646. }
  647. return nil
  648. }
  649. func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
  650. cfg := config.Get()
  651. agentConfig, ok := cfg.Agents[agentName]
  652. if !ok {
  653. return nil, fmt.Errorf("agent %s not found", agentName)
  654. }
  655. model, ok := models.SupportedModels[agentConfig.Model]
  656. if !ok {
  657. return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
  658. }
  659. providerCfg, ok := cfg.Providers[model.Provider]
  660. if !ok {
  661. return nil, fmt.Errorf("provider %s not supported", model.Provider)
  662. }
  663. if providerCfg.Disabled {
  664. return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
  665. }
  666. maxTokens := model.DefaultMaxTokens
  667. if agentConfig.MaxTokens > 0 {
  668. maxTokens = agentConfig.MaxTokens
  669. }
  670. opts := []provider.ProviderClientOption{
  671. provider.WithAPIKey(providerCfg.APIKey),
  672. provider.WithModel(model),
  673. provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
  674. provider.WithMaxTokens(maxTokens),
  675. }
  676. if model.Provider == models.ProviderOpenAI && model.CanReason {
  677. opts = append(
  678. opts,
  679. provider.WithOpenAIOptions(
  680. provider.WithReasoningEffort(agentConfig.ReasoningEffort),
  681. ),
  682. )
  683. } else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentPrimary {
  684. opts = append(
  685. opts,
  686. provider.WithAnthropicOptions(
  687. provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
  688. ),
  689. )
  690. }
  691. agentProvider, err := provider.NewProvider(
  692. model.Provider,
  693. opts...,
  694. )
  695. if err != nil {
  696. return nil, fmt.Errorf("could not create provider: %v", err)
  697. }
  698. return agentProvider, nil
  699. }