gemini.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. package provider
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "github.com/google/generative-ai-go/genai"
  7. "github.com/google/uuid"
  8. "github.com/kujtimiihoxha/termai/internal/llm/models"
  9. "github.com/kujtimiihoxha/termai/internal/llm/tools"
  10. "github.com/kujtimiihoxha/termai/internal/message"
  11. "google.golang.org/api/iterator"
  12. "google.golang.org/api/option"
  13. )
  14. type geminiProvider struct {
  15. client *genai.Client
  16. model models.Model
  17. maxTokens int32
  18. apiKey string
  19. systemMessage string
  20. }
  21. type GeminiOption func(*geminiProvider)
  22. func NewGeminiProvider(ctx context.Context, opts ...GeminiOption) (Provider, error) {
  23. provider := &geminiProvider{
  24. maxTokens: 5000,
  25. }
  26. for _, opt := range opts {
  27. opt(provider)
  28. }
  29. if provider.systemMessage == "" {
  30. return nil, errors.New("system message is required")
  31. }
  32. client, err := genai.NewClient(ctx, option.WithAPIKey(provider.apiKey))
  33. if err != nil {
  34. return nil, err
  35. }
  36. provider.client = client
  37. return provider, nil
  38. }
  39. func WithGeminiSystemMessage(message string) GeminiOption {
  40. return func(p *geminiProvider) {
  41. p.systemMessage = message
  42. }
  43. }
  44. func WithGeminiMaxTokens(maxTokens int32) GeminiOption {
  45. return func(p *geminiProvider) {
  46. p.maxTokens = maxTokens
  47. }
  48. }
  49. func WithGeminiModel(model models.Model) GeminiOption {
  50. return func(p *geminiProvider) {
  51. p.model = model
  52. }
  53. }
  54. func WithGeminiKey(apiKey string) GeminiOption {
  55. return func(p *geminiProvider) {
  56. p.apiKey = apiKey
  57. }
  58. }
  59. func (p *geminiProvider) Close() {
  60. if p.client != nil {
  61. p.client.Close()
  62. }
  63. }
  64. func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content {
  65. var history []*genai.Content
  66. for _, msg := range messages {
  67. switch msg.Role {
  68. case message.User:
  69. history = append(history, &genai.Content{
  70. Parts: []genai.Part{genai.Text(msg.Content().String())},
  71. Role: "user",
  72. })
  73. case message.Assistant:
  74. content := &genai.Content{
  75. Role: "model",
  76. Parts: []genai.Part{},
  77. }
  78. if msg.Content().String() != "" {
  79. content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
  80. }
  81. if len(msg.ToolCalls()) > 0 {
  82. for _, call := range msg.ToolCalls() {
  83. args, _ := parseJsonToMap(call.Input)
  84. content.Parts = append(content.Parts, genai.FunctionCall{
  85. Name: call.Name,
  86. Args: args,
  87. })
  88. }
  89. }
  90. history = append(history, content)
  91. case message.Tool:
  92. for _, result := range msg.ToolResults() {
  93. response := map[string]interface{}{"result": result.Content}
  94. parsed, err := parseJsonToMap(result.Content)
  95. if err == nil {
  96. response = parsed
  97. }
  98. var toolCall message.ToolCall
  99. for _, msg := range messages {
  100. if msg.Role == message.Assistant {
  101. for _, call := range msg.ToolCalls() {
  102. if call.ID == result.ToolCallID {
  103. toolCall = call
  104. break
  105. }
  106. }
  107. }
  108. }
  109. history = append(history, &genai.Content{
  110. Parts: []genai.Part{genai.FunctionResponse{
  111. Name: toolCall.Name,
  112. Response: response,
  113. }},
  114. Role: "function",
  115. })
  116. }
  117. }
  118. }
  119. return history
  120. }
  121. func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage {
  122. if resp == nil || resp.UsageMetadata == nil {
  123. return TokenUsage{}
  124. }
  125. return TokenUsage{
  126. InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
  127. OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
  128. CacheCreationTokens: 0, // Not directly provided by Gemini
  129. CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
  130. }
  131. }
  132. func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
  133. messages = cleanupMessages(messages)
  134. model := p.client.GenerativeModel(p.model.APIModel)
  135. model.SetMaxOutputTokens(p.maxTokens)
  136. model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
  137. if len(tools) > 0 {
  138. declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
  139. for _, declaration := range declarations {
  140. model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
  141. }
  142. }
  143. chat := model.StartChat()
  144. chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
  145. lastUserMsg := messages[len(messages)-1]
  146. resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content().String()))
  147. if err != nil {
  148. return nil, err
  149. }
  150. var content string
  151. var toolCalls []message.ToolCall
  152. if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
  153. for _, part := range resp.Candidates[0].Content.Parts {
  154. switch p := part.(type) {
  155. case genai.Text:
  156. content = string(p)
  157. case genai.FunctionCall:
  158. id := "call_" + uuid.New().String()
  159. args, _ := json.Marshal(p.Args)
  160. toolCalls = append(toolCalls, message.ToolCall{
  161. ID: id,
  162. Name: p.Name,
  163. Input: string(args),
  164. Type: "function",
  165. })
  166. }
  167. }
  168. }
  169. tokenUsage := p.extractTokenUsage(resp)
  170. return &ProviderResponse{
  171. Content: content,
  172. ToolCalls: toolCalls,
  173. Usage: tokenUsage,
  174. }, nil
  175. }
  176. func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
  177. messages = cleanupMessages(messages)
  178. model := p.client.GenerativeModel(p.model.APIModel)
  179. model.SetMaxOutputTokens(p.maxTokens)
  180. model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
  181. if len(tools) > 0 {
  182. declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
  183. for _, declaration := range declarations {
  184. model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
  185. }
  186. }
  187. chat := model.StartChat()
  188. chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
  189. lastUserMsg := messages[len(messages)-1]
  190. iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content().String()))
  191. eventChan := make(chan ProviderEvent)
  192. go func() {
  193. defer close(eventChan)
  194. var finalResp *genai.GenerateContentResponse
  195. currentContent := ""
  196. toolCalls := []message.ToolCall{}
  197. for {
  198. resp, err := iter.Next()
  199. if err == iterator.Done {
  200. break
  201. }
  202. if err != nil {
  203. eventChan <- ProviderEvent{
  204. Type: EventError,
  205. Error: err,
  206. }
  207. return
  208. }
  209. finalResp = resp
  210. if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
  211. for _, part := range resp.Candidates[0].Content.Parts {
  212. switch p := part.(type) {
  213. case genai.Text:
  214. newText := string(p)
  215. eventChan <- ProviderEvent{
  216. Type: EventContentDelta,
  217. Content: newText,
  218. }
  219. currentContent += newText
  220. case genai.FunctionCall:
  221. id := "call_" + uuid.New().String()
  222. args, _ := json.Marshal(p.Args)
  223. newCall := message.ToolCall{
  224. ID: id,
  225. Name: p.Name,
  226. Input: string(args),
  227. Type: "function",
  228. }
  229. isNew := true
  230. for _, existing := range toolCalls {
  231. if existing.Name == newCall.Name && existing.Input == newCall.Input {
  232. isNew = false
  233. break
  234. }
  235. }
  236. if isNew {
  237. toolCalls = append(toolCalls, newCall)
  238. }
  239. }
  240. }
  241. }
  242. }
  243. tokenUsage := p.extractTokenUsage(finalResp)
  244. eventChan <- ProviderEvent{
  245. Type: EventComplete,
  246. Response: &ProviderResponse{
  247. Content: currentContent,
  248. ToolCalls: toolCalls,
  249. Usage: tokenUsage,
  250. FinishReason: string(finalResp.Candidates[0].FinishReason.String()),
  251. },
  252. }
  253. }()
  254. return eventChan, nil
  255. }
  256. func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
  257. declarations := make([]*genai.FunctionDeclaration, len(tools))
  258. for i, tool := range tools {
  259. info := tool.Info()
  260. declarations[i] = &genai.FunctionDeclaration{
  261. Name: info.Name,
  262. Description: info.Description,
  263. Parameters: &genai.Schema{
  264. Type: genai.TypeObject,
  265. Properties: convertSchemaProperties(info.Parameters),
  266. Required: info.Required,
  267. },
  268. }
  269. }
  270. return declarations
  271. }
  272. func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
  273. properties := make(map[string]*genai.Schema)
  274. for name, param := range parameters {
  275. properties[name] = convertToSchema(param)
  276. }
  277. return properties
  278. }
  279. func convertToSchema(param interface{}) *genai.Schema {
  280. schema := &genai.Schema{Type: genai.TypeString}
  281. paramMap, ok := param.(map[string]interface{})
  282. if !ok {
  283. return schema
  284. }
  285. if desc, ok := paramMap["description"].(string); ok {
  286. schema.Description = desc
  287. }
  288. typeVal, hasType := paramMap["type"]
  289. if !hasType {
  290. return schema
  291. }
  292. typeStr, ok := typeVal.(string)
  293. if !ok {
  294. return schema
  295. }
  296. schema.Type = mapJSONTypeToGenAI(typeStr)
  297. switch typeStr {
  298. case "array":
  299. schema.Items = processArrayItems(paramMap)
  300. case "object":
  301. if props, ok := paramMap["properties"].(map[string]interface{}); ok {
  302. schema.Properties = convertSchemaProperties(props)
  303. }
  304. }
  305. return schema
  306. }
  307. func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
  308. items, ok := paramMap["items"].(map[string]interface{})
  309. if !ok {
  310. return nil
  311. }
  312. return convertToSchema(items)
  313. }
  314. func mapJSONTypeToGenAI(jsonType string) genai.Type {
  315. switch jsonType {
  316. case "string":
  317. return genai.TypeString
  318. case "number":
  319. return genai.TypeNumber
  320. case "integer":
  321. return genai.TypeInteger
  322. case "boolean":
  323. return genai.TypeBoolean
  324. case "array":
  325. return genai.TypeArray
  326. case "object":
  327. return genai.TypeObject
  328. default:
  329. return genai.TypeString // Default to string for unknown types
  330. }
  331. }
  332. func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
  333. var result map[string]interface{}
  334. err := json.Unmarshal([]byte(jsonStr), &result)
  335. return result, err
  336. }