|
|
@@ -4,6 +4,7 @@ import (
|
|
|
"bufio"
|
|
|
"bytes"
|
|
|
"context"
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"net/http"
|
|
|
@@ -25,14 +26,11 @@ import (
|
|
|
relaymodel "github.com/labring/aiproxy/core/relay/model"
|
|
|
"github.com/labring/aiproxy/core/relay/utils"
|
|
|
log "github.com/sirupsen/logrus"
|
|
|
+ "golang.org/x/sync/semaphore"
|
|
|
)
|
|
|
|
|
|
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
|
|
|
|
|
|
-const (
|
|
|
- VisionMaxImageNum = 16
|
|
|
-)
|
|
|
-
|
|
|
var toolChoiceTypeMap = map[string]string{
|
|
|
"none": "NONE",
|
|
|
"auto": "AUTO",
|
|
|
@@ -146,37 +144,27 @@ func buildToolConfig(textRequest *relaymodel.GeneralOpenAIRequest) *ToolConfig {
|
|
|
return &toolConfig
|
|
|
}
|
|
|
|
|
|
-func buildMessageParts(ctx context.Context, part relaymodel.MessageContent) ([]Part, error) {
|
|
|
- if part.Type == relaymodel.ContentTypeText {
|
|
|
- return []Part{{Text: part.Text}}, nil
|
|
|
+func buildMessageParts(message relaymodel.MessageContent) *Part {
|
|
|
+ part := &Part{
|
|
|
+ Text: message.Text,
|
|
|
}
|
|
|
-
|
|
|
- if part.Type == relaymodel.ContentTypeImageURL {
|
|
|
- mimeType, data, err := image.GetImageFromURL(ctx, part.ImageURL.URL)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
+ if message.ImageURL != nil {
|
|
|
+ part.InlineData = &InlineData{
|
|
|
+ Data: message.ImageURL.URL,
|
|
|
}
|
|
|
- return []Part{{
|
|
|
- InlineData: &InlineData{
|
|
|
- MimeType: mimeType,
|
|
|
- Data: data,
|
|
|
- },
|
|
|
- }}, nil
|
|
|
}
|
|
|
-
|
|
|
- return nil, nil
|
|
|
+ return part
|
|
|
}
|
|
|
|
|
|
-func buildContents(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest) (*ChatContent, []*ChatContent, error) {
|
|
|
+func buildContents(textRequest *relaymodel.GeneralOpenAIRequest) (*ChatContent, []*ChatContent, []*Part) {
|
|
|
contents := make([]*ChatContent, 0, len(textRequest.Messages))
|
|
|
- imageNum := 0
|
|
|
+ var imageTasks []*Part
|
|
|
|
|
|
var systemContent *ChatContent
|
|
|
|
|
|
for _, message := range textRequest.Messages {
|
|
|
content := ChatContent{
|
|
|
- Role: message.Role,
|
|
|
- Parts: make([]Part, 0),
|
|
|
+ Role: message.Role,
|
|
|
}
|
|
|
|
|
|
switch {
|
|
|
@@ -190,7 +178,7 @@ func buildContents(ctx context.Context, textRequest *relaymodel.GeneralOpenAIReq
|
|
|
} else {
|
|
|
args = make(map[string]any)
|
|
|
}
|
|
|
- content.Parts = append(content.Parts, Part{
|
|
|
+ content.Parts = append(content.Parts, &Part{
|
|
|
FunctionCall: &FunctionCall{
|
|
|
Name: toolCall.Function.Name,
|
|
|
Args: args,
|
|
|
@@ -211,7 +199,7 @@ func buildContents(ctx context.Context, textRequest *relaymodel.GeneralOpenAIReq
|
|
|
} else {
|
|
|
contentMap = make(map[string]any)
|
|
|
}
|
|
|
- content.Parts = append(content.Parts, Part{
|
|
|
+ content.Parts = append(content.Parts, &Part{
|
|
|
FunctionResponse: &FunctionResponse{
|
|
|
Name: *message.Name,
|
|
|
Response: struct {
|
|
|
@@ -226,18 +214,11 @@ func buildContents(ctx context.Context, textRequest *relaymodel.GeneralOpenAIReq
|
|
|
default:
|
|
|
openaiContent := message.ParseContent()
|
|
|
for _, part := range openaiContent {
|
|
|
- if part.Type == relaymodel.ContentTypeImageURL {
|
|
|
- imageNum++
|
|
|
- if imageNum > VisionMaxImageNum {
|
|
|
- continue
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- parts, err := buildMessageParts(ctx, part)
|
|
|
- if err != nil {
|
|
|
- return nil, nil, err
|
|
|
+ part := buildMessageParts(part)
|
|
|
+ if part.InlineData != nil {
|
|
|
+ imageTasks = append(imageTasks, part)
|
|
|
}
|
|
|
- content.Parts = append(content.Parts, parts...)
|
|
|
+ content.Parts = append(content.Parts, part)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -253,7 +234,48 @@ func buildContents(ctx context.Context, textRequest *relaymodel.GeneralOpenAIReq
|
|
|
contents = append(contents, &content)
|
|
|
}
|
|
|
|
|
|
- return systemContent, contents, nil
|
|
|
+ return systemContent, contents, imageTasks
|
|
|
+}
|
|
|
+
|
|
|
+func processImageTasks(ctx context.Context, imageTasks []*Part) error {
|
|
|
+ if len(imageTasks) == 0 {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ sem := semaphore.NewWeighted(3)
|
|
|
+ var wg sync.WaitGroup
|
|
|
+ var mu sync.Mutex
|
|
|
+ var processErrs []error
|
|
|
+
|
|
|
+ for _, task := range imageTasks {
|
|
|
+ if task.InlineData == nil || task.InlineData.Data == "" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ wg.Add(1)
|
|
|
+ go func() {
|
|
|
+ defer wg.Done()
|
|
|
+ _ = sem.Acquire(ctx, 1)
|
|
|
+ defer sem.Release(1)
|
|
|
+
|
|
|
+ mimeType, data, err := image.GetImageFromURL(ctx, task.InlineData.Data)
|
|
|
+ if err != nil {
|
|
|
+ mu.Lock()
|
|
|
+ processErrs = append(processErrs, err)
|
|
|
+ mu.Unlock()
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ task.InlineData.MimeType = mimeType
|
|
|
+ task.InlineData.Data = data
|
|
|
+ }()
|
|
|
+ }
|
|
|
+
|
|
|
+ wg.Wait()
|
|
|
+
|
|
|
+ if len(processErrs) != 0 {
|
|
|
+ return errors.Join(processErrs...)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
|
|
@@ -266,9 +288,13 @@ func ConvertRequest(meta *meta.Meta, req *http.Request) (string, http.Header, io
|
|
|
textRequest.Model = meta.ActualModel
|
|
|
meta.Set("stream", textRequest.Stream)
|
|
|
|
|
|
- systemContent, contents, err := buildContents(req.Context(), textRequest)
|
|
|
- if err != nil {
|
|
|
- return "", nil, nil, err
|
|
|
+ systemContent, contents, imageTasks := buildContents(textRequest)
|
|
|
+
|
|
|
+ // Process image tasks concurrently
|
|
|
+ if len(imageTasks) > 0 {
|
|
|
+ if err := processImageTasks(req.Context(), imageTasks); err != nil {
|
|
|
+ return "", nil, nil, err
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
config, err := buildGenerationConfig(meta, req, textRequest)
|
|
|
@@ -418,7 +444,7 @@ func responseChat2OpenAI(meta *meta.Meta, response *ChatResponse) *relaymodel.Te
|
|
|
}
|
|
|
for _, part := range candidate.Content.Parts {
|
|
|
if part.FunctionCall != nil {
|
|
|
- toolCall, err := getToolCall(&part)
|
|
|
+ toolCall, err := getToolCall(part)
|
|
|
if err != nil {
|
|
|
log.Error("get tool call failed: " + err.Error())
|
|
|
}
|
|
|
@@ -496,7 +522,7 @@ func streamResponseChat2OpenAI(meta *meta.Meta, geminiResponse *ChatResponse) *r
|
|
|
}
|
|
|
for _, part := range candidate.Content.Parts {
|
|
|
if part.FunctionCall != nil {
|
|
|
- toolCall, err := getToolCall(&part)
|
|
|
+ toolCall, err := getToolCall(part)
|
|
|
if err != nil {
|
|
|
log.Error("get tool call failed: " + err.Error())
|
|
|
}
|