agent.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. package agent
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "strings"
  7. "sync"
  8. "github.com/opencode-ai/opencode/internal/config"
  9. "github.com/opencode-ai/opencode/internal/llm/models"
  10. "github.com/opencode-ai/opencode/internal/llm/prompt"
  11. "github.com/opencode-ai/opencode/internal/llm/provider"
  12. "github.com/opencode-ai/opencode/internal/llm/tools"
  13. "github.com/opencode-ai/opencode/internal/logging"
  14. "github.com/opencode-ai/opencode/internal/message"
  15. "github.com/opencode-ai/opencode/internal/permission"
  16. "github.com/opencode-ai/opencode/internal/session"
  17. )
  18. // Common errors
  19. var (
  20. ErrRequestCancelled = errors.New("request cancelled by user")
  21. ErrSessionBusy = errors.New("session is currently processing another request")
  22. )
  23. type AgentEvent struct {
  24. message message.Message
  25. err error
  26. }
  27. func (e *AgentEvent) Err() error {
  28. return e.err
  29. }
  30. func (e *AgentEvent) Response() message.Message {
  31. return e.message
  32. }
  33. type Service interface {
  34. Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error)
  35. Cancel(sessionID string)
  36. IsSessionBusy(sessionID string) bool
  37. IsBusy() bool
  38. Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
  39. }
  40. type agent struct {
  41. sessions session.Service
  42. messages message.Service
  43. tools []tools.BaseTool
  44. provider provider.Provider
  45. titleProvider provider.Provider
  46. activeRequests sync.Map
  47. }
  48. func NewAgent(
  49. agentName config.AgentName,
  50. sessions session.Service,
  51. messages message.Service,
  52. agentTools []tools.BaseTool,
  53. ) (Service, error) {
  54. agentProvider, err := createAgentProvider(agentName)
  55. if err != nil {
  56. return nil, err
  57. }
  58. var titleProvider provider.Provider
  59. // Only generate titles for the coder agent
  60. if agentName == config.AgentCoder {
  61. titleProvider, err = createAgentProvider(config.AgentTitle)
  62. if err != nil {
  63. return nil, err
  64. }
  65. }
  66. agent := &agent{
  67. provider: agentProvider,
  68. messages: messages,
  69. sessions: sessions,
  70. tools: agentTools,
  71. titleProvider: titleProvider,
  72. activeRequests: sync.Map{},
  73. }
  74. return agent, nil
  75. }
  76. func (a *agent) Cancel(sessionID string) {
  77. if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
  78. if cancel, ok := cancelFunc.(context.CancelFunc); ok {
  79. logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
  80. cancel()
  81. }
  82. }
  83. }
  84. func (a *agent) IsBusy() bool {
  85. busy := false
  86. a.activeRequests.Range(func(key, value interface{}) bool {
  87. if cancelFunc, ok := value.(context.CancelFunc); ok {
  88. if cancelFunc != nil {
  89. busy = true
  90. return false // Stop iterating
  91. }
  92. }
  93. return true // Continue iterating
  94. })
  95. return busy
  96. }
  97. func (a *agent) IsSessionBusy(sessionID string) bool {
  98. _, busy := a.activeRequests.Load(sessionID)
  99. return busy
  100. }
  101. func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
  102. if a.titleProvider == nil {
  103. return nil
  104. }
  105. session, err := a.sessions.Get(ctx, sessionID)
  106. if err != nil {
  107. return err
  108. }
  109. response, err := a.titleProvider.SendMessages(
  110. ctx,
  111. []message.Message{
  112. {
  113. Role: message.User,
  114. Parts: []message.ContentPart{
  115. message.TextContent{
  116. Text: content,
  117. },
  118. },
  119. },
  120. },
  121. make([]tools.BaseTool, 0),
  122. )
  123. if err != nil {
  124. return err
  125. }
  126. title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
  127. if title == "" {
  128. return nil
  129. }
  130. session.Title = title
  131. _, err = a.sessions.Save(ctx, session)
  132. return err
  133. }
  134. func (a *agent) err(err error) AgentEvent {
  135. return AgentEvent{
  136. err: err,
  137. }
  138. }
  139. func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) {
  140. events := make(chan AgentEvent)
  141. if a.IsSessionBusy(sessionID) {
  142. return nil, ErrSessionBusy
  143. }
  144. genCtx, cancel := context.WithCancel(ctx)
  145. a.activeRequests.Store(sessionID, cancel)
  146. go func() {
  147. logging.Debug("Request started", "sessionID", sessionID)
  148. defer logging.RecoverPanic("agent.Run", func() {
  149. events <- a.err(fmt.Errorf("panic while running the agent"))
  150. })
  151. result := a.processGeneration(genCtx, sessionID, content)
  152. if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
  153. logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, result))
  154. }
  155. logging.Debug("Request completed", "sessionID", sessionID)
  156. a.activeRequests.Delete(sessionID)
  157. cancel()
  158. events <- result
  159. close(events)
  160. }()
  161. return events, nil
  162. }
  163. func (a *agent) processGeneration(ctx context.Context, sessionID, content string) AgentEvent {
  164. // List existing messages; if none, start title generation asynchronously.
  165. msgs, err := a.messages.List(ctx, sessionID)
  166. if err != nil {
  167. return a.err(fmt.Errorf("failed to list messages: %w", err))
  168. }
  169. if len(msgs) == 0 {
  170. go func() {
  171. defer logging.RecoverPanic("agent.Run", func() {
  172. logging.ErrorPersist("panic while generating title")
  173. })
  174. titleErr := a.generateTitle(context.Background(), sessionID, content)
  175. if titleErr != nil {
  176. logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
  177. }
  178. }()
  179. }
  180. userMsg, err := a.createUserMessage(ctx, sessionID, content)
  181. if err != nil {
  182. return a.err(fmt.Errorf("failed to create user message: %w", err))
  183. }
  184. // Append the new user message to the conversation history.
  185. msgHistory := append(msgs, userMsg)
  186. for {
  187. // Check for cancellation before each iteration
  188. select {
  189. case <-ctx.Done():
  190. return a.err(ctx.Err())
  191. default:
  192. // Continue processing
  193. }
  194. agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
  195. if err != nil {
  196. if errors.Is(err, context.Canceled) {
  197. agentMessage.AddFinish(message.FinishReasonCanceled)
  198. a.messages.Update(context.Background(), agentMessage)
  199. return a.err(ErrRequestCancelled)
  200. }
  201. return a.err(fmt.Errorf("failed to process events: %w", err))
  202. }
  203. logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
  204. if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
  205. // We are not done, we need to respond with the tool response
  206. msgHistory = append(msgHistory, agentMessage, *toolResults)
  207. continue
  208. }
  209. return AgentEvent{
  210. message: agentMessage,
  211. }
  212. }
  213. }
  214. func (a *agent) createUserMessage(ctx context.Context, sessionID, content string) (message.Message, error) {
  215. return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  216. Role: message.User,
  217. Parts: []message.ContentPart{
  218. message.TextContent{Text: content},
  219. },
  220. })
  221. }
  222. func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
  223. eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
  224. assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  225. Role: message.Assistant,
  226. Parts: []message.ContentPart{},
  227. Model: a.provider.Model().ID,
  228. })
  229. if err != nil {
  230. return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
  231. }
  232. // Add the session and message ID into the context if needed by tools.
  233. ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
  234. ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
  235. // Process each event in the stream.
  236. for event := range eventChan {
  237. if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
  238. a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
  239. return assistantMsg, nil, processErr
  240. }
  241. if ctx.Err() != nil {
  242. a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
  243. return assistantMsg, nil, ctx.Err()
  244. }
  245. }
  246. toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
  247. toolCalls := assistantMsg.ToolCalls()
  248. for i, toolCall := range toolCalls {
  249. select {
  250. case <-ctx.Done():
  251. a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
  252. // Make all future tool calls cancelled
  253. for j := i; j < len(toolCalls); j++ {
  254. toolResults[j] = message.ToolResult{
  255. ToolCallID: toolCalls[j].ID,
  256. Content: "Tool execution canceled by user",
  257. IsError: true,
  258. }
  259. }
  260. goto out
  261. default:
  262. // Continue processing
  263. var tool tools.BaseTool
  264. for _, availableTools := range a.tools {
  265. if availableTools.Info().Name == toolCall.Name {
  266. tool = availableTools
  267. }
  268. }
  269. // Tool not found
  270. if tool == nil {
  271. toolResults[i] = message.ToolResult{
  272. ToolCallID: toolCall.ID,
  273. Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
  274. IsError: true,
  275. }
  276. continue
  277. }
  278. toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
  279. ID: toolCall.ID,
  280. Name: toolCall.Name,
  281. Input: toolCall.Input,
  282. })
  283. if toolErr != nil {
  284. if errors.Is(toolErr, permission.ErrorPermissionDenied) {
  285. toolResults[i] = message.ToolResult{
  286. ToolCallID: toolCall.ID,
  287. Content: "Permission denied",
  288. IsError: true,
  289. }
  290. for j := i + 1; j < len(toolCalls); j++ {
  291. toolResults[j] = message.ToolResult{
  292. ToolCallID: toolCalls[j].ID,
  293. Content: "Tool execution canceled by user",
  294. IsError: true,
  295. }
  296. }
  297. a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
  298. break
  299. }
  300. }
  301. toolResults[i] = message.ToolResult{
  302. ToolCallID: toolCall.ID,
  303. Content: toolResult.Content,
  304. Metadata: toolResult.Metadata,
  305. IsError: toolResult.IsError,
  306. }
  307. }
  308. }
  309. out:
  310. if len(toolResults) == 0 {
  311. return assistantMsg, nil, nil
  312. }
  313. parts := make([]message.ContentPart, 0)
  314. for _, tr := range toolResults {
  315. parts = append(parts, tr)
  316. }
  317. msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
  318. Role: message.Tool,
  319. Parts: parts,
  320. })
  321. if err != nil {
  322. return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
  323. }
  324. return assistantMsg, &msg, err
  325. }
  326. func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
  327. msg.AddFinish(finishReson)
  328. _ = a.messages.Update(ctx, *msg)
  329. }
  330. func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
  331. select {
  332. case <-ctx.Done():
  333. return ctx.Err()
  334. default:
  335. // Continue processing.
  336. }
  337. switch event.Type {
  338. case provider.EventThinkingDelta:
  339. assistantMsg.AppendReasoningContent(event.Content)
  340. return a.messages.Update(ctx, *assistantMsg)
  341. case provider.EventContentDelta:
  342. assistantMsg.AppendContent(event.Content)
  343. return a.messages.Update(ctx, *assistantMsg)
  344. case provider.EventToolUseStart:
  345. assistantMsg.AddToolCall(*event.ToolCall)
  346. return a.messages.Update(ctx, *assistantMsg)
  347. // TODO: see how to handle this
  348. // case provider.EventToolUseDelta:
  349. // tm := time.Unix(assistantMsg.UpdatedAt, 0)
  350. // assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
  351. // if time.Since(tm) > 1000*time.Millisecond {
  352. // err := a.messages.Update(ctx, *assistantMsg)
  353. // assistantMsg.UpdatedAt = time.Now().Unix()
  354. // return err
  355. // }
  356. case provider.EventToolUseStop:
  357. assistantMsg.FinishToolCall(event.ToolCall.ID)
  358. return a.messages.Update(ctx, *assistantMsg)
  359. case provider.EventError:
  360. if errors.Is(event.Error, context.Canceled) {
  361. logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
  362. return context.Canceled
  363. }
  364. logging.ErrorPersist(event.Error.Error())
  365. return event.Error
  366. case provider.EventComplete:
  367. assistantMsg.SetToolCalls(event.Response.ToolCalls)
  368. assistantMsg.AddFinish(event.Response.FinishReason)
  369. if err := a.messages.Update(ctx, *assistantMsg); err != nil {
  370. return fmt.Errorf("failed to update message: %w", err)
  371. }
  372. return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
  373. }
  374. return nil
  375. }
  376. func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
  377. sess, err := a.sessions.Get(ctx, sessionID)
  378. if err != nil {
  379. return fmt.Errorf("failed to get session: %w", err)
  380. }
  381. cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
  382. model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
  383. model.CostPer1MIn/1e6*float64(usage.InputTokens) +
  384. model.CostPer1MOut/1e6*float64(usage.OutputTokens)
  385. sess.Cost += cost
  386. sess.CompletionTokens += usage.OutputTokens
  387. sess.PromptTokens += usage.InputTokens
  388. _, err = a.sessions.Save(ctx, sess)
  389. if err != nil {
  390. return fmt.Errorf("failed to save session: %w", err)
  391. }
  392. return nil
  393. }
  394. func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) {
  395. if a.IsBusy() {
  396. return models.Model{}, fmt.Errorf("cannot change model while processing requests")
  397. }
  398. if err := config.UpdateAgentModel(agentName, modelID); err != nil {
  399. return models.Model{}, fmt.Errorf("failed to update config: %w", err)
  400. }
  401. provider, err := createAgentProvider(agentName)
  402. if err != nil {
  403. return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err)
  404. }
  405. a.provider = provider
  406. return a.provider.Model(), nil
  407. }
  408. func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
  409. cfg := config.Get()
  410. agentConfig, ok := cfg.Agents[agentName]
  411. if !ok {
  412. return nil, fmt.Errorf("agent %s not found", agentName)
  413. }
  414. model, ok := models.SupportedModels[agentConfig.Model]
  415. if !ok {
  416. return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
  417. }
  418. providerCfg, ok := cfg.Providers[model.Provider]
  419. if !ok {
  420. return nil, fmt.Errorf("provider %s not supported", model.Provider)
  421. }
  422. if providerCfg.Disabled {
  423. return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
  424. }
  425. maxTokens := model.DefaultMaxTokens
  426. if agentConfig.MaxTokens > 0 {
  427. maxTokens = agentConfig.MaxTokens
  428. }
  429. opts := []provider.ProviderClientOption{
  430. provider.WithAPIKey(providerCfg.APIKey),
  431. provider.WithModel(model),
  432. provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
  433. provider.WithMaxTokens(maxTokens),
  434. }
  435. if model.Provider == models.ProviderOpenAI && model.CanReason {
  436. opts = append(
  437. opts,
  438. provider.WithOpenAIOptions(
  439. provider.WithReasoningEffort(agentConfig.ReasoningEffort),
  440. ),
  441. )
  442. } else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder {
  443. opts = append(
  444. opts,
  445. provider.WithAnthropicOptions(
  446. provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
  447. ),
  448. )
  449. }
  450. agentProvider, err := provider.NewProvider(
  451. model.Provider,
  452. opts...,
  453. )
  454. if err != nil {
  455. return nil, fmt.Errorf("could not create provider: %v", err)
  456. }
  457. return agentProvider, nil
  458. }