agent.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. package agent
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "log"
  7. "strings"
  8. "sync"
  9. "github.com/kujtimiihoxha/termai/internal/app"
  10. "github.com/kujtimiihoxha/termai/internal/config"
  11. "github.com/kujtimiihoxha/termai/internal/llm/models"
  12. "github.com/kujtimiihoxha/termai/internal/llm/prompt"
  13. "github.com/kujtimiihoxha/termai/internal/llm/provider"
  14. "github.com/kujtimiihoxha/termai/internal/llm/tools"
  15. "github.com/kujtimiihoxha/termai/internal/message"
  16. "github.com/kujtimiihoxha/termai/internal/pubsub"
  17. "github.com/kujtimiihoxha/termai/internal/tui/util"
  18. )
  19. type Agent interface {
  20. Generate(sessionID string, content string) error
  21. }
  22. type agent struct {
  23. *app.App
  24. model models.Model
  25. tools []tools.BaseTool
  26. agent provider.Provider
  27. titleGenerator provider.Provider
  28. }
  29. func (c *agent) handleTitleGeneration(sessionID, content string) {
  30. response, err := c.titleGenerator.SendMessages(
  31. c.Context,
  32. []message.Message{
  33. {
  34. Role: message.User,
  35. Parts: []message.ContentPart{
  36. message.TextContent{
  37. Text: content,
  38. },
  39. },
  40. },
  41. },
  42. nil,
  43. )
  44. if err != nil {
  45. return
  46. }
  47. session, err := c.Sessions.Get(sessionID)
  48. if err != nil {
  49. return
  50. }
  51. if response.Content != "" {
  52. session.Title = response.Content
  53. session.Title = strings.TrimSpace(session.Title)
  54. session.Title = strings.ReplaceAll(session.Title, "\n", " ")
  55. c.Sessions.Save(session)
  56. }
  57. }
  58. func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error {
  59. session, err := c.Sessions.Get(sessionID)
  60. if err != nil {
  61. return err
  62. }
  63. cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
  64. model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
  65. model.CostPer1MIn/1e6*float64(usage.InputTokens) +
  66. model.CostPer1MOut/1e6*float64(usage.OutputTokens)
  67. session.Cost += cost
  68. session.CompletionTokens += usage.OutputTokens
  69. session.PromptTokens += usage.InputTokens
  70. _, err = c.Sessions.Save(session)
  71. return err
  72. }
  73. func (c *agent) processEvent(
  74. sessionID string,
  75. assistantMsg *message.Message,
  76. event provider.ProviderEvent,
  77. ) error {
  78. switch event.Type {
  79. case provider.EventThinkingDelta:
  80. assistantMsg.AppendReasoningContent(event.Content)
  81. return c.Messages.Update(*assistantMsg)
  82. case provider.EventContentDelta:
  83. assistantMsg.AppendContent(event.Content)
  84. return c.Messages.Update(*assistantMsg)
  85. case provider.EventError:
  86. // TODO: remove when realease
  87. log.Println("error", event.Error)
  88. c.App.Status.Publish(pubsub.UpdatedEvent, util.InfoMsg{
  89. Type: util.InfoTypeError,
  90. Msg: event.Error.Error(),
  91. })
  92. return event.Error
  93. case provider.EventWarning:
  94. c.App.Status.Publish(pubsub.UpdatedEvent, util.InfoMsg{
  95. Type: util.InfoTypeWarn,
  96. Msg: event.Info,
  97. })
  98. return nil
  99. case provider.EventInfo:
  100. c.App.Status.Publish(pubsub.UpdatedEvent, util.InfoMsg{
  101. Type: util.InfoTypeInfo,
  102. Msg: event.Info,
  103. })
  104. case provider.EventComplete:
  105. assistantMsg.SetToolCalls(event.Response.ToolCalls)
  106. assistantMsg.AddFinish(event.Response.FinishReason)
  107. err := c.Messages.Update(*assistantMsg)
  108. if err != nil {
  109. return err
  110. }
  111. return c.TrackUsage(sessionID, c.model, event.Response.Usage)
  112. }
  113. return nil
  114. }
  115. func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
  116. var wg sync.WaitGroup
  117. toolResults := make([]message.ToolResult, len(toolCalls))
  118. mutex := &sync.Mutex{}
  119. for i, tc := range toolCalls {
  120. wg.Add(1)
  121. go func(index int, toolCall message.ToolCall) {
  122. defer wg.Done()
  123. response := ""
  124. isError := false
  125. found := false
  126. for _, tool := range tls {
  127. if tool.Info().Name == toolCall.Name {
  128. found = true
  129. toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
  130. ID: toolCall.ID,
  131. Name: toolCall.Name,
  132. Input: toolCall.Input,
  133. })
  134. if toolErr != nil {
  135. response = fmt.Sprintf("error running tool: %s", toolErr)
  136. isError = true
  137. } else {
  138. response = toolResult.Content
  139. isError = toolResult.IsError
  140. }
  141. break
  142. }
  143. }
  144. if !found {
  145. response = fmt.Sprintf("tool not found: %s", toolCall.Name)
  146. isError = true
  147. }
  148. mutex.Lock()
  149. defer mutex.Unlock()
  150. toolResults[index] = message.ToolResult{
  151. ToolCallID: toolCall.ID,
  152. Content: response,
  153. IsError: isError,
  154. }
  155. }(i, tc)
  156. }
  157. wg.Wait()
  158. return toolResults, nil
  159. }
  160. func (c *agent) handleToolExecution(
  161. ctx context.Context,
  162. assistantMsg message.Message,
  163. ) (*message.Message, error) {
  164. if len(assistantMsg.ToolCalls()) == 0 {
  165. return nil, nil
  166. }
  167. toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
  168. if err != nil {
  169. return nil, err
  170. }
  171. parts := make([]message.ContentPart, 0)
  172. for _, toolResult := range toolResults {
  173. parts = append(parts, toolResult)
  174. }
  175. msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
  176. Role: message.Tool,
  177. Parts: parts,
  178. })
  179. return &msg, err
  180. }
  181. func (c *agent) generate(sessionID string, content string) error {
  182. messages, err := c.Messages.List(sessionID)
  183. if err != nil {
  184. return err
  185. }
  186. if len(messages) == 0 {
  187. go c.handleTitleGeneration(sessionID, content)
  188. }
  189. userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
  190. Role: message.User,
  191. Parts: []message.ContentPart{
  192. message.TextContent{
  193. Text: content,
  194. },
  195. },
  196. })
  197. if err != nil {
  198. return err
  199. }
  200. messages = append(messages, userMsg)
  201. for {
  202. eventChan, err := c.agent.StreamResponse(c.Context, messages, c.tools)
  203. if err != nil {
  204. return err
  205. }
  206. assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
  207. Role: message.Assistant,
  208. Parts: []message.ContentPart{},
  209. })
  210. if err != nil {
  211. return err
  212. }
  213. for event := range eventChan {
  214. err = c.processEvent(sessionID, &assistantMsg, event)
  215. if err != nil {
  216. assistantMsg.AddFinish("error:" + err.Error())
  217. c.Messages.Update(assistantMsg)
  218. return err
  219. }
  220. }
  221. msg, err := c.handleToolExecution(c.Context, assistantMsg)
  222. c.Messages.Update(assistantMsg)
  223. if err != nil {
  224. return err
  225. }
  226. if len(assistantMsg.ToolCalls()) == 0 {
  227. break
  228. }
  229. messages = append(messages, assistantMsg)
  230. if msg != nil {
  231. messages = append(messages, *msg)
  232. }
  233. }
  234. return nil
  235. }
  236. func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
  237. maxTokens := config.Get().Model.CoderMaxTokens
  238. providerConfig, ok := config.Get().Providers[model.Provider]
  239. if !ok || !providerConfig.Enabled {
  240. return nil, nil, errors.New("provider is not enabled")
  241. }
  242. var agentProvider provider.Provider
  243. var titleGenerator provider.Provider
  244. switch model.Provider {
  245. case models.ProviderOpenAI:
  246. var err error
  247. agentProvider, err = provider.NewOpenAIProvider(
  248. provider.WithOpenAISystemMessage(
  249. prompt.CoderOpenAISystemPrompt(),
  250. ),
  251. provider.WithOpenAIMaxTokens(maxTokens),
  252. provider.WithOpenAIModel(model),
  253. provider.WithOpenAIKey(providerConfig.APIKey),
  254. )
  255. if err != nil {
  256. return nil, nil, err
  257. }
  258. titleGenerator, err = provider.NewOpenAIProvider(
  259. provider.WithOpenAISystemMessage(
  260. prompt.TitlePrompt(),
  261. ),
  262. provider.WithOpenAIMaxTokens(80),
  263. provider.WithOpenAIModel(model),
  264. provider.WithOpenAIKey(providerConfig.APIKey),
  265. )
  266. if err != nil {
  267. return nil, nil, err
  268. }
  269. case models.ProviderAnthropic:
  270. var err error
  271. agentProvider, err = provider.NewAnthropicProvider(
  272. provider.WithAnthropicSystemMessage(
  273. prompt.CoderAnthropicSystemPrompt(),
  274. ),
  275. provider.WithAnthropicMaxTokens(maxTokens),
  276. provider.WithAnthropicKey(providerConfig.APIKey),
  277. provider.WithAnthropicModel(model),
  278. )
  279. if err != nil {
  280. return nil, nil, err
  281. }
  282. titleGenerator, err = provider.NewAnthropicProvider(
  283. provider.WithAnthropicSystemMessage(
  284. prompt.TitlePrompt(),
  285. ),
  286. provider.WithAnthropicMaxTokens(80),
  287. provider.WithAnthropicKey(providerConfig.APIKey),
  288. provider.WithAnthropicModel(model),
  289. )
  290. if err != nil {
  291. return nil, nil, err
  292. }
  293. case models.ProviderGemini:
  294. var err error
  295. agentProvider, err = provider.NewGeminiProvider(
  296. ctx,
  297. provider.WithGeminiSystemMessage(
  298. prompt.CoderOpenAISystemPrompt(),
  299. ),
  300. provider.WithGeminiMaxTokens(int32(maxTokens)),
  301. provider.WithGeminiKey(providerConfig.APIKey),
  302. provider.WithGeminiModel(model),
  303. )
  304. if err != nil {
  305. return nil, nil, err
  306. }
  307. titleGenerator, err = provider.NewGeminiProvider(
  308. ctx,
  309. provider.WithGeminiSystemMessage(
  310. prompt.TitlePrompt(),
  311. ),
  312. provider.WithGeminiMaxTokens(80),
  313. provider.WithGeminiKey(providerConfig.APIKey),
  314. provider.WithGeminiModel(model),
  315. )
  316. if err != nil {
  317. return nil, nil, err
  318. }
  319. case models.ProviderGROQ:
  320. var err error
  321. agentProvider, err = provider.NewOpenAIProvider(
  322. provider.WithOpenAISystemMessage(
  323. prompt.CoderAnthropicSystemPrompt(),
  324. ),
  325. provider.WithOpenAIMaxTokens(maxTokens),
  326. provider.WithOpenAIModel(model),
  327. provider.WithOpenAIKey(providerConfig.APIKey),
  328. provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
  329. )
  330. if err != nil {
  331. return nil, nil, err
  332. }
  333. titleGenerator, err = provider.NewOpenAIProvider(
  334. provider.WithOpenAISystemMessage(
  335. prompt.TitlePrompt(),
  336. ),
  337. provider.WithOpenAIMaxTokens(80),
  338. provider.WithOpenAIModel(model),
  339. provider.WithOpenAIKey(providerConfig.APIKey),
  340. provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
  341. )
  342. if err != nil {
  343. return nil, nil, err
  344. }
  345. }
  346. return agentProvider, titleGenerator, nil
  347. }