agent.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. package agent
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "strings"
  7. "sync"
  8. "github.com/kujtimiihoxha/termai/internal/app"
  9. "github.com/kujtimiihoxha/termai/internal/config"
  10. "github.com/kujtimiihoxha/termai/internal/llm/models"
  11. "github.com/kujtimiihoxha/termai/internal/llm/prompt"
  12. "github.com/kujtimiihoxha/termai/internal/llm/provider"
  13. "github.com/kujtimiihoxha/termai/internal/llm/tools"
  14. "github.com/kujtimiihoxha/termai/internal/logging"
  15. "github.com/kujtimiihoxha/termai/internal/message"
  16. )
  17. type Agent interface {
  18. Generate(ctx context.Context, sessionID string, content string) error
  19. }
  20. type agent struct {
  21. *app.App
  22. model models.Model
  23. tools []tools.BaseTool
  24. agent provider.Provider
  25. titleGenerator provider.Provider
  26. }
  27. func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
  28. response, err := c.titleGenerator.SendMessages(
  29. ctx,
  30. []message.Message{
  31. {
  32. Role: message.User,
  33. Parts: []message.ContentPart{
  34. message.TextContent{
  35. Text: content,
  36. },
  37. },
  38. },
  39. },
  40. nil,
  41. )
  42. if err != nil {
  43. return
  44. }
  45. session, err := c.Sessions.Get(sessionID)
  46. if err != nil {
  47. return
  48. }
  49. if response.Content != "" {
  50. session.Title = response.Content
  51. session.Title = strings.TrimSpace(session.Title)
  52. session.Title = strings.ReplaceAll(session.Title, "\n", " ")
  53. c.Sessions.Save(session)
  54. }
  55. }
  56. func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error {
  57. session, err := c.Sessions.Get(sessionID)
  58. if err != nil {
  59. return err
  60. }
  61. cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
  62. model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
  63. model.CostPer1MIn/1e6*float64(usage.InputTokens) +
  64. model.CostPer1MOut/1e6*float64(usage.OutputTokens)
  65. session.Cost += cost
  66. session.CompletionTokens += usage.OutputTokens
  67. session.PromptTokens += usage.InputTokens
  68. _, err = c.Sessions.Save(session)
  69. return err
  70. }
  71. func (c *agent) processEvent(
  72. sessionID string,
  73. assistantMsg *message.Message,
  74. event provider.ProviderEvent,
  75. ) error {
  76. switch event.Type {
  77. case provider.EventThinkingDelta:
  78. assistantMsg.AppendReasoningContent(event.Content)
  79. return c.Messages.Update(*assistantMsg)
  80. case provider.EventContentDelta:
  81. assistantMsg.AppendContent(event.Content)
  82. return c.Messages.Update(*assistantMsg)
  83. case provider.EventError:
  84. if errors.Is(event.Error, context.Canceled) {
  85. return nil
  86. }
  87. logging.ErrorPersist(event.Error.Error())
  88. return event.Error
  89. case provider.EventWarning:
  90. logging.WarnPersist(event.Info)
  91. return nil
  92. case provider.EventInfo:
  93. logging.InfoPersist(event.Info)
  94. case provider.EventComplete:
  95. assistantMsg.SetToolCalls(event.Response.ToolCalls)
  96. assistantMsg.AddFinish(event.Response.FinishReason)
  97. err := c.Messages.Update(*assistantMsg)
  98. if err != nil {
  99. return err
  100. }
  101. return c.TrackUsage(sessionID, c.model, event.Response.Usage)
  102. }
  103. return nil
  104. }
  105. func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
  106. var wg sync.WaitGroup
  107. toolResults := make([]message.ToolResult, len(toolCalls))
  108. mutex := &sync.Mutex{}
  109. errChan := make(chan error, 1)
  110. // Create a child context that can be canceled
  111. ctx, cancel := context.WithCancel(ctx)
  112. defer cancel()
  113. for i, tc := range toolCalls {
  114. wg.Add(1)
  115. go func(index int, toolCall message.ToolCall) {
  116. defer wg.Done()
  117. // Check if context is already canceled
  118. select {
  119. case <-ctx.Done():
  120. mutex.Lock()
  121. toolResults[index] = message.ToolResult{
  122. ToolCallID: toolCall.ID,
  123. Content: "Tool execution canceled",
  124. IsError: true,
  125. }
  126. mutex.Unlock()
  127. // Send cancellation error to error channel if it's empty
  128. select {
  129. case errChan <- ctx.Err():
  130. default:
  131. }
  132. return
  133. default:
  134. }
  135. response := ""
  136. isError := false
  137. found := false
  138. for _, tool := range tls {
  139. if tool.Info().Name == toolCall.Name {
  140. found = true
  141. toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
  142. ID: toolCall.ID,
  143. Name: toolCall.Name,
  144. Input: toolCall.Input,
  145. })
  146. if toolErr != nil {
  147. if errors.Is(toolErr, context.Canceled) {
  148. response = "Tool execution canceled"
  149. // Send cancellation error to error channel if it's empty
  150. select {
  151. case errChan <- ctx.Err():
  152. default:
  153. }
  154. } else {
  155. response = fmt.Sprintf("error running tool: %s", toolErr)
  156. }
  157. isError = true
  158. } else {
  159. response = toolResult.Content
  160. isError = toolResult.IsError
  161. }
  162. break
  163. }
  164. }
  165. if !found {
  166. response = fmt.Sprintf("tool not found: %s", toolCall.Name)
  167. isError = true
  168. }
  169. mutex.Lock()
  170. defer mutex.Unlock()
  171. toolResults[index] = message.ToolResult{
  172. ToolCallID: toolCall.ID,
  173. Content: response,
  174. IsError: isError,
  175. }
  176. }(i, tc)
  177. }
  178. // Wait for all goroutines to finish or context to be canceled
  179. done := make(chan struct{})
  180. go func() {
  181. wg.Wait()
  182. close(done)
  183. }()
  184. select {
  185. case <-done:
  186. // All tools completed successfully
  187. case err := <-errChan:
  188. // One of the tools encountered a cancellation
  189. return toolResults, err
  190. case <-ctx.Done():
  191. // Context was canceled externally
  192. return toolResults, ctx.Err()
  193. }
  194. return toolResults, nil
  195. }
  196. func (c *agent) handleToolExecution(
  197. ctx context.Context,
  198. assistantMsg message.Message,
  199. ) (*message.Message, error) {
  200. if len(assistantMsg.ToolCalls()) == 0 {
  201. return nil, nil
  202. }
  203. toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
  204. if err != nil {
  205. return nil, err
  206. }
  207. parts := make([]message.ContentPart, 0)
  208. for _, toolResult := range toolResults {
  209. parts = append(parts, toolResult)
  210. }
  211. msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
  212. Role: message.Tool,
  213. Parts: parts,
  214. })
  215. return &msg, err
  216. }
  217. func (c *agent) generate(ctx context.Context, sessionID string, content string) error {
  218. messages, err := c.Messages.List(sessionID)
  219. if err != nil {
  220. return err
  221. }
  222. if len(messages) == 0 {
  223. go c.handleTitleGeneration(ctx, sessionID, content)
  224. }
  225. userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
  226. Role: message.User,
  227. Parts: []message.ContentPart{
  228. message.TextContent{
  229. Text: content,
  230. },
  231. },
  232. })
  233. if err != nil {
  234. return err
  235. }
  236. messages = append(messages, userMsg)
  237. for {
  238. select {
  239. case <-ctx.Done():
  240. assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
  241. Role: message.Assistant,
  242. Parts: []message.ContentPart{},
  243. })
  244. if err != nil {
  245. return err
  246. }
  247. assistantMsg.AddFinish("canceled")
  248. c.Messages.Update(assistantMsg)
  249. return context.Canceled
  250. default:
  251. // Continue processing
  252. }
  253. eventChan, err := c.agent.StreamResponse(ctx, messages, c.tools)
  254. if err != nil {
  255. if errors.Is(err, context.Canceled) {
  256. assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
  257. Role: message.Assistant,
  258. Parts: []message.ContentPart{},
  259. })
  260. if err != nil {
  261. return err
  262. }
  263. assistantMsg.AddFinish("canceled")
  264. c.Messages.Update(assistantMsg)
  265. return context.Canceled
  266. }
  267. return err
  268. }
  269. assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
  270. Role: message.Assistant,
  271. Parts: []message.ContentPart{},
  272. })
  273. if err != nil {
  274. return err
  275. }
  276. for event := range eventChan {
  277. err = c.processEvent(sessionID, &assistantMsg, event)
  278. if err != nil {
  279. if errors.Is(err, context.Canceled) {
  280. assistantMsg.AddFinish("canceled")
  281. c.Messages.Update(assistantMsg)
  282. return context.Canceled
  283. }
  284. assistantMsg.AddFinish("error:" + err.Error())
  285. c.Messages.Update(assistantMsg)
  286. return err
  287. }
  288. select {
  289. case <-ctx.Done():
  290. assistantMsg.AddFinish("canceled")
  291. c.Messages.Update(assistantMsg)
  292. return context.Canceled
  293. default:
  294. }
  295. }
  296. // Check for context cancellation before tool execution
  297. select {
  298. case <-ctx.Done():
  299. assistantMsg.AddFinish("canceled")
  300. c.Messages.Update(assistantMsg)
  301. return context.Canceled
  302. default:
  303. // Continue processing
  304. }
  305. msg, err := c.handleToolExecution(ctx, assistantMsg)
  306. if err != nil {
  307. if errors.Is(err, context.Canceled) {
  308. assistantMsg.AddFinish("canceled")
  309. c.Messages.Update(assistantMsg)
  310. return context.Canceled
  311. }
  312. return err
  313. }
  314. c.Messages.Update(assistantMsg)
  315. if len(assistantMsg.ToolCalls()) == 0 {
  316. break
  317. }
  318. messages = append(messages, assistantMsg)
  319. if msg != nil {
  320. messages = append(messages, *msg)
  321. }
  322. // Check for context cancellation after tool execution
  323. select {
  324. case <-ctx.Done():
  325. assistantMsg.AddFinish("canceled")
  326. c.Messages.Update(assistantMsg)
  327. return context.Canceled
  328. default:
  329. // Continue processing
  330. }
  331. }
  332. return nil
  333. }
  334. func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
  335. maxTokens := config.Get().Model.CoderMaxTokens
  336. providerConfig, ok := config.Get().Providers[model.Provider]
  337. if !ok || !providerConfig.Enabled {
  338. return nil, nil, errors.New("provider is not enabled")
  339. }
  340. var agentProvider provider.Provider
  341. var titleGenerator provider.Provider
  342. switch model.Provider {
  343. case models.ProviderOpenAI:
  344. var err error
  345. agentProvider, err = provider.NewOpenAIProvider(
  346. provider.WithOpenAISystemMessage(
  347. prompt.CoderOpenAISystemPrompt(),
  348. ),
  349. provider.WithOpenAIMaxTokens(maxTokens),
  350. provider.WithOpenAIModel(model),
  351. provider.WithOpenAIKey(providerConfig.APIKey),
  352. )
  353. if err != nil {
  354. return nil, nil, err
  355. }
  356. titleGenerator, err = provider.NewOpenAIProvider(
  357. provider.WithOpenAISystemMessage(
  358. prompt.TitlePrompt(),
  359. ),
  360. provider.WithOpenAIMaxTokens(80),
  361. provider.WithOpenAIModel(model),
  362. provider.WithOpenAIKey(providerConfig.APIKey),
  363. )
  364. if err != nil {
  365. return nil, nil, err
  366. }
  367. case models.ProviderAnthropic:
  368. var err error
  369. agentProvider, err = provider.NewAnthropicProvider(
  370. provider.WithAnthropicSystemMessage(
  371. prompt.CoderAnthropicSystemPrompt(),
  372. ),
  373. provider.WithAnthropicMaxTokens(maxTokens),
  374. provider.WithAnthropicKey(providerConfig.APIKey),
  375. provider.WithAnthropicModel(model),
  376. )
  377. if err != nil {
  378. return nil, nil, err
  379. }
  380. titleGenerator, err = provider.NewAnthropicProvider(
  381. provider.WithAnthropicSystemMessage(
  382. prompt.TitlePrompt(),
  383. ),
  384. provider.WithAnthropicMaxTokens(80),
  385. provider.WithAnthropicKey(providerConfig.APIKey),
  386. provider.WithAnthropicModel(model),
  387. )
  388. if err != nil {
  389. return nil, nil, err
  390. }
  391. case models.ProviderGemini:
  392. var err error
  393. agentProvider, err = provider.NewGeminiProvider(
  394. ctx,
  395. provider.WithGeminiSystemMessage(
  396. prompt.CoderOpenAISystemPrompt(),
  397. ),
  398. provider.WithGeminiMaxTokens(int32(maxTokens)),
  399. provider.WithGeminiKey(providerConfig.APIKey),
  400. provider.WithGeminiModel(model),
  401. )
  402. if err != nil {
  403. return nil, nil, err
  404. }
  405. titleGenerator, err = provider.NewGeminiProvider(
  406. ctx,
  407. provider.WithGeminiSystemMessage(
  408. prompt.TitlePrompt(),
  409. ),
  410. provider.WithGeminiMaxTokens(80),
  411. provider.WithGeminiKey(providerConfig.APIKey),
  412. provider.WithGeminiModel(model),
  413. )
  414. if err != nil {
  415. return nil, nil, err
  416. }
  417. case models.ProviderGROQ:
  418. var err error
  419. agentProvider, err = provider.NewOpenAIProvider(
  420. provider.WithOpenAISystemMessage(
  421. prompt.CoderAnthropicSystemPrompt(),
  422. ),
  423. provider.WithOpenAIMaxTokens(maxTokens),
  424. provider.WithOpenAIModel(model),
  425. provider.WithOpenAIKey(providerConfig.APIKey),
  426. provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
  427. )
  428. if err != nil {
  429. return nil, nil, err
  430. }
  431. titleGenerator, err = provider.NewOpenAIProvider(
  432. provider.WithOpenAISystemMessage(
  433. prompt.TitlePrompt(),
  434. ),
  435. provider.WithOpenAIMaxTokens(80),
  436. provider.WithOpenAIModel(model),
  437. provider.WithOpenAIKey(providerConfig.APIKey),
  438. provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
  439. )
  440. if err != nil {
  441. return nil, nil, err
  442. }
  443. case models.ProviderBedrock:
  444. var err error
  445. agentProvider, err = provider.NewBedrockProvider(
  446. provider.WithBedrockSystemMessage(
  447. prompt.CoderAnthropicSystemPrompt(),
  448. ),
  449. provider.WithBedrockMaxTokens(maxTokens),
  450. provider.WithBedrockModel(model),
  451. )
  452. if err != nil {
  453. return nil, nil, err
  454. }
  455. titleGenerator, err = provider.NewBedrockProvider(
  456. provider.WithBedrockSystemMessage(
  457. prompt.TitlePrompt(),
  458. ),
  459. provider.WithBedrockMaxTokens(maxTokens),
  460. provider.WithBedrockModel(model),
  461. )
  462. if err != nil {
  463. return nil, nil, err
  464. }
  465. }
  466. return agentProvider, titleGenerator, nil
  467. }