| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555 |
- package provider
- import (
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "strings"
- "time"
- "github.com/google/uuid"
- "github.com/sst/opencode/internal/config"
- "github.com/sst/opencode/internal/llm/tools"
- "github.com/sst/opencode/internal/message"
- "github.com/sst/opencode/internal/status"
- "google.golang.org/genai"
- "log/slog"
- )
- type geminiOptions struct {
- disableCache bool
- }
- type GeminiOption func(*geminiOptions)
- type geminiClient struct {
- providerOptions providerClientOptions
- options geminiOptions
- client *genai.Client
- }
- type GeminiClient ProviderClient
- func newGeminiClient(opts providerClientOptions) GeminiClient {
- geminiOpts := geminiOptions{}
- for _, o := range opts.geminiOptions {
- o(&geminiOpts)
- }
- client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
- if err != nil {
- slog.Error("Failed to create Gemini client", "error", err)
- return nil
- }
- return &geminiClient{
- providerOptions: opts,
- options: geminiOpts,
- client: client,
- }
- }
- func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
- var history []*genai.Content
- for _, msg := range messages {
- switch msg.Role {
- case message.User:
- var parts []*genai.Part
- parts = append(parts, &genai.Part{Text: msg.Content().String()})
- for _, binaryContent := range msg.BinaryContent() {
- imageFormat := strings.Split(binaryContent.MIMEType, "/")
- parts = append(parts, &genai.Part{InlineData: &genai.Blob{
- MIMEType: imageFormat[1],
- Data: binaryContent.Data,
- }})
- }
- history = append(history, &genai.Content{
- Parts: parts,
- Role: "user",
- })
- case message.Assistant:
- content := &genai.Content{
- Role: "model",
- Parts: []*genai.Part{},
- }
- if msg.Content().String() != "" {
- content.Parts = append(content.Parts, &genai.Part{Text: msg.Content().String()})
- }
- if len(msg.ToolCalls()) > 0 {
- for _, call := range msg.ToolCalls() {
- args, _ := parseJsonToMap(call.Input)
- content.Parts = append(content.Parts, &genai.Part{
- FunctionCall: &genai.FunctionCall{
- Name: call.Name,
- Args: args,
- },
- })
- }
- }
- history = append(history, content)
- case message.Tool:
- for _, result := range msg.ToolResults() {
- response := map[string]interface{}{"result": result.Content}
- parsed, err := parseJsonToMap(result.Content)
- if err == nil {
- response = parsed
- }
- var toolCall message.ToolCall
- for _, m := range messages {
- if m.Role == message.Assistant {
- for _, call := range m.ToolCalls() {
- if call.ID == result.ToolCallID {
- toolCall = call
- break
- }
- }
- }
- }
- history = append(history, &genai.Content{
- Parts: []*genai.Part{
- {
- FunctionResponse: &genai.FunctionResponse{
- Name: toolCall.Name,
- Response: response,
- },
- },
- },
- Role: "function",
- })
- }
- }
- }
- return history
- }
- func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
- geminiTool := &genai.Tool{}
- geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
- for _, tool := range tools {
- info := tool.Info()
- declaration := &genai.FunctionDeclaration{
- Name: info.Name,
- Description: info.Description,
- Parameters: &genai.Schema{
- Type: genai.TypeObject,
- Properties: convertSchemaProperties(info.Parameters),
- Required: info.Required,
- },
- }
- geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
- }
- return []*genai.Tool{geminiTool}
- }
- func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
- switch {
- case reason == genai.FinishReasonStop:
- return message.FinishReasonEndTurn
- case reason == genai.FinishReasonMaxTokens:
- return message.FinishReasonMaxTokens
- default:
- return message.FinishReasonUnknown
- }
- }
- func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
- // Convert messages
- geminiMessages := g.convertMessages(messages)
- cfg := config.Get()
- if cfg.Debug {
- jsonData, _ := json.Marshal(geminiMessages)
- slog.Debug("Prepared messages", "messages", string(jsonData))
- }
- history := geminiMessages[:len(geminiMessages)-1] // All but last message
- lastMsg := geminiMessages[len(geminiMessages)-1]
- config := &genai.GenerateContentConfig{
- MaxOutputTokens: int32(g.providerOptions.maxTokens),
- SystemInstruction: &genai.Content{
- Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
- },
- }
- if len(tools) > 0 {
- config.Tools = g.convertTools(tools)
- }
- chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
- attempts := 0
- for {
- attempts++
- var toolCalls []message.ToolCall
- var lastMsgParts []genai.Part
- for _, part := range lastMsg.Parts {
- lastMsgParts = append(lastMsgParts, *part)
- }
- resp, err := chat.SendMessage(ctx, lastMsgParts...)
- // If there is an error we are going to see if we can retry the call
- if err != nil {
- retry, after, retryErr := g.shouldRetry(attempts, err)
- duration := time.Duration(after) * time.Millisecond
- if retryErr != nil {
- return nil, retryErr
- }
- if retry {
- status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), status.WithDuration(duration))
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- case <-time.After(duration):
- continue
- }
- }
- return nil, retryErr
- }
- content := ""
- if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
- for _, part := range resp.Candidates[0].Content.Parts {
- switch {
- case part.Text != "":
- content = string(part.Text)
- case part.FunctionCall != nil:
- id := "call_" + uuid.New().String()
- args, _ := json.Marshal(part.FunctionCall.Args)
- toolCalls = append(toolCalls, message.ToolCall{
- ID: id,
- Name: part.FunctionCall.Name,
- Input: string(args),
- Type: "function",
- Finished: true,
- })
- }
- }
- }
- finishReason := message.FinishReasonEndTurn
- if len(resp.Candidates) > 0 {
- finishReason = g.finishReason(resp.Candidates[0].FinishReason)
- }
- if len(toolCalls) > 0 {
- finishReason = message.FinishReasonToolUse
- }
- return &ProviderResponse{
- Content: content,
- ToolCalls: toolCalls,
- Usage: g.usage(resp),
- FinishReason: finishReason,
- }, nil
- }
- }
- func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
- // Convert messages
- geminiMessages := g.convertMessages(messages)
- cfg := config.Get()
- if cfg.Debug {
- jsonData, _ := json.Marshal(geminiMessages)
- slog.Debug("Prepared messages", "messages", string(jsonData))
- }
- history := geminiMessages[:len(geminiMessages)-1] // All but last message
- lastMsg := geminiMessages[len(geminiMessages)-1]
- config := &genai.GenerateContentConfig{
- MaxOutputTokens: int32(g.providerOptions.maxTokens),
- SystemInstruction: &genai.Content{
- Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
- },
- }
- if len(tools) > 0 {
- config.Tools = g.convertTools(tools)
- }
- chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
- attempts := 0
- eventChan := make(chan ProviderEvent)
- go func() {
- defer close(eventChan)
- for {
- attempts++
- currentContent := ""
- toolCalls := []message.ToolCall{}
- var finalResp *genai.GenerateContentResponse
- eventChan <- ProviderEvent{Type: EventContentStart}
- var lastMsgParts []genai.Part
- for _, part := range lastMsg.Parts {
- lastMsgParts = append(lastMsgParts, *part)
- }
- for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
- if err != nil {
- retry, after, retryErr := g.shouldRetry(attempts, err)
- duration := time.Duration(after) * time.Millisecond
- if retryErr != nil {
- eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
- return
- }
- if retry {
- status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), status.WithDuration(duration))
- select {
- case <-ctx.Done():
- if ctx.Err() != nil {
- eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
- }
- return
- case <-time.After(duration):
- break
- }
- } else {
- eventChan <- ProviderEvent{Type: EventError, Error: err}
- return
- }
- }
- finalResp = resp
- if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
- for _, part := range resp.Candidates[0].Content.Parts {
- switch {
- case part.Text != "":
- delta := string(part.Text)
- if delta != "" {
- eventChan <- ProviderEvent{
- Type: EventContentDelta,
- Content: delta,
- }
- currentContent += delta
- }
- case part.FunctionCall != nil:
- id := "call_" + uuid.New().String()
- args, _ := json.Marshal(part.FunctionCall.Args)
- newCall := message.ToolCall{
- ID: id,
- Name: part.FunctionCall.Name,
- Input: string(args),
- Type: "function",
- Finished: true,
- }
- isNew := true
- for _, existing := range toolCalls {
- if existing.Name == newCall.Name && existing.Input == newCall.Input {
- isNew = false
- break
- }
- }
- if isNew {
- toolCalls = append(toolCalls, newCall)
- }
- }
- }
- }
- }
- eventChan <- ProviderEvent{Type: EventContentStop}
- if finalResp != nil {
- finishReason := message.FinishReasonEndTurn
- if len(finalResp.Candidates) > 0 {
- finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
- }
- if len(toolCalls) > 0 {
- finishReason = message.FinishReasonToolUse
- }
- eventChan <- ProviderEvent{
- Type: EventComplete,
- Response: &ProviderResponse{
- Content: currentContent,
- ToolCalls: toolCalls,
- Usage: g.usage(finalResp),
- FinishReason: finishReason,
- },
- }
- return
- }
- }
- }()
- return eventChan
- }
- func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
- // Check if error is a rate limit error
- if attempts > maxRetries {
- return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
- }
- // Gemini doesn't have a standard error type we can check against
- // So we'll check the error message for rate limit indicators
- if errors.Is(err, io.EOF) {
- return false, 0, err
- }
- errMsg := err.Error()
- isRateLimit := false
- // Check for common rate limit error messages
- if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
- isRateLimit = true
- }
- if !isRateLimit {
- return false, 0, err
- }
- // Calculate backoff with jitter
- backoffMs := 2000 * (1 << (attempts - 1))
- jitterMs := int(float64(backoffMs) * 0.2)
- retryMs := backoffMs + jitterMs
- return true, int64(retryMs), nil
- }
- func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
- var toolCalls []message.ToolCall
- if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
- for _, part := range resp.Candidates[0].Content.Parts {
- if part.FunctionCall != nil {
- id := "call_" + uuid.New().String()
- args, _ := json.Marshal(part.FunctionCall.Args)
- toolCalls = append(toolCalls, message.ToolCall{
- ID: id,
- Name: part.FunctionCall.Name,
- Input: string(args),
- Type: "function",
- })
- }
- }
- }
- return toolCalls
- }
- func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
- if resp == nil || resp.UsageMetadata == nil {
- return TokenUsage{}
- }
- return TokenUsage{
- InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
- OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
- CacheCreationTokens: 0, // Not directly provided by Gemini
- CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
- }
- }
- func WithGeminiDisableCache() GeminiOption {
- return func(options *geminiOptions) {
- options.disableCache = true
- }
- }
- // Helper functions
- func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
- var result map[string]interface{}
- err := json.Unmarshal([]byte(jsonStr), &result)
- return result, err
- }
- func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
- properties := make(map[string]*genai.Schema)
- for name, param := range parameters {
- properties[name] = convertToSchema(param)
- }
- return properties
- }
- func convertToSchema(param interface{}) *genai.Schema {
- schema := &genai.Schema{Type: genai.TypeString}
- paramMap, ok := param.(map[string]interface{})
- if !ok {
- return schema
- }
- if desc, ok := paramMap["description"].(string); ok {
- schema.Description = desc
- }
- typeVal, hasType := paramMap["type"]
- if !hasType {
- return schema
- }
- typeStr, ok := typeVal.(string)
- if !ok {
- return schema
- }
- schema.Type = mapJSONTypeToGenAI(typeStr)
- switch typeStr {
- case "array":
- schema.Items = processArrayItems(paramMap)
- case "object":
- if props, ok := paramMap["properties"].(map[string]interface{}); ok {
- schema.Properties = convertSchemaProperties(props)
- }
- }
- return schema
- }
- func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
- items, ok := paramMap["items"].(map[string]interface{})
- if !ok {
- return nil
- }
- return convertToSchema(items)
- }
- func mapJSONTypeToGenAI(jsonType string) genai.Type {
- switch jsonType {
- case "string":
- return genai.TypeString
- case "number":
- return genai.TypeNumber
- case "integer":
- return genai.TypeInteger
- case "boolean":
- return genai.TypeBoolean
- case "array":
- return genai.TypeArray
- case "object":
- return genai.TypeObject
- default:
- return genai.TypeString // Default to string for unknown types
- }
- }
- func contains(s string, substrs ...string) bool {
- for _, substr := range substrs {
- if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
- return true
- }
- }
- return false
- }
|