agent.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. package agent
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "log"
  7. "strings"
  8. "sync"
  9. "github.com/kujtimiihoxha/termai/internal/app"
  10. "github.com/kujtimiihoxha/termai/internal/config"
  11. "github.com/kujtimiihoxha/termai/internal/llm/models"
  12. "github.com/kujtimiihoxha/termai/internal/llm/prompt"
  13. "github.com/kujtimiihoxha/termai/internal/llm/provider"
  14. "github.com/kujtimiihoxha/termai/internal/llm/tools"
  15. "github.com/kujtimiihoxha/termai/internal/message"
  16. )
  17. type Agent interface {
  18. Generate(sessionID string, content string) error
  19. }
  20. type agent struct {
  21. *app.App
  22. model models.Model
  23. tools []tools.BaseTool
  24. agent provider.Provider
  25. titleGenerator provider.Provider
  26. }
  27. func (c *agent) handleTitleGeneration(sessionID, content string) {
  28. response, err := c.titleGenerator.SendMessages(
  29. c.Context,
  30. []message.Message{
  31. {
  32. Role: message.User,
  33. Parts: []message.ContentPart{
  34. message.TextContent{
  35. Text: content,
  36. },
  37. },
  38. },
  39. },
  40. nil,
  41. )
  42. if err != nil {
  43. return
  44. }
  45. session, err := c.Sessions.Get(sessionID)
  46. if err != nil {
  47. return
  48. }
  49. if response.Content != "" {
  50. session.Title = response.Content
  51. session.Title = strings.TrimSpace(session.Title)
  52. session.Title = strings.ReplaceAll(session.Title, "\n", " ")
  53. c.Sessions.Save(session)
  54. }
  55. }
  56. func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error {
  57. session, err := c.Sessions.Get(sessionID)
  58. if err != nil {
  59. return err
  60. }
  61. cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
  62. model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
  63. model.CostPer1MIn/1e6*float64(usage.InputTokens) +
  64. model.CostPer1MOut/1e6*float64(usage.OutputTokens)
  65. session.Cost += cost
  66. session.CompletionTokens += usage.OutputTokens
  67. session.PromptTokens += usage.InputTokens
  68. _, err = c.Sessions.Save(session)
  69. return err
  70. }
  71. func (c *agent) processEvent(
  72. sessionID string,
  73. assistantMsg *message.Message,
  74. event provider.ProviderEvent,
  75. ) error {
  76. switch event.Type {
  77. case provider.EventThinkingDelta:
  78. assistantMsg.AppendReasoningContent(event.Content)
  79. return c.Messages.Update(*assistantMsg)
  80. case provider.EventContentDelta:
  81. assistantMsg.AppendContent(event.Content)
  82. return c.Messages.Update(*assistantMsg)
  83. case provider.EventError:
  84. log.Println("error", event.Error)
  85. return event.Error
  86. case provider.EventComplete:
  87. assistantMsg.SetToolCalls(event.Response.ToolCalls)
  88. assistantMsg.AddFinish(event.Response.FinishReason)
  89. err := c.Messages.Update(*assistantMsg)
  90. if err != nil {
  91. return err
  92. }
  93. return c.TrackUsage(sessionID, c.model, event.Response.Usage)
  94. }
  95. return nil
  96. }
  97. func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
  98. var wg sync.WaitGroup
  99. toolResults := make([]message.ToolResult, len(toolCalls))
  100. mutex := &sync.Mutex{}
  101. for i, tc := range toolCalls {
  102. wg.Add(1)
  103. go func(index int, toolCall message.ToolCall) {
  104. defer wg.Done()
  105. response := ""
  106. isError := false
  107. found := false
  108. for _, tool := range tls {
  109. if tool.Info().Name == toolCall.Name {
  110. found = true
  111. toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
  112. ID: toolCall.ID,
  113. Name: toolCall.Name,
  114. Input: toolCall.Input,
  115. })
  116. if toolErr != nil {
  117. response = fmt.Sprintf("error running tool: %s", toolErr)
  118. isError = true
  119. } else {
  120. response = toolResult.Content
  121. isError = toolResult.IsError
  122. }
  123. break
  124. }
  125. }
  126. if !found {
  127. response = fmt.Sprintf("tool not found: %s", toolCall.Name)
  128. isError = true
  129. }
  130. mutex.Lock()
  131. defer mutex.Unlock()
  132. toolResults[index] = message.ToolResult{
  133. ToolCallID: toolCall.ID,
  134. Content: response,
  135. IsError: isError,
  136. }
  137. }(i, tc)
  138. }
  139. wg.Wait()
  140. return toolResults, nil
  141. }
  142. func (c *agent) handleToolExecution(
  143. ctx context.Context,
  144. assistantMsg message.Message,
  145. ) (*message.Message, error) {
  146. if len(assistantMsg.ToolCalls()) == 0 {
  147. return nil, nil
  148. }
  149. toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
  150. if err != nil {
  151. return nil, err
  152. }
  153. parts := make([]message.ContentPart, 0)
  154. for _, toolResult := range toolResults {
  155. parts = append(parts, toolResult)
  156. }
  157. msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
  158. Role: message.Tool,
  159. Parts: parts,
  160. })
  161. return &msg, err
  162. }
  163. func (c *agent) generate(sessionID string, content string) error {
  164. messages, err := c.Messages.List(sessionID)
  165. if err != nil {
  166. return err
  167. }
  168. if len(messages) == 0 {
  169. go c.handleTitleGeneration(sessionID, content)
  170. }
  171. userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
  172. Role: message.User,
  173. Parts: []message.ContentPart{
  174. message.TextContent{
  175. Text: content,
  176. },
  177. },
  178. })
  179. if err != nil {
  180. return err
  181. }
  182. messages = append(messages, userMsg)
  183. for {
  184. eventChan, err := c.agent.StreamResponse(c.Context, messages, c.tools)
  185. if err != nil {
  186. return err
  187. }
  188. assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
  189. Role: message.Assistant,
  190. Parts: []message.ContentPart{},
  191. })
  192. if err != nil {
  193. return err
  194. }
  195. for event := range eventChan {
  196. err = c.processEvent(sessionID, &assistantMsg, event)
  197. if err != nil {
  198. assistantMsg.AddFinish("error:" + err.Error())
  199. c.Messages.Update(assistantMsg)
  200. return err
  201. }
  202. }
  203. msg, err := c.handleToolExecution(c.Context, assistantMsg)
  204. c.Messages.Update(assistantMsg)
  205. if err != nil {
  206. return err
  207. }
  208. if len(assistantMsg.ToolCalls()) == 0 {
  209. break
  210. }
  211. messages = append(messages, assistantMsg)
  212. if msg != nil {
  213. messages = append(messages, *msg)
  214. }
  215. }
  216. return nil
  217. }
  218. func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
  219. maxTokens := config.Get().Model.CoderMaxTokens
  220. providerConfig, ok := config.Get().Providers[model.Provider]
  221. if !ok || !providerConfig.Enabled {
  222. return nil, nil, errors.New("provider is not enabled")
  223. }
  224. var agentProvider provider.Provider
  225. var titleGenerator provider.Provider
  226. switch model.Provider {
  227. case models.ProviderOpenAI:
  228. var err error
  229. agentProvider, err = provider.NewOpenAIProvider(
  230. provider.WithOpenAISystemMessage(
  231. prompt.CoderOpenAISystemPrompt(),
  232. ),
  233. provider.WithOpenAIMaxTokens(maxTokens),
  234. provider.WithOpenAIModel(model),
  235. provider.WithOpenAIKey(providerConfig.APIKey),
  236. )
  237. if err != nil {
  238. return nil, nil, err
  239. }
  240. titleGenerator, err = provider.NewOpenAIProvider(
  241. provider.WithOpenAISystemMessage(
  242. prompt.TitlePrompt(),
  243. ),
  244. provider.WithOpenAIMaxTokens(80),
  245. provider.WithOpenAIModel(model),
  246. provider.WithOpenAIKey(providerConfig.APIKey),
  247. )
  248. if err != nil {
  249. return nil, nil, err
  250. }
  251. case models.ProviderAnthropic:
  252. var err error
  253. agentProvider, err = provider.NewAnthropicProvider(
  254. provider.WithAnthropicSystemMessage(
  255. prompt.CoderAnthropicSystemPrompt(),
  256. ),
  257. provider.WithAnthropicMaxTokens(maxTokens),
  258. provider.WithAnthropicKey(providerConfig.APIKey),
  259. provider.WithAnthropicModel(model),
  260. )
  261. if err != nil {
  262. return nil, nil, err
  263. }
  264. titleGenerator, err = provider.NewAnthropicProvider(
  265. provider.WithAnthropicSystemMessage(
  266. prompt.TitlePrompt(),
  267. ),
  268. provider.WithAnthropicMaxTokens(80),
  269. provider.WithAnthropicKey(providerConfig.APIKey),
  270. provider.WithAnthropicModel(model),
  271. )
  272. if err != nil {
  273. return nil, nil, err
  274. }
  275. case models.ProviderGemini:
  276. var err error
  277. agentProvider, err = provider.NewGeminiProvider(
  278. ctx,
  279. provider.WithGeminiSystemMessage(
  280. prompt.CoderOpenAISystemPrompt(),
  281. ),
  282. provider.WithGeminiMaxTokens(int32(maxTokens)),
  283. provider.WithGeminiKey(providerConfig.APIKey),
  284. provider.WithGeminiModel(model),
  285. )
  286. if err != nil {
  287. return nil, nil, err
  288. }
  289. titleGenerator, err = provider.NewGeminiProvider(
  290. ctx,
  291. provider.WithGeminiSystemMessage(
  292. prompt.TitlePrompt(),
  293. ),
  294. provider.WithGeminiMaxTokens(80),
  295. provider.WithGeminiKey(providerConfig.APIKey),
  296. provider.WithGeminiModel(model),
  297. )
  298. if err != nil {
  299. return nil, nil, err
  300. }
  301. case models.ProviderGROQ:
  302. var err error
  303. agentProvider, err = provider.NewOpenAIProvider(
  304. provider.WithOpenAISystemMessage(
  305. prompt.CoderAnthropicSystemPrompt(),
  306. ),
  307. provider.WithOpenAIMaxTokens(maxTokens),
  308. provider.WithOpenAIModel(model),
  309. provider.WithOpenAIKey(providerConfig.APIKey),
  310. provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
  311. )
  312. if err != nil {
  313. return nil, nil, err
  314. }
  315. titleGenerator, err = provider.NewOpenAIProvider(
  316. provider.WithOpenAISystemMessage(
  317. prompt.TitlePrompt(),
  318. ),
  319. provider.WithOpenAIMaxTokens(80),
  320. provider.WithOpenAIModel(model),
  321. provider.WithOpenAIKey(providerConfig.APIKey),
  322. provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
  323. )
  324. if err != nil {
  325. return nil, nil, err
  326. }
  327. }
  328. return agentProvider, titleGenerator, nil
  329. }