gemini.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. package provider
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "strings"
  9. "time"
  10. "github.com/google/uuid"
  11. "github.com/sst/opencode/internal/config"
  12. "github.com/sst/opencode/internal/llm/tools"
  13. "github.com/sst/opencode/internal/message"
  14. "github.com/sst/opencode/internal/status"
  15. "google.golang.org/genai"
  16. "log/slog"
  17. )
  18. type geminiOptions struct {
  19. disableCache bool
  20. }
  21. type GeminiOption func(*geminiOptions)
  22. type geminiClient struct {
  23. providerOptions providerClientOptions
  24. options geminiOptions
  25. client *genai.Client
  26. }
  27. type GeminiClient ProviderClient
  28. func newGeminiClient(opts providerClientOptions) GeminiClient {
  29. geminiOpts := geminiOptions{}
  30. for _, o := range opts.geminiOptions {
  31. o(&geminiOpts)
  32. }
  33. client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
  34. if err != nil {
  35. slog.Error("Failed to create Gemini client", "error", err)
  36. return nil
  37. }
  38. return &geminiClient{
  39. providerOptions: opts,
  40. options: geminiOpts,
  41. client: client,
  42. }
  43. }
  44. func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
  45. var history []*genai.Content
  46. for _, msg := range messages {
  47. switch msg.Role {
  48. case message.User:
  49. var parts []*genai.Part
  50. parts = append(parts, &genai.Part{Text: msg.Content().String()})
  51. for _, binaryContent := range msg.BinaryContent() {
  52. imageFormat := strings.Split(binaryContent.MIMEType, "/")
  53. parts = append(parts, &genai.Part{InlineData: &genai.Blob{
  54. MIMEType: imageFormat[1],
  55. Data: binaryContent.Data,
  56. }})
  57. }
  58. history = append(history, &genai.Content{
  59. Parts: parts,
  60. Role: "user",
  61. })
  62. case message.Assistant:
  63. content := &genai.Content{
  64. Role: "model",
  65. Parts: []*genai.Part{},
  66. }
  67. if msg.Content().String() != "" {
  68. content.Parts = append(content.Parts, &genai.Part{Text: msg.Content().String()})
  69. }
  70. if len(msg.ToolCalls()) > 0 {
  71. for _, call := range msg.ToolCalls() {
  72. args, _ := parseJsonToMap(call.Input)
  73. content.Parts = append(content.Parts, &genai.Part{
  74. FunctionCall: &genai.FunctionCall{
  75. Name: call.Name,
  76. Args: args,
  77. },
  78. })
  79. }
  80. }
  81. history = append(history, content)
  82. case message.Tool:
  83. for _, result := range msg.ToolResults() {
  84. response := map[string]interface{}{"result": result.Content}
  85. parsed, err := parseJsonToMap(result.Content)
  86. if err == nil {
  87. response = parsed
  88. }
  89. var toolCall message.ToolCall
  90. for _, m := range messages {
  91. if m.Role == message.Assistant {
  92. for _, call := range m.ToolCalls() {
  93. if call.ID == result.ToolCallID {
  94. toolCall = call
  95. break
  96. }
  97. }
  98. }
  99. }
  100. history = append(history, &genai.Content{
  101. Parts: []*genai.Part{
  102. {
  103. FunctionResponse: &genai.FunctionResponse{
  104. Name: toolCall.Name,
  105. Response: response,
  106. },
  107. },
  108. },
  109. Role: "function",
  110. })
  111. }
  112. }
  113. }
  114. return history
  115. }
  116. func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
  117. geminiTool := &genai.Tool{}
  118. geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
  119. for _, tool := range tools {
  120. info := tool.Info()
  121. declaration := &genai.FunctionDeclaration{
  122. Name: info.Name,
  123. Description: info.Description,
  124. Parameters: &genai.Schema{
  125. Type: genai.TypeObject,
  126. Properties: convertSchemaProperties(info.Parameters),
  127. Required: info.Required,
  128. },
  129. }
  130. geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
  131. }
  132. return []*genai.Tool{geminiTool}
  133. }
  134. func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
  135. switch {
  136. case reason == genai.FinishReasonStop:
  137. return message.FinishReasonEndTurn
  138. case reason == genai.FinishReasonMaxTokens:
  139. return message.FinishReasonMaxTokens
  140. default:
  141. return message.FinishReasonUnknown
  142. }
  143. }
  144. func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
  145. // Convert messages
  146. geminiMessages := g.convertMessages(messages)
  147. cfg := config.Get()
  148. if cfg.Debug {
  149. jsonData, _ := json.Marshal(geminiMessages)
  150. slog.Debug("Prepared messages", "messages", string(jsonData))
  151. }
  152. history := geminiMessages[:len(geminiMessages)-1] // All but last message
  153. lastMsg := geminiMessages[len(geminiMessages)-1]
  154. config := &genai.GenerateContentConfig{
  155. MaxOutputTokens: int32(g.providerOptions.maxTokens),
  156. SystemInstruction: &genai.Content{
  157. Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
  158. },
  159. }
  160. if len(tools) > 0 {
  161. config.Tools = g.convertTools(tools)
  162. }
  163. chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
  164. attempts := 0
  165. for {
  166. attempts++
  167. var toolCalls []message.ToolCall
  168. var lastMsgParts []genai.Part
  169. for _, part := range lastMsg.Parts {
  170. lastMsgParts = append(lastMsgParts, *part)
  171. }
  172. resp, err := chat.SendMessage(ctx, lastMsgParts...)
  173. // If there is an error we are going to see if we can retry the call
  174. if err != nil {
  175. retry, after, retryErr := g.shouldRetry(attempts, err)
  176. duration := time.Duration(after) * time.Millisecond
  177. if retryErr != nil {
  178. return nil, retryErr
  179. }
  180. if retry {
  181. status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), status.WithDuration(duration))
  182. select {
  183. case <-ctx.Done():
  184. return nil, ctx.Err()
  185. case <-time.After(duration):
  186. continue
  187. }
  188. }
  189. return nil, retryErr
  190. }
  191. content := ""
  192. if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
  193. for _, part := range resp.Candidates[0].Content.Parts {
  194. switch {
  195. case part.Text != "":
  196. content = string(part.Text)
  197. case part.FunctionCall != nil:
  198. id := "call_" + uuid.New().String()
  199. args, _ := json.Marshal(part.FunctionCall.Args)
  200. toolCalls = append(toolCalls, message.ToolCall{
  201. ID: id,
  202. Name: part.FunctionCall.Name,
  203. Input: string(args),
  204. Type: "function",
  205. Finished: true,
  206. })
  207. }
  208. }
  209. }
  210. finishReason := message.FinishReasonEndTurn
  211. if len(resp.Candidates) > 0 {
  212. finishReason = g.finishReason(resp.Candidates[0].FinishReason)
  213. }
  214. if len(toolCalls) > 0 {
  215. finishReason = message.FinishReasonToolUse
  216. }
  217. return &ProviderResponse{
  218. Content: content,
  219. ToolCalls: toolCalls,
  220. Usage: g.usage(resp),
  221. FinishReason: finishReason,
  222. }, nil
  223. }
  224. }
  225. func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
  226. // Convert messages
  227. geminiMessages := g.convertMessages(messages)
  228. cfg := config.Get()
  229. if cfg.Debug {
  230. jsonData, _ := json.Marshal(geminiMessages)
  231. slog.Debug("Prepared messages", "messages", string(jsonData))
  232. }
  233. history := geminiMessages[:len(geminiMessages)-1] // All but last message
  234. lastMsg := geminiMessages[len(geminiMessages)-1]
  235. config := &genai.GenerateContentConfig{
  236. MaxOutputTokens: int32(g.providerOptions.maxTokens),
  237. SystemInstruction: &genai.Content{
  238. Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
  239. },
  240. }
  241. if len(tools) > 0 {
  242. config.Tools = g.convertTools(tools)
  243. }
  244. chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
  245. attempts := 0
  246. eventChan := make(chan ProviderEvent)
  247. go func() {
  248. defer close(eventChan)
  249. for {
  250. attempts++
  251. currentContent := ""
  252. toolCalls := []message.ToolCall{}
  253. var finalResp *genai.GenerateContentResponse
  254. eventChan <- ProviderEvent{Type: EventContentStart}
  255. var lastMsgParts []genai.Part
  256. for _, part := range lastMsg.Parts {
  257. lastMsgParts = append(lastMsgParts, *part)
  258. }
  259. for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
  260. if err != nil {
  261. retry, after, retryErr := g.shouldRetry(attempts, err)
  262. duration := time.Duration(after) * time.Millisecond
  263. if retryErr != nil {
  264. eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
  265. return
  266. }
  267. if retry {
  268. status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), status.WithDuration(duration))
  269. select {
  270. case <-ctx.Done():
  271. if ctx.Err() != nil {
  272. eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
  273. }
  274. return
  275. case <-time.After(duration):
  276. break
  277. }
  278. } else {
  279. eventChan <- ProviderEvent{Type: EventError, Error: err}
  280. return
  281. }
  282. }
  283. finalResp = resp
  284. if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
  285. for _, part := range resp.Candidates[0].Content.Parts {
  286. switch {
  287. case part.Text != "":
  288. delta := string(part.Text)
  289. if delta != "" {
  290. eventChan <- ProviderEvent{
  291. Type: EventContentDelta,
  292. Content: delta,
  293. }
  294. currentContent += delta
  295. }
  296. case part.FunctionCall != nil:
  297. id := "call_" + uuid.New().String()
  298. args, _ := json.Marshal(part.FunctionCall.Args)
  299. newCall := message.ToolCall{
  300. ID: id,
  301. Name: part.FunctionCall.Name,
  302. Input: string(args),
  303. Type: "function",
  304. Finished: true,
  305. }
  306. isNew := true
  307. for _, existing := range toolCalls {
  308. if existing.Name == newCall.Name && existing.Input == newCall.Input {
  309. isNew = false
  310. break
  311. }
  312. }
  313. if isNew {
  314. toolCalls = append(toolCalls, newCall)
  315. }
  316. }
  317. }
  318. }
  319. }
  320. eventChan <- ProviderEvent{Type: EventContentStop}
  321. if finalResp != nil {
  322. finishReason := message.FinishReasonEndTurn
  323. if len(finalResp.Candidates) > 0 {
  324. finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
  325. }
  326. if len(toolCalls) > 0 {
  327. finishReason = message.FinishReasonToolUse
  328. }
  329. eventChan <- ProviderEvent{
  330. Type: EventComplete,
  331. Response: &ProviderResponse{
  332. Content: currentContent,
  333. ToolCalls: toolCalls,
  334. Usage: g.usage(finalResp),
  335. FinishReason: finishReason,
  336. },
  337. }
  338. return
  339. }
  340. }
  341. }()
  342. return eventChan
  343. }
  344. func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
  345. // Check if error is a rate limit error
  346. if attempts > maxRetries {
  347. return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
  348. }
  349. // Gemini doesn't have a standard error type we can check against
  350. // So we'll check the error message for rate limit indicators
  351. if errors.Is(err, io.EOF) {
  352. return false, 0, err
  353. }
  354. errMsg := err.Error()
  355. isRateLimit := false
  356. // Check for common rate limit error messages
  357. if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
  358. isRateLimit = true
  359. }
  360. if !isRateLimit {
  361. return false, 0, err
  362. }
  363. // Calculate backoff with jitter
  364. backoffMs := 2000 * (1 << (attempts - 1))
  365. jitterMs := int(float64(backoffMs) * 0.2)
  366. retryMs := backoffMs + jitterMs
  367. return true, int64(retryMs), nil
  368. }
  369. func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
  370. var toolCalls []message.ToolCall
  371. if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
  372. for _, part := range resp.Candidates[0].Content.Parts {
  373. if part.FunctionCall != nil {
  374. id := "call_" + uuid.New().String()
  375. args, _ := json.Marshal(part.FunctionCall.Args)
  376. toolCalls = append(toolCalls, message.ToolCall{
  377. ID: id,
  378. Name: part.FunctionCall.Name,
  379. Input: string(args),
  380. Type: "function",
  381. })
  382. }
  383. }
  384. }
  385. return toolCalls
  386. }
  387. func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
  388. if resp == nil || resp.UsageMetadata == nil {
  389. return TokenUsage{}
  390. }
  391. return TokenUsage{
  392. InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
  393. OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
  394. CacheCreationTokens: 0, // Not directly provided by Gemini
  395. CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
  396. }
  397. }
  398. func WithGeminiDisableCache() GeminiOption {
  399. return func(options *geminiOptions) {
  400. options.disableCache = true
  401. }
  402. }
  403. // Helper functions
  404. func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
  405. var result map[string]interface{}
  406. err := json.Unmarshal([]byte(jsonStr), &result)
  407. return result, err
  408. }
  409. func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
  410. properties := make(map[string]*genai.Schema)
  411. for name, param := range parameters {
  412. properties[name] = convertToSchema(param)
  413. }
  414. return properties
  415. }
  416. func convertToSchema(param interface{}) *genai.Schema {
  417. schema := &genai.Schema{Type: genai.TypeString}
  418. paramMap, ok := param.(map[string]interface{})
  419. if !ok {
  420. return schema
  421. }
  422. if desc, ok := paramMap["description"].(string); ok {
  423. schema.Description = desc
  424. }
  425. typeVal, hasType := paramMap["type"]
  426. if !hasType {
  427. return schema
  428. }
  429. typeStr, ok := typeVal.(string)
  430. if !ok {
  431. return schema
  432. }
  433. schema.Type = mapJSONTypeToGenAI(typeStr)
  434. switch typeStr {
  435. case "array":
  436. schema.Items = processArrayItems(paramMap)
  437. case "object":
  438. if props, ok := paramMap["properties"].(map[string]interface{}); ok {
  439. schema.Properties = convertSchemaProperties(props)
  440. }
  441. }
  442. return schema
  443. }
  444. func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
  445. items, ok := paramMap["items"].(map[string]interface{})
  446. if !ok {
  447. return nil
  448. }
  449. return convertToSchema(items)
  450. }
  451. func mapJSONTypeToGenAI(jsonType string) genai.Type {
  452. switch jsonType {
  453. case "string":
  454. return genai.TypeString
  455. case "number":
  456. return genai.TypeNumber
  457. case "integer":
  458. return genai.TypeInteger
  459. case "boolean":
  460. return genai.TypeBoolean
  461. case "array":
  462. return genai.TypeArray
  463. case "object":
  464. return genai.TypeObject
  465. default:
  466. return genai.TypeString // Default to string for unknown types
  467. }
  468. }
  469. func contains(s string, substrs ...string) bool {
  470. for _, substr := range substrs {
  471. if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
  472. return true
  473. }
  474. }
  475. return false
  476. }