|
|
@@ -9,14 +9,12 @@ import (
|
|
|
"strings"
|
|
|
"time"
|
|
|
|
|
|
- "github.com/google/generative-ai-go/genai"
|
|
|
"github.com/google/uuid"
|
|
|
"github.com/opencode-ai/opencode/internal/config"
|
|
|
"github.com/opencode-ai/opencode/internal/llm/tools"
|
|
|
"github.com/opencode-ai/opencode/internal/message"
|
|
|
"github.com/opencode-ai/opencode/internal/status"
|
|
|
- "google.golang.org/api/iterator"
|
|
|
- "google.golang.org/api/option"
|
|
|
+ "google.golang.org/genai"
|
|
|
"log/slog"
|
|
|
)
|
|
|
|
|
|
@@ -40,7 +38,7 @@ func newGeminiClient(opts providerClientOptions) GeminiClient {
|
|
|
o(&geminiOpts)
|
|
|
}
|
|
|
|
|
|
- client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey))
|
|
|
+ client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
|
|
|
if err != nil {
|
|
|
slog.Error("Failed to create Gemini client", "error", err)
|
|
|
return nil
|
|
|
@@ -58,11 +56,14 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont
|
|
|
for _, msg := range messages {
|
|
|
switch msg.Role {
|
|
|
case message.User:
|
|
|
- var parts []genai.Part
|
|
|
- parts = append(parts, genai.Text(msg.Content().String()))
|
|
|
+ var parts []*genai.Part
|
|
|
+ parts = append(parts, &genai.Part{Text: msg.Content().String()})
|
|
|
for _, binaryContent := range msg.BinaryContent() {
|
|
|
imageFormat := strings.Split(binaryContent.MIMEType, "/")
|
|
|
- parts = append(parts, genai.ImageData(imageFormat[1], binaryContent.Data))
|
|
|
+ parts = append(parts, &genai.Part{InlineData: &genai.Blob{
|
|
|
+ MIMEType: imageFormat[1],
|
|
|
+ Data: binaryContent.Data,
|
|
|
+ }})
|
|
|
}
|
|
|
history = append(history, &genai.Content{
|
|
|
Parts: parts,
|
|
|
@@ -71,19 +72,21 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont
|
|
|
case message.Assistant:
|
|
|
content := &genai.Content{
|
|
|
Role: "model",
|
|
|
- Parts: []genai.Part{},
|
|
|
+ Parts: []*genai.Part{},
|
|
|
}
|
|
|
|
|
|
if msg.Content().String() != "" {
|
|
|
- content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
|
|
|
+ content.Parts = append(content.Parts, &genai.Part{Text: msg.Content().String()})
|
|
|
}
|
|
|
|
|
|
if len(msg.ToolCalls()) > 0 {
|
|
|
for _, call := range msg.ToolCalls() {
|
|
|
args, _ := parseJsonToMap(call.Input)
|
|
|
- content.Parts = append(content.Parts, genai.FunctionCall{
|
|
|
- Name: call.Name,
|
|
|
- Args: args,
|
|
|
+ content.Parts = append(content.Parts, &genai.Part{
|
|
|
+ FunctionCall: &genai.FunctionCall{
|
|
|
+ Name: call.Name,
|
|
|
+ Args: args,
|
|
|
+ },
|
|
|
})
|
|
|
}
|
|
|
}
|
|
|
@@ -111,10 +114,14 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont
|
|
|
}
|
|
|
|
|
|
history = append(history, &genai.Content{
|
|
|
- Parts: []genai.Part{genai.FunctionResponse{
|
|
|
- Name: toolCall.Name,
|
|
|
- Response: response,
|
|
|
- }},
|
|
|
+ Parts: []*genai.Part{
|
|
|
+ {
|
|
|
+ FunctionResponse: &genai.FunctionResponse{
|
|
|
+ Name: toolCall.Name,
|
|
|
+ Response: response,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
Role: "function",
|
|
|
})
|
|
|
}
|
|
|
@@ -158,18 +165,6 @@ func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishRea
|
|
|
}
|
|
|
|
|
|
func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
|
|
- model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
|
|
|
- model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
|
|
|
- model.SystemInstruction = &genai.Content{
|
|
|
- Parts: []genai.Part{
|
|
|
- genai.Text(g.providerOptions.systemMessage),
|
|
|
- },
|
|
|
- }
|
|
|
- // Convert tools
|
|
|
- if len(tools) > 0 {
|
|
|
- model.Tools = g.convertTools(tools)
|
|
|
- }
|
|
|
-
|
|
|
// Convert messages
|
|
|
geminiMessages := g.convertMessages(messages)
|
|
|
|
|
|
@@ -179,16 +174,26 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
|
|
|
slog.Debug("Prepared messages", "messages", string(jsonData))
|
|
|
}
|
|
|
|
|
|
+ history := geminiMessages[:len(geminiMessages)-1] // All but last message
|
|
|
+ lastMsg := geminiMessages[len(geminiMessages)-1]
|
|
|
+ chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{
|
|
|
+ MaxOutputTokens: int32(g.providerOptions.maxTokens),
|
|
|
+ SystemInstruction: &genai.Content{
|
|
|
+ Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
|
|
|
+ },
|
|
|
+ Tools: g.convertTools(tools),
|
|
|
+ }, history)
|
|
|
+
|
|
|
attempts := 0
|
|
|
for {
|
|
|
attempts++
|
|
|
var toolCalls []message.ToolCall
|
|
|
- chat := model.StartChat()
|
|
|
- chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
|
|
|
-
|
|
|
- lastMsg := geminiMessages[len(geminiMessages)-1]
|
|
|
|
|
|
- resp, err := chat.SendMessage(ctx, lastMsg.Parts...)
|
|
|
+ var lastMsgParts []genai.Part
|
|
|
+ for _, part := range lastMsg.Parts {
|
|
|
+ lastMsgParts = append(lastMsgParts, *part)
|
|
|
+ }
|
|
|
+ resp, err := chat.SendMessage(ctx, lastMsgParts...)
|
|
|
// If there is an error we are going to see if we can retry the call
|
|
|
if err != nil {
|
|
|
retry, after, retryErr := g.shouldRetry(attempts, err)
|
|
|
@@ -211,15 +216,15 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
|
|
|
|
|
|
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
|
|
for _, part := range resp.Candidates[0].Content.Parts {
|
|
|
- switch p := part.(type) {
|
|
|
- case genai.Text:
|
|
|
- content = string(p)
|
|
|
- case genai.FunctionCall:
|
|
|
+ switch {
|
|
|
+ case part.Text != "":
|
|
|
+ content = string(part.Text)
|
|
|
+ case part.FunctionCall != nil:
|
|
|
id := "call_" + uuid.New().String()
|
|
|
- args, _ := json.Marshal(p.Args)
|
|
|
+ args, _ := json.Marshal(part.FunctionCall.Args)
|
|
|
toolCalls = append(toolCalls, message.ToolCall{
|
|
|
ID: id,
|
|
|
- Name: p.Name,
|
|
|
+ Name: part.FunctionCall.Name,
|
|
|
Input: string(args),
|
|
|
Type: "function",
|
|
|
Finished: true,
|
|
|
@@ -245,18 +250,6 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
|
|
|
}
|
|
|
|
|
|
func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
|
|
|
- model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
|
|
|
- model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
|
|
|
- model.SystemInstruction = &genai.Content{
|
|
|
- Parts: []genai.Part{
|
|
|
- genai.Text(g.providerOptions.systemMessage),
|
|
|
- },
|
|
|
- }
|
|
|
- // Convert tools
|
|
|
- if len(tools) > 0 {
|
|
|
- model.Tools = g.convertTools(tools)
|
|
|
- }
|
|
|
-
|
|
|
// Convert messages
|
|
|
geminiMessages := g.convertMessages(messages)
|
|
|
|
|
|
@@ -266,6 +259,16 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
|
|
slog.Debug("Prepared messages", "messages", string(jsonData))
|
|
|
}
|
|
|
|
|
|
+ history := geminiMessages[:len(geminiMessages)-1] // All but last message
|
|
|
+ lastMsg := geminiMessages[len(geminiMessages)-1]
|
|
|
+ chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{
|
|
|
+ MaxOutputTokens: int32(g.providerOptions.maxTokens),
|
|
|
+ SystemInstruction: &genai.Content{
|
|
|
+ Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
|
|
|
+ },
|
|
|
+ Tools: g.convertTools(tools),
|
|
|
+ }, history)
|
|
|
+
|
|
|
attempts := 0
|
|
|
eventChan := make(chan ProviderEvent)
|
|
|
|
|
|
@@ -274,11 +277,6 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
|
|
|
|
|
for {
|
|
|
attempts++
|
|
|
- chat := model.StartChat()
|
|
|
- chat.History = geminiMessages[:len(geminiMessages)-1]
|
|
|
- lastMsg := geminiMessages[len(geminiMessages)-1]
|
|
|
-
|
|
|
- iter := chat.SendMessageStream(ctx, lastMsg.Parts...)
|
|
|
|
|
|
currentContent := ""
|
|
|
toolCalls := []message.ToolCall{}
|
|
|
@@ -286,11 +284,12 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
|
|
|
|
|
eventChan <- ProviderEvent{Type: EventContentStart}
|
|
|
|
|
|
- for {
|
|
|
- resp, err := iter.Next()
|
|
|
- if err == iterator.Done {
|
|
|
- break
|
|
|
- }
|
|
|
+ var lastMsgParts []genai.Part
|
|
|
+
|
|
|
+ for _, part := range lastMsg.Parts {
|
|
|
+ lastMsgParts = append(lastMsgParts, *part)
|
|
|
+ }
|
|
|
+ for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
|
|
|
if err != nil {
|
|
|
retry, after, retryErr := g.shouldRetry(attempts, err)
|
|
|
if retryErr != nil {
|
|
|
@@ -319,9 +318,9 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
|
|
|
|
|
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
|
|
for _, part := range resp.Candidates[0].Content.Parts {
|
|
|
- switch p := part.(type) {
|
|
|
- case genai.Text:
|
|
|
- delta := string(p)
|
|
|
+ switch {
|
|
|
+ case part.Text != "":
|
|
|
+ delta := string(part.Text)
|
|
|
if delta != "" {
|
|
|
eventChan <- ProviderEvent{
|
|
|
Type: EventContentDelta,
|
|
|
@@ -329,12 +328,12 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
|
|
}
|
|
|
currentContent += delta
|
|
|
}
|
|
|
- case genai.FunctionCall:
|
|
|
+ case part.FunctionCall != nil:
|
|
|
id := "call_" + uuid.New().String()
|
|
|
- args, _ := json.Marshal(p.Args)
|
|
|
+ args, _ := json.Marshal(part.FunctionCall.Args)
|
|
|
newCall := message.ToolCall{
|
|
|
ID: id,
|
|
|
- Name: p.Name,
|
|
|
+ Name: part.FunctionCall.Name,
|
|
|
Input: string(args),
|
|
|
Type: "function",
|
|
|
Finished: true,
|
|
|
@@ -422,12 +421,12 @@ func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.
|
|
|
|
|
|
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
|
|
for _, part := range resp.Candidates[0].Content.Parts {
|
|
|
- if funcCall, ok := part.(genai.FunctionCall); ok {
|
|
|
+ if part.FunctionCall != nil {
|
|
|
id := "call_" + uuid.New().String()
|
|
|
- args, _ := json.Marshal(funcCall.Args)
|
|
|
+ args, _ := json.Marshal(part.FunctionCall.Args)
|
|
|
toolCalls = append(toolCalls, message.ToolCall{
|
|
|
ID: id,
|
|
|
- Name: funcCall.Name,
|
|
|
+ Name: part.FunctionCall.Name,
|
|
|
Input: string(args),
|
|
|
Type: "function",
|
|
|
})
|