agent.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. package agent
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "strings"
  7. "sync"
  8. "github.com/kujtimiihoxha/termai/internal/app"
  9. "github.com/kujtimiihoxha/termai/internal/config"
  10. "github.com/kujtimiihoxha/termai/internal/llm/models"
  11. "github.com/kujtimiihoxha/termai/internal/llm/prompt"
  12. "github.com/kujtimiihoxha/termai/internal/llm/provider"
  13. "github.com/kujtimiihoxha/termai/internal/llm/tools"
  14. "github.com/kujtimiihoxha/termai/internal/logging"
  15. "github.com/kujtimiihoxha/termai/internal/message"
  16. )
  17. type Agent interface {
  18. Generate(ctx context.Context, sessionID string, content string) error
  19. }
  20. type agent struct {
  21. *app.App
  22. model models.Model
  23. tools []tools.BaseTool
  24. agent provider.Provider
  25. titleGenerator provider.Provider
  26. }
  27. func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
  28. response, err := c.titleGenerator.SendMessages(
  29. ctx,
  30. []message.Message{
  31. {
  32. Role: message.User,
  33. Parts: []message.ContentPart{
  34. message.TextContent{
  35. Text: content,
  36. },
  37. },
  38. },
  39. },
  40. nil,
  41. )
  42. if err != nil {
  43. return
  44. }
  45. session, err := c.Sessions.Get(sessionID)
  46. if err != nil {
  47. return
  48. }
  49. if response.Content != "" {
  50. session.Title = response.Content
  51. session.Title = strings.TrimSpace(session.Title)
  52. session.Title = strings.ReplaceAll(session.Title, "\n", " ")
  53. c.Sessions.Save(session)
  54. }
  55. }
  56. func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error {
  57. session, err := c.Sessions.Get(sessionID)
  58. if err != nil {
  59. return err
  60. }
  61. cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
  62. model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
  63. model.CostPer1MIn/1e6*float64(usage.InputTokens) +
  64. model.CostPer1MOut/1e6*float64(usage.OutputTokens)
  65. session.Cost += cost
  66. session.CompletionTokens += usage.OutputTokens
  67. session.PromptTokens += usage.InputTokens
  68. _, err = c.Sessions.Save(session)
  69. return err
  70. }
  71. func (c *agent) processEvent(
  72. sessionID string,
  73. assistantMsg *message.Message,
  74. event provider.ProviderEvent,
  75. ) error {
  76. switch event.Type {
  77. case provider.EventThinkingDelta:
  78. assistantMsg.AppendReasoningContent(event.Content)
  79. return c.Messages.Update(*assistantMsg)
  80. case provider.EventContentDelta:
  81. assistantMsg.AppendContent(event.Content)
  82. return c.Messages.Update(*assistantMsg)
  83. case provider.EventError:
  84. if errors.Is(event.Error, context.Canceled) {
  85. return nil
  86. }
  87. logging.ErrorPersist(event.Error.Error())
  88. return event.Error
  89. case provider.EventWarning:
  90. logging.WarnPersist(event.Info)
  91. return nil
  92. case provider.EventInfo:
  93. logging.InfoPersist(event.Info)
  94. case provider.EventComplete:
  95. assistantMsg.SetToolCalls(event.Response.ToolCalls)
  96. assistantMsg.AddFinish(event.Response.FinishReason)
  97. err := c.Messages.Update(*assistantMsg)
  98. if err != nil {
  99. return err
  100. }
  101. return c.TrackUsage(sessionID, c.model, event.Response.Usage)
  102. }
  103. return nil
  104. }
  105. func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
  106. var wg sync.WaitGroup
  107. toolResults := make([]message.ToolResult, len(toolCalls))
  108. mutex := &sync.Mutex{}
  109. errChan := make(chan error, 1)
  110. // Create a child context that can be canceled
  111. ctx, cancel := context.WithCancel(ctx)
  112. defer cancel()
  113. for i, tc := range toolCalls {
  114. wg.Add(1)
  115. go func(index int, toolCall message.ToolCall) {
  116. defer wg.Done()
  117. // Check if context is already canceled
  118. select {
  119. case <-ctx.Done():
  120. mutex.Lock()
  121. toolResults[index] = message.ToolResult{
  122. ToolCallID: toolCall.ID,
  123. Content: "Tool execution canceled",
  124. IsError: true,
  125. }
  126. mutex.Unlock()
  127. // Send cancellation error to error channel if it's empty
  128. select {
  129. case errChan <- ctx.Err():
  130. default:
  131. }
  132. return
  133. default:
  134. }
  135. response := ""
  136. isError := false
  137. found := false
  138. for _, tool := range tls {
  139. if tool.Info().Name == toolCall.Name {
  140. found = true
  141. toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
  142. ID: toolCall.ID,
  143. Name: toolCall.Name,
  144. Input: toolCall.Input,
  145. })
  146. if toolErr != nil {
  147. if errors.Is(toolErr, context.Canceled) {
  148. response = "Tool execution canceled"
  149. // Send cancellation error to error channel if it's empty
  150. select {
  151. case errChan <- ctx.Err():
  152. default:
  153. }
  154. } else {
  155. response = fmt.Sprintf("error running tool: %s", toolErr)
  156. }
  157. isError = true
  158. } else {
  159. response = toolResult.Content
  160. isError = toolResult.IsError
  161. }
  162. break
  163. }
  164. }
  165. if !found {
  166. response = fmt.Sprintf("tool not found: %s", toolCall.Name)
  167. isError = true
  168. }
  169. mutex.Lock()
  170. defer mutex.Unlock()
  171. toolResults[index] = message.ToolResult{
  172. ToolCallID: toolCall.ID,
  173. Content: response,
  174. IsError: isError,
  175. }
  176. }(i, tc)
  177. }
  178. // Wait for all goroutines to finish or context to be canceled
  179. done := make(chan struct{})
  180. go func() {
  181. wg.Wait()
  182. close(done)
  183. }()
  184. select {
  185. case <-done:
  186. // All tools completed successfully
  187. case err := <-errChan:
  188. // One of the tools encountered a cancellation
  189. return toolResults, err
  190. case <-ctx.Done():
  191. // Context was canceled externally
  192. return toolResults, ctx.Err()
  193. }
  194. return toolResults, nil
  195. }
  196. func (c *agent) handleToolExecution(
  197. ctx context.Context,
  198. assistantMsg message.Message,
  199. ) (*message.Message, error) {
  200. if len(assistantMsg.ToolCalls()) == 0 {
  201. return nil, nil
  202. }
  203. toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
  204. if err != nil {
  205. return nil, err
  206. }
  207. parts := make([]message.ContentPart, 0)
  208. for _, toolResult := range toolResults {
  209. parts = append(parts, toolResult)
  210. }
  211. msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
  212. Role: message.Tool,
  213. Parts: parts,
  214. })
  215. return &msg, err
  216. }
  217. func (c *agent) generate(ctx context.Context, sessionID string, content string) error {
  218. ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
  219. messages, err := c.Messages.List(sessionID)
  220. if err != nil {
  221. return err
  222. }
  223. if len(messages) == 0 {
  224. go c.handleTitleGeneration(ctx, sessionID, content)
  225. }
  226. userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
  227. Role: message.User,
  228. Parts: []message.ContentPart{
  229. message.TextContent{
  230. Text: content,
  231. },
  232. },
  233. })
  234. if err != nil {
  235. return err
  236. }
  237. messages = append(messages, userMsg)
  238. for {
  239. select {
  240. case <-ctx.Done():
  241. assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
  242. Role: message.Assistant,
  243. Parts: []message.ContentPart{},
  244. })
  245. if err != nil {
  246. return err
  247. }
  248. assistantMsg.AddFinish("canceled")
  249. c.Messages.Update(assistantMsg)
  250. return context.Canceled
  251. default:
  252. // Continue processing
  253. }
  254. eventChan, err := c.agent.StreamResponse(ctx, messages, c.tools)
  255. if err != nil {
  256. if errors.Is(err, context.Canceled) {
  257. assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
  258. Role: message.Assistant,
  259. Parts: []message.ContentPart{},
  260. })
  261. if err != nil {
  262. return err
  263. }
  264. assistantMsg.AddFinish("canceled")
  265. c.Messages.Update(assistantMsg)
  266. return context.Canceled
  267. }
  268. return err
  269. }
  270. assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
  271. Role: message.Assistant,
  272. Parts: []message.ContentPart{},
  273. Model: c.model.ID,
  274. })
  275. if err != nil {
  276. return err
  277. }
  278. ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
  279. for event := range eventChan {
  280. err = c.processEvent(sessionID, &assistantMsg, event)
  281. if err != nil {
  282. if errors.Is(err, context.Canceled) {
  283. assistantMsg.AddFinish("canceled")
  284. c.Messages.Update(assistantMsg)
  285. return context.Canceled
  286. }
  287. assistantMsg.AddFinish("error:" + err.Error())
  288. c.Messages.Update(assistantMsg)
  289. return err
  290. }
  291. select {
  292. case <-ctx.Done():
  293. assistantMsg.AddFinish("canceled")
  294. c.Messages.Update(assistantMsg)
  295. return context.Canceled
  296. default:
  297. }
  298. }
  299. // Check for context cancellation before tool execution
  300. select {
  301. case <-ctx.Done():
  302. assistantMsg.AddFinish("canceled")
  303. c.Messages.Update(assistantMsg)
  304. return context.Canceled
  305. default:
  306. // Continue processing
  307. }
  308. msg, err := c.handleToolExecution(ctx, assistantMsg)
  309. if err != nil {
  310. if errors.Is(err, context.Canceled) {
  311. assistantMsg.AddFinish("canceled")
  312. c.Messages.Update(assistantMsg)
  313. return context.Canceled
  314. }
  315. return err
  316. }
  317. c.Messages.Update(assistantMsg)
  318. if len(assistantMsg.ToolCalls()) == 0 {
  319. break
  320. }
  321. messages = append(messages, assistantMsg)
  322. if msg != nil {
  323. messages = append(messages, *msg)
  324. }
  325. // Check for context cancellation after tool execution
  326. select {
  327. case <-ctx.Done():
  328. assistantMsg.AddFinish("canceled")
  329. c.Messages.Update(assistantMsg)
  330. return context.Canceled
  331. default:
  332. // Continue processing
  333. }
  334. }
  335. return nil
  336. }
  337. func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
  338. maxTokens := config.Get().Model.CoderMaxTokens
  339. providerConfig, ok := config.Get().Providers[model.Provider]
  340. if !ok || !providerConfig.Enabled {
  341. return nil, nil, errors.New("provider is not enabled")
  342. }
  343. var agentProvider provider.Provider
  344. var titleGenerator provider.Provider
  345. switch model.Provider {
  346. case models.ProviderOpenAI:
  347. var err error
  348. agentProvider, err = provider.NewOpenAIProvider(
  349. provider.WithOpenAISystemMessage(
  350. prompt.CoderOpenAISystemPrompt(),
  351. ),
  352. provider.WithOpenAIMaxTokens(maxTokens),
  353. provider.WithOpenAIModel(model),
  354. provider.WithOpenAIKey(providerConfig.APIKey),
  355. )
  356. if err != nil {
  357. return nil, nil, err
  358. }
  359. titleGenerator, err = provider.NewOpenAIProvider(
  360. provider.WithOpenAISystemMessage(
  361. prompt.TitlePrompt(),
  362. ),
  363. provider.WithOpenAIMaxTokens(80),
  364. provider.WithOpenAIModel(model),
  365. provider.WithOpenAIKey(providerConfig.APIKey),
  366. )
  367. if err != nil {
  368. return nil, nil, err
  369. }
  370. case models.ProviderAnthropic:
  371. var err error
  372. agentProvider, err = provider.NewAnthropicProvider(
  373. provider.WithAnthropicSystemMessage(
  374. prompt.CoderAnthropicSystemPrompt(),
  375. ),
  376. provider.WithAnthropicMaxTokens(maxTokens),
  377. provider.WithAnthropicKey(providerConfig.APIKey),
  378. provider.WithAnthropicModel(model),
  379. )
  380. if err != nil {
  381. return nil, nil, err
  382. }
  383. titleGenerator, err = provider.NewAnthropicProvider(
  384. provider.WithAnthropicSystemMessage(
  385. prompt.TitlePrompt(),
  386. ),
  387. provider.WithAnthropicMaxTokens(80),
  388. provider.WithAnthropicKey(providerConfig.APIKey),
  389. provider.WithAnthropicModel(model),
  390. )
  391. if err != nil {
  392. return nil, nil, err
  393. }
  394. case models.ProviderGemini:
  395. var err error
  396. agentProvider, err = provider.NewGeminiProvider(
  397. ctx,
  398. provider.WithGeminiSystemMessage(
  399. prompt.CoderOpenAISystemPrompt(),
  400. ),
  401. provider.WithGeminiMaxTokens(int32(maxTokens)),
  402. provider.WithGeminiKey(providerConfig.APIKey),
  403. provider.WithGeminiModel(model),
  404. )
  405. if err != nil {
  406. return nil, nil, err
  407. }
  408. titleGenerator, err = provider.NewGeminiProvider(
  409. ctx,
  410. provider.WithGeminiSystemMessage(
  411. prompt.TitlePrompt(),
  412. ),
  413. provider.WithGeminiMaxTokens(80),
  414. provider.WithGeminiKey(providerConfig.APIKey),
  415. provider.WithGeminiModel(model),
  416. )
  417. if err != nil {
  418. return nil, nil, err
  419. }
  420. case models.ProviderGROQ:
  421. var err error
  422. agentProvider, err = provider.NewOpenAIProvider(
  423. provider.WithOpenAISystemMessage(
  424. prompt.CoderAnthropicSystemPrompt(),
  425. ),
  426. provider.WithOpenAIMaxTokens(maxTokens),
  427. provider.WithOpenAIModel(model),
  428. provider.WithOpenAIKey(providerConfig.APIKey),
  429. provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
  430. )
  431. if err != nil {
  432. return nil, nil, err
  433. }
  434. titleGenerator, err = provider.NewOpenAIProvider(
  435. provider.WithOpenAISystemMessage(
  436. prompt.TitlePrompt(),
  437. ),
  438. provider.WithOpenAIMaxTokens(80),
  439. provider.WithOpenAIModel(model),
  440. provider.WithOpenAIKey(providerConfig.APIKey),
  441. provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
  442. )
  443. if err != nil {
  444. return nil, nil, err
  445. }
  446. case models.ProviderBedrock:
  447. var err error
  448. agentProvider, err = provider.NewBedrockProvider(
  449. provider.WithBedrockSystemMessage(
  450. prompt.CoderAnthropicSystemPrompt(),
  451. ),
  452. provider.WithBedrockMaxTokens(maxTokens),
  453. provider.WithBedrockModel(model),
  454. )
  455. if err != nil {
  456. return nil, nil, err
  457. }
  458. titleGenerator, err = provider.NewBedrockProvider(
  459. provider.WithBedrockSystemMessage(
  460. prompt.TitlePrompt(),
  461. ),
  462. provider.WithBedrockMaxTokens(maxTokens),
  463. provider.WithBedrockModel(model),
  464. )
  465. if err != nil {
  466. return nil, nil, err
  467. }
  468. }
  469. return agentProvider, titleGenerator, nil
  470. }