|
@@ -54,19 +54,6 @@ func newGeminiClient(opts providerClientOptions) GeminiClient {
|
|
|
|
|
|
|
|
func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
|
|
func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
|
|
|
var history []*genai.Content
|
|
var history []*genai.Content
|
|
|
-
|
|
|
|
|
- // Add system message first
|
|
|
|
|
- history = append(history, &genai.Content{
|
|
|
|
|
- Parts: []genai.Part{genai.Text(g.providerOptions.systemMessage)},
|
|
|
|
|
- Role: "user",
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- // Add a system response to acknowledge the system message
|
|
|
|
|
- history = append(history, &genai.Content{
|
|
|
|
|
- Parts: []genai.Part{genai.Text("I'll help you with that.")},
|
|
|
|
|
- Role: "model",
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
for _, msg := range messages {
|
|
for _, msg := range messages {
|
|
|
switch msg.Role {
|
|
switch msg.Role {
|
|
|
case message.User:
|
|
case message.User:
|
|
@@ -154,14 +141,11 @@ func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
|
|
func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
|
|
|
- reasonStr := reason.String()
|
|
|
|
|
switch {
|
|
switch {
|
|
|
- case reasonStr == "STOP":
|
|
|
|
|
|
|
+ case reason == genai.FinishReasonStop:
|
|
|
return message.FinishReasonEndTurn
|
|
return message.FinishReasonEndTurn
|
|
|
- case reasonStr == "MAX_TOKENS":
|
|
|
|
|
|
|
+ case reason == genai.FinishReasonMaxTokens:
|
|
|
return message.FinishReasonMaxTokens
|
|
return message.FinishReasonMaxTokens
|
|
|
- case strings.Contains(reasonStr, "FUNCTION") || strings.Contains(reasonStr, "TOOL"):
|
|
|
|
|
- return message.FinishReasonToolUse
|
|
|
|
|
default:
|
|
default:
|
|
|
return message.FinishReasonUnknown
|
|
return message.FinishReasonUnknown
|
|
|
}
|
|
}
|
|
@@ -170,7 +154,11 @@ func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishRea
|
|
|
func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
|
func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
|
|
model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
|
|
model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
|
|
|
model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
|
|
model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
|
|
|
-
|
|
|
|
|
|
|
+ model.SystemInstruction = &genai.Content{
|
|
|
|
|
+ Parts: []genai.Part{
|
|
|
|
|
+ genai.Text(g.providerOptions.systemMessage),
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
// Convert tools
|
|
// Convert tools
|
|
|
if len(tools) > 0 {
|
|
if len(tools) > 0 {
|
|
|
model.Tools = g.convertTools(tools)
|
|
model.Tools = g.convertTools(tools)
|
|
@@ -188,19 +176,13 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
|
|
|
attempts := 0
|
|
attempts := 0
|
|
|
for {
|
|
for {
|
|
|
attempts++
|
|
attempts++
|
|
|
|
|
+ var toolCalls []message.ToolCall
|
|
|
chat := model.StartChat()
|
|
chat := model.StartChat()
|
|
|
chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
|
|
chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
|
|
|
|
|
|
|
|
lastMsg := geminiMessages[len(geminiMessages)-1]
|
|
lastMsg := geminiMessages[len(geminiMessages)-1]
|
|
|
- var lastText string
|
|
|
|
|
- for _, part := range lastMsg.Parts {
|
|
|
|
|
- if text, ok := part.(genai.Text); ok {
|
|
|
|
|
- lastText = string(text)
|
|
|
|
|
- break
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
|
|
|
|
|
- resp, err := chat.SendMessage(ctx, genai.Text(lastText))
|
|
|
|
|
|
|
+ resp, err := chat.SendMessage(ctx, lastMsg.Parts...)
|
|
|
// If there is an error we are going to see if we can retry the call
|
|
// If there is an error we are going to see if we can retry the call
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
retry, after, retryErr := g.shouldRetry(attempts, err)
|
|
retry, after, retryErr := g.shouldRetry(attempts, err)
|
|
@@ -220,7 +202,6 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
content := ""
|
|
content := ""
|
|
|
- var toolCalls []message.ToolCall
|
|
|
|
|
|
|
|
|
|
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
|
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
|
|
for _, part := range resp.Candidates[0].Content.Parts {
|
|
for _, part := range resp.Candidates[0].Content.Parts {
|
|
@@ -231,20 +212,25 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
|
|
|
id := "call_" + uuid.New().String()
|
|
id := "call_" + uuid.New().String()
|
|
|
args, _ := json.Marshal(p.Args)
|
|
args, _ := json.Marshal(p.Args)
|
|
|
toolCalls = append(toolCalls, message.ToolCall{
|
|
toolCalls = append(toolCalls, message.ToolCall{
|
|
|
- ID: id,
|
|
|
|
|
- Name: p.Name,
|
|
|
|
|
- Input: string(args),
|
|
|
|
|
- Type: "function",
|
|
|
|
|
|
|
+ ID: id,
|
|
|
|
|
+ Name: p.Name,
|
|
|
|
|
+ Input: string(args),
|
|
|
|
|
+ Type: "function",
|
|
|
|
|
+ Finished: true,
|
|
|
})
|
|
})
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
+ finishReason := g.finishReason(resp.Candidates[0].FinishReason)
|
|
|
|
|
+ if len(toolCalls) > 0 {
|
|
|
|
|
+ finishReason = message.FinishReasonToolUse
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
return &ProviderResponse{
|
|
return &ProviderResponse{
|
|
|
Content: content,
|
|
Content: content,
|
|
|
ToolCalls: toolCalls,
|
|
ToolCalls: toolCalls,
|
|
|
Usage: g.usage(resp),
|
|
Usage: g.usage(resp),
|
|
|
- FinishReason: g.finishReason(resp.Candidates[0].FinishReason),
|
|
|
|
|
|
|
+ FinishReason: finishReason,
|
|
|
}, nil
|
|
}, nil
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -252,7 +238,11 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
|
|
|
func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
|
|
func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
|
|
|
model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
|
|
model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
|
|
|
model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
|
|
model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
|
|
|
-
|
|
|
|
|
|
|
+ model.SystemInstruction = &genai.Content{
|
|
|
|
|
+ Parts: []genai.Part{
|
|
|
|
|
+ genai.Text(g.providerOptions.systemMessage),
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
// Convert tools
|
|
// Convert tools
|
|
|
if len(tools) > 0 {
|
|
if len(tools) > 0 {
|
|
|
model.Tools = g.convertTools(tools)
|
|
model.Tools = g.convertTools(tools)
|
|
@@ -276,18 +266,10 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
|
|
for {
|
|
for {
|
|
|
attempts++
|
|
attempts++
|
|
|
chat := model.StartChat()
|
|
chat := model.StartChat()
|
|
|
- chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
|
|
|
|
|
-
|
|
|
|
|
|
|
+ chat.History = geminiMessages[:len(geminiMessages)-1]
|
|
|
lastMsg := geminiMessages[len(geminiMessages)-1]
|
|
lastMsg := geminiMessages[len(geminiMessages)-1]
|
|
|
- var lastText string
|
|
|
|
|
- for _, part := range lastMsg.Parts {
|
|
|
|
|
- if text, ok := part.(genai.Text); ok {
|
|
|
|
|
- lastText = string(text)
|
|
|
|
|
- break
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
|
|
|
|
|
- iter := chat.SendMessageStream(ctx, genai.Text(lastText))
|
|
|
|
|
|
|
+ iter := chat.SendMessageStream(ctx, lastMsg.Parts...)
|
|
|
|
|
|
|
|
currentContent := ""
|
|
currentContent := ""
|
|
|
toolCalls := []message.ToolCall{}
|
|
toolCalls := []message.ToolCall{}
|
|
@@ -330,23 +312,23 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
|
|
for _, part := range resp.Candidates[0].Content.Parts {
|
|
for _, part := range resp.Candidates[0].Content.Parts {
|
|
|
switch p := part.(type) {
|
|
switch p := part.(type) {
|
|
|
case genai.Text:
|
|
case genai.Text:
|
|
|
- newText := string(p)
|
|
|
|
|
- delta := newText[len(currentContent):]
|
|
|
|
|
|
|
+ delta := string(p)
|
|
|
if delta != "" {
|
|
if delta != "" {
|
|
|
eventChan <- ProviderEvent{
|
|
eventChan <- ProviderEvent{
|
|
|
Type: EventContentDelta,
|
|
Type: EventContentDelta,
|
|
|
Content: delta,
|
|
Content: delta,
|
|
|
}
|
|
}
|
|
|
- currentContent = newText
|
|
|
|
|
|
|
+ currentContent += delta
|
|
|
}
|
|
}
|
|
|
case genai.FunctionCall:
|
|
case genai.FunctionCall:
|
|
|
id := "call_" + uuid.New().String()
|
|
id := "call_" + uuid.New().String()
|
|
|
args, _ := json.Marshal(p.Args)
|
|
args, _ := json.Marshal(p.Args)
|
|
|
newCall := message.ToolCall{
|
|
newCall := message.ToolCall{
|
|
|
- ID: id,
|
|
|
|
|
- Name: p.Name,
|
|
|
|
|
- Input: string(args),
|
|
|
|
|
- Type: "function",
|
|
|
|
|
|
|
+ ID: id,
|
|
|
|
|
+ Name: p.Name,
|
|
|
|
|
+ Input: string(args),
|
|
|
|
|
+ Type: "function",
|
|
|
|
|
+ Finished: true,
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
isNew := true
|
|
isNew := true
|
|
@@ -368,37 +350,22 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
|
|
eventChan <- ProviderEvent{Type: EventContentStop}
|
|
eventChan <- ProviderEvent{Type: EventContentStop}
|
|
|
|
|
|
|
|
if finalResp != nil {
|
|
if finalResp != nil {
|
|
|
|
|
+ finishReason := g.finishReason(finalResp.Candidates[0].FinishReason)
|
|
|
|
|
+ if len(toolCalls) > 0 {
|
|
|
|
|
+ finishReason = message.FinishReasonToolUse
|
|
|
|
|
+ }
|
|
|
eventChan <- ProviderEvent{
|
|
eventChan <- ProviderEvent{
|
|
|
Type: EventComplete,
|
|
Type: EventComplete,
|
|
|
Response: &ProviderResponse{
|
|
Response: &ProviderResponse{
|
|
|
Content: currentContent,
|
|
Content: currentContent,
|
|
|
ToolCalls: toolCalls,
|
|
ToolCalls: toolCalls,
|
|
|
Usage: g.usage(finalResp),
|
|
Usage: g.usage(finalResp),
|
|
|
- FinishReason: g.finishReason(finalResp.Candidates[0].FinishReason),
|
|
|
|
|
|
|
+ FinishReason: finishReason,
|
|
|
},
|
|
},
|
|
|
}
|
|
}
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // If we get here, we need to retry
|
|
|
|
|
- if attempts > maxRetries {
|
|
|
|
|
- eventChan <- ProviderEvent{
|
|
|
|
|
- Type: EventError,
|
|
|
|
|
- Error: fmt.Errorf("maximum retry attempts reached: %d retries", maxRetries),
|
|
|
|
|
- }
|
|
|
|
|
- return
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Wait before retrying
|
|
|
|
|
- select {
|
|
|
|
|
- case <-ctx.Done():
|
|
|
|
|
- if ctx.Err() != nil {
|
|
|
|
|
- eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
|
|
|
|
|
- }
|
|
|
|
|
- return
|
|
|
|
|
- case <-time.After(time.Duration(2000*(1<<(attempts-1))) * time.Millisecond):
|
|
|
|
|
- continue
|
|
|
|
|
- }
|
|
|
|
|
}
|
|
}
|
|
|
}()
|
|
}()
|
|
|
|
|
|