gemini.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. package provider
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "strings"
  9. "time"
  10. "github.com/google/generative-ai-go/genai"
  11. "github.com/google/uuid"
  12. "github.com/opencode-ai/opencode/internal/config"
  13. "github.com/opencode-ai/opencode/internal/llm/tools"
  14. "github.com/opencode-ai/opencode/internal/logging"
  15. "github.com/opencode-ai/opencode/internal/message"
  16. "google.golang.org/api/iterator"
  17. "google.golang.org/api/option"
  18. )
  19. type geminiOptions struct {
  20. disableCache bool
  21. }
  22. type GeminiOption func(*geminiOptions)
  23. type geminiClient struct {
  24. providerOptions providerClientOptions
  25. options geminiOptions
  26. client *genai.Client
  27. }
  28. type GeminiClient ProviderClient
  29. func newGeminiClient(opts providerClientOptions) GeminiClient {
  30. geminiOpts := geminiOptions{}
  31. for _, o := range opts.geminiOptions {
  32. o(&geminiOpts)
  33. }
  34. client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey))
  35. if err != nil {
  36. logging.Error("Failed to create Gemini client", "error", err)
  37. return nil
  38. }
  39. return &geminiClient{
  40. providerOptions: opts,
  41. options: geminiOpts,
  42. client: client,
  43. }
  44. }
  45. func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
  46. var history []*genai.Content
  47. for _, msg := range messages {
  48. switch msg.Role {
  49. case message.User:
  50. history = append(history, &genai.Content{
  51. Parts: []genai.Part{genai.Text(msg.Content().String())},
  52. Role: "user",
  53. })
  54. case message.Assistant:
  55. content := &genai.Content{
  56. Role: "model",
  57. Parts: []genai.Part{},
  58. }
  59. if msg.Content().String() != "" {
  60. content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
  61. }
  62. if len(msg.ToolCalls()) > 0 {
  63. for _, call := range msg.ToolCalls() {
  64. args, _ := parseJsonToMap(call.Input)
  65. content.Parts = append(content.Parts, genai.FunctionCall{
  66. Name: call.Name,
  67. Args: args,
  68. })
  69. }
  70. }
  71. history = append(history, content)
  72. case message.Tool:
  73. for _, result := range msg.ToolResults() {
  74. response := map[string]interface{}{"result": result.Content}
  75. parsed, err := parseJsonToMap(result.Content)
  76. if err == nil {
  77. response = parsed
  78. }
  79. var toolCall message.ToolCall
  80. for _, m := range messages {
  81. if m.Role == message.Assistant {
  82. for _, call := range m.ToolCalls() {
  83. if call.ID == result.ToolCallID {
  84. toolCall = call
  85. break
  86. }
  87. }
  88. }
  89. }
  90. history = append(history, &genai.Content{
  91. Parts: []genai.Part{genai.FunctionResponse{
  92. Name: toolCall.Name,
  93. Response: response,
  94. }},
  95. Role: "function",
  96. })
  97. }
  98. }
  99. }
  100. return history
  101. }
  102. func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
  103. geminiTool := &genai.Tool{}
  104. geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
  105. for _, tool := range tools {
  106. info := tool.Info()
  107. declaration := &genai.FunctionDeclaration{
  108. Name: info.Name,
  109. Description: info.Description,
  110. Parameters: &genai.Schema{
  111. Type: genai.TypeObject,
  112. Properties: convertSchemaProperties(info.Parameters),
  113. Required: info.Required,
  114. },
  115. }
  116. geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
  117. }
  118. return []*genai.Tool{geminiTool}
  119. }
  120. func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
  121. switch {
  122. case reason == genai.FinishReasonStop:
  123. return message.FinishReasonEndTurn
  124. case reason == genai.FinishReasonMaxTokens:
  125. return message.FinishReasonMaxTokens
  126. default:
  127. return message.FinishReasonUnknown
  128. }
  129. }
  130. func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
  131. model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
  132. model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
  133. model.SystemInstruction = &genai.Content{
  134. Parts: []genai.Part{
  135. genai.Text(g.providerOptions.systemMessage),
  136. },
  137. }
  138. // Convert tools
  139. if len(tools) > 0 {
  140. model.Tools = g.convertTools(tools)
  141. }
  142. // Convert messages
  143. geminiMessages := g.convertMessages(messages)
  144. cfg := config.Get()
  145. if cfg.Debug {
  146. jsonData, _ := json.Marshal(geminiMessages)
  147. logging.Debug("Prepared messages", "messages", string(jsonData))
  148. }
  149. attempts := 0
  150. for {
  151. attempts++
  152. var toolCalls []message.ToolCall
  153. chat := model.StartChat()
  154. chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
  155. lastMsg := geminiMessages[len(geminiMessages)-1]
  156. resp, err := chat.SendMessage(ctx, lastMsg.Parts...)
  157. // If there is an error we are going to see if we can retry the call
  158. if err != nil {
  159. retry, after, retryErr := g.shouldRetry(attempts, err)
  160. if retryErr != nil {
  161. return nil, retryErr
  162. }
  163. if retry {
  164. logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
  165. select {
  166. case <-ctx.Done():
  167. return nil, ctx.Err()
  168. case <-time.After(time.Duration(after) * time.Millisecond):
  169. continue
  170. }
  171. }
  172. return nil, retryErr
  173. }
  174. content := ""
  175. if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
  176. for _, part := range resp.Candidates[0].Content.Parts {
  177. switch p := part.(type) {
  178. case genai.Text:
  179. content = string(p)
  180. case genai.FunctionCall:
  181. id := "call_" + uuid.New().String()
  182. args, _ := json.Marshal(p.Args)
  183. toolCalls = append(toolCalls, message.ToolCall{
  184. ID: id,
  185. Name: p.Name,
  186. Input: string(args),
  187. Type: "function",
  188. Finished: true,
  189. })
  190. }
  191. }
  192. }
  193. finishReason := message.FinishReasonEndTurn
  194. if len(resp.Candidates) > 0 {
  195. finishReason = g.finishReason(resp.Candidates[0].FinishReason)
  196. }
  197. if len(toolCalls) > 0 {
  198. finishReason = message.FinishReasonToolUse
  199. }
  200. return &ProviderResponse{
  201. Content: content,
  202. ToolCalls: toolCalls,
  203. Usage: g.usage(resp),
  204. FinishReason: finishReason,
  205. }, nil
  206. }
  207. }
  208. func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
  209. model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
  210. model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
  211. model.SystemInstruction = &genai.Content{
  212. Parts: []genai.Part{
  213. genai.Text(g.providerOptions.systemMessage),
  214. },
  215. }
  216. // Convert tools
  217. if len(tools) > 0 {
  218. model.Tools = g.convertTools(tools)
  219. }
  220. // Convert messages
  221. geminiMessages := g.convertMessages(messages)
  222. cfg := config.Get()
  223. if cfg.Debug {
  224. jsonData, _ := json.Marshal(geminiMessages)
  225. logging.Debug("Prepared messages", "messages", string(jsonData))
  226. }
  227. attempts := 0
  228. eventChan := make(chan ProviderEvent)
  229. go func() {
  230. defer close(eventChan)
  231. for {
  232. attempts++
  233. chat := model.StartChat()
  234. chat.History = geminiMessages[:len(geminiMessages)-1]
  235. lastMsg := geminiMessages[len(geminiMessages)-1]
  236. iter := chat.SendMessageStream(ctx, lastMsg.Parts...)
  237. currentContent := ""
  238. toolCalls := []message.ToolCall{}
  239. var finalResp *genai.GenerateContentResponse
  240. eventChan <- ProviderEvent{Type: EventContentStart}
  241. for {
  242. resp, err := iter.Next()
  243. if err == iterator.Done {
  244. break
  245. }
  246. if err != nil {
  247. retry, after, retryErr := g.shouldRetry(attempts, err)
  248. if retryErr != nil {
  249. eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
  250. return
  251. }
  252. if retry {
  253. logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
  254. select {
  255. case <-ctx.Done():
  256. if ctx.Err() != nil {
  257. eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
  258. }
  259. return
  260. case <-time.After(time.Duration(after) * time.Millisecond):
  261. break
  262. }
  263. } else {
  264. eventChan <- ProviderEvent{Type: EventError, Error: err}
  265. return
  266. }
  267. }
  268. finalResp = resp
  269. if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
  270. for _, part := range resp.Candidates[0].Content.Parts {
  271. switch p := part.(type) {
  272. case genai.Text:
  273. delta := string(p)
  274. if delta != "" {
  275. eventChan <- ProviderEvent{
  276. Type: EventContentDelta,
  277. Content: delta,
  278. }
  279. currentContent += delta
  280. }
  281. case genai.FunctionCall:
  282. id := "call_" + uuid.New().String()
  283. args, _ := json.Marshal(p.Args)
  284. newCall := message.ToolCall{
  285. ID: id,
  286. Name: p.Name,
  287. Input: string(args),
  288. Type: "function",
  289. Finished: true,
  290. }
  291. isNew := true
  292. for _, existing := range toolCalls {
  293. if existing.Name == newCall.Name && existing.Input == newCall.Input {
  294. isNew = false
  295. break
  296. }
  297. }
  298. if isNew {
  299. toolCalls = append(toolCalls, newCall)
  300. }
  301. }
  302. }
  303. }
  304. }
  305. eventChan <- ProviderEvent{Type: EventContentStop}
  306. if finalResp != nil {
  307. finishReason := message.FinishReasonEndTurn
  308. if len(finalResp.Candidates) > 0 {
  309. finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
  310. }
  311. if len(toolCalls) > 0 {
  312. finishReason = message.FinishReasonToolUse
  313. }
  314. eventChan <- ProviderEvent{
  315. Type: EventComplete,
  316. Response: &ProviderResponse{
  317. Content: currentContent,
  318. ToolCalls: toolCalls,
  319. Usage: g.usage(finalResp),
  320. FinishReason: finishReason,
  321. },
  322. }
  323. return
  324. }
  325. }
  326. }()
  327. return eventChan
  328. }
  329. func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
  330. // Check if error is a rate limit error
  331. if attempts > maxRetries {
  332. return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
  333. }
  334. // Gemini doesn't have a standard error type we can check against
  335. // So we'll check the error message for rate limit indicators
  336. if errors.Is(err, io.EOF) {
  337. return false, 0, err
  338. }
  339. errMsg := err.Error()
  340. isRateLimit := false
  341. // Check for common rate limit error messages
  342. if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
  343. isRateLimit = true
  344. }
  345. if !isRateLimit {
  346. return false, 0, err
  347. }
  348. // Calculate backoff with jitter
  349. backoffMs := 2000 * (1 << (attempts - 1))
  350. jitterMs := int(float64(backoffMs) * 0.2)
  351. retryMs := backoffMs + jitterMs
  352. return true, int64(retryMs), nil
  353. }
  354. func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
  355. var toolCalls []message.ToolCall
  356. if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
  357. for _, part := range resp.Candidates[0].Content.Parts {
  358. if funcCall, ok := part.(genai.FunctionCall); ok {
  359. id := "call_" + uuid.New().String()
  360. args, _ := json.Marshal(funcCall.Args)
  361. toolCalls = append(toolCalls, message.ToolCall{
  362. ID: id,
  363. Name: funcCall.Name,
  364. Input: string(args),
  365. Type: "function",
  366. })
  367. }
  368. }
  369. }
  370. return toolCalls
  371. }
  372. func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
  373. if resp == nil || resp.UsageMetadata == nil {
  374. return TokenUsage{}
  375. }
  376. return TokenUsage{
  377. InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
  378. OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
  379. CacheCreationTokens: 0, // Not directly provided by Gemini
  380. CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
  381. }
  382. }
  383. func WithGeminiDisableCache() GeminiOption {
  384. return func(options *geminiOptions) {
  385. options.disableCache = true
  386. }
  387. }
  388. // Helper functions
  389. func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
  390. var result map[string]interface{}
  391. err := json.Unmarshal([]byte(jsonStr), &result)
  392. return result, err
  393. }
  394. func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
  395. properties := make(map[string]*genai.Schema)
  396. for name, param := range parameters {
  397. properties[name] = convertToSchema(param)
  398. }
  399. return properties
  400. }
  401. func convertToSchema(param interface{}) *genai.Schema {
  402. schema := &genai.Schema{Type: genai.TypeString}
  403. paramMap, ok := param.(map[string]interface{})
  404. if !ok {
  405. return schema
  406. }
  407. if desc, ok := paramMap["description"].(string); ok {
  408. schema.Description = desc
  409. }
  410. typeVal, hasType := paramMap["type"]
  411. if !hasType {
  412. return schema
  413. }
  414. typeStr, ok := typeVal.(string)
  415. if !ok {
  416. return schema
  417. }
  418. schema.Type = mapJSONTypeToGenAI(typeStr)
  419. switch typeStr {
  420. case "array":
  421. schema.Items = processArrayItems(paramMap)
  422. case "object":
  423. if props, ok := paramMap["properties"].(map[string]interface{}); ok {
  424. schema.Properties = convertSchemaProperties(props)
  425. }
  426. }
  427. return schema
  428. }
  429. func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
  430. items, ok := paramMap["items"].(map[string]interface{})
  431. if !ok {
  432. return nil
  433. }
  434. return convertToSchema(items)
  435. }
  436. func mapJSONTypeToGenAI(jsonType string) genai.Type {
  437. switch jsonType {
  438. case "string":
  439. return genai.TypeString
  440. case "number":
  441. return genai.TypeNumber
  442. case "integer":
  443. return genai.TypeInteger
  444. case "boolean":
  445. return genai.TypeBoolean
  446. case "array":
  447. return genai.TypeArray
  448. case "object":
  449. return genai.TypeObject
  450. default:
  451. return genai.TypeString // Default to string for unknown types
  452. }
  453. }
  454. func contains(s string, substrs ...string) bool {
  455. for _, substr := range substrs {
  456. if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
  457. return true
  458. }
  459. }
  460. return false
  461. }