Просмотр исходного кода

feat: batch patch image url to base64 (#152)

* feat: batch patch image url to base64

* fix: ci
zijiren 8 месяцев назад
Родитель
Сommit
607db88d58

+ 54 - 12
core/relay/adaptor/anthropic/main.go

@@ -4,9 +4,11 @@ import (
 	"bufio"
 	"bytes"
 	"context"
+	"errors"
 	"io"
 	"net/http"
 	"strings"
+	"sync"
 
 	"github.com/bytedance/sonic"
 	"github.com/bytedance/sonic/ast"
@@ -19,6 +21,7 @@ import (
 	"github.com/labring/aiproxy/core/relay/adaptor/openai"
 	"github.com/labring/aiproxy/core/relay/meta"
 	relaymodel "github.com/labring/aiproxy/core/relay/model"
+	"golang.org/x/sync/semaphore"
 )
 
 func ConvertRequest(meta *meta.Meta, req *http.Request) (string, http.Header, io.Reader, error) {
@@ -56,7 +59,8 @@ func ConvertImage2Base64(ctx context.Context, node *ast.Node) error {
 		return nil
 	}
 
-	return messagesNode.ForEach(func(_ ast.Sequence, msgNode *ast.Node) bool {
+	var imageItems []*ast.Node
+	err := messagesNode.ForEach(func(_ ast.Sequence, msgNode *ast.Node) bool {
 		contentNode := msgNode.Get("content")
 		if contentNode == nil || contentNode.TypeSafe() != ast.V_ARRAY {
 			return true
@@ -65,34 +69,70 @@ func ConvertImage2Base64(ctx context.Context, node *ast.Node) error {
 		err := contentNode.ForEach(func(_ ast.Sequence, contentItem *ast.Node) bool {
 			contentType, err := contentItem.Get("type").String()
 			if err == nil && contentType == conetentTypeImage {
-				convertImageURLToBase64(ctx, contentItem)
+				sourceNode := contentItem.Get("source")
+				if sourceNode != nil {
+					imageType, err := sourceNode.Get("type").String()
+					if err == nil && imageType == "url" {
+						imageItems = append(imageItems, contentItem)
+					}
+				}
 			}
 			return true
 		})
 		return err == nil
 	})
+	if err != nil {
+		return err
+	}
+
+	if len(imageItems) == 0 {
+		return nil
+	}
+
+	sem := semaphore.NewWeighted(3)
+	var wg sync.WaitGroup
+	var mu sync.Mutex
+	var processErrs []error
+
+	for _, item := range imageItems {
+		wg.Add(1)
+		go func(contentItem *ast.Node) {
+			defer wg.Done()
+			_ = sem.Acquire(ctx, 1)
+			defer sem.Release(1)
+
+			err := convertImageURLToBase64(ctx, contentItem)
+			if err != nil {
+				mu.Lock()
+				processErrs = append(processErrs, err)
+				mu.Unlock()
+			}
+		}(item)
+	}
+
+	wg.Wait()
+
+	if len(processErrs) != 0 {
+		return errors.Join(processErrs...)
+	}
+	return nil
 }
 
 // convertImageURLToBase64 converts an image URL to base64 encoded data
-func convertImageURLToBase64(ctx context.Context, contentItem *ast.Node) {
+func convertImageURLToBase64(ctx context.Context, contentItem *ast.Node) error {
 	sourceNode := contentItem.Get("source")
 	if sourceNode == nil {
-		return
-	}
-
-	imageType, err := sourceNode.Get("type").String()
-	if err != nil || imageType != "url" {
-		return
+		return nil
 	}
 
 	url, err := sourceNode.Get("url").String()
 	if err != nil {
-		return
+		return nil
 	}
 
 	mimeType, data, err := image.GetImageFromURL(ctx, url)
 	if err != nil {
-		return
+		return nil
 	}
 
 	patches := []func() (bool, error){
@@ -104,9 +144,11 @@ func convertImageURLToBase64(ctx context.Context, contentItem *ast.Node) {
 
 	for _, patch := range patches {
 		if _, err := patch(); err != nil {
-			return
+			return err
 		}
 	}
+
+	return nil
 }
 
 func StreamHandler(m *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage, *relaymodel.ErrorWithStatusCode) {

+ 5 - 4
core/relay/adaptor/anthropic/model.go

@@ -38,8 +38,9 @@ type Metadata struct {
 
 type ImageSource struct {
 	Type      string `json:"type"`
-	MediaType string `json:"media_type"`
-	Data      string `json:"data"`
+	MediaType string `json:"media_type,omitempty"`
+	Data      string `json:"data,omitempty"`
+	URL       string `json:"url,omitempty"`
 }
 
 type Content struct {
@@ -56,8 +57,8 @@ type Content struct {
 }
 
 type Message struct {
-	Role    string    `json:"role"`
-	Content []Content `json:"content"`
+	Role    string     `json:"role"`
+	Content []*Content `json:"content"`
 }
 
 type Tool struct {

+ 61 - 12
core/relay/adaptor/anthropic/openai.go

@@ -2,8 +2,11 @@ package anthropic
 
 import (
 	"bufio"
+	"context"
+	"errors"
 	"net/http"
 	"strings"
+	"sync"
 	"time"
 
 	"github.com/bytedance/sonic"
@@ -17,6 +20,7 @@ import (
 	"github.com/labring/aiproxy/core/relay/adaptor/openai"
 	"github.com/labring/aiproxy/core/relay/meta"
 	relaymodel "github.com/labring/aiproxy/core/relay/model"
+	"golang.org/x/sync/semaphore"
 )
 
 const (
@@ -128,7 +132,7 @@ func OpenAIConvertRequest(meta *meta.Meta, req *http.Request) (*Request, error)
 		claudeToolChoice := struct {
 			Type string `json:"type"`
 			Name string `json:"name,omitempty"`
-		}{Type: "auto"} // default value https://docs.anthropic.com/en/docs/build-with-claude/tool-use#controlling-claudes-output
+		}{Type: "auto"}
 		if choice, ok := textRequest.ToolChoice.(map[string]any); ok {
 			if function, ok := choice["function"].(map[string]any); ok {
 				claudeToolChoice.Type = "tool"
@@ -143,6 +147,8 @@ func OpenAIConvertRequest(meta *meta.Meta, req *http.Request) (*Request, error)
 		claudeRequest.ToolChoice = claudeToolChoice
 	}
 
+	var imageTasks []*Content
+
 	for _, message := range textRequest.Messages {
 		if message.Role == "system" {
 			claudeRequest.System = append(claudeRequest.System, Content{
@@ -167,9 +173,9 @@ func OpenAIConvertRequest(meta *meta.Meta, req *http.Request) (*Request, error)
 				content.Text = ""
 				content.ToolUseID = message.ToolCallID
 			}
-			claudeMessage.Content = append(claudeMessage.Content, content)
+			claudeMessage.Content = append(claudeMessage.Content, &content)
 		} else {
-			var contents []Content
+			var contents []*Content
 			openaiContent := message.ParseContent()
 			for _, part := range openaiContent {
 				var content Content
@@ -180,16 +186,12 @@ func OpenAIConvertRequest(meta *meta.Meta, req *http.Request) (*Request, error)
 				case relaymodel.ContentTypeImageURL:
 					content.Type = conetentTypeImage
 					content.Source = &ImageSource{
-						Type: "base64",
-					}
-					mimeType, data, err := image.GetImageFromURL(req.Context(), part.ImageURL.URL)
-					if err != nil {
-						return nil, err
+						Type: "url",
+						URL:  part.ImageURL.URL,
 					}
-					content.Source.MediaType = mimeType
-					content.Source.Data = data
+					imageTasks = append(imageTasks, &content)
 				}
-				contents = append(contents, content)
+				contents = append(contents, &content)
 			}
 			claudeMessage.Content = contents
 		}
@@ -197,7 +199,7 @@ func OpenAIConvertRequest(meta *meta.Meta, req *http.Request) (*Request, error)
 		for _, toolCall := range message.ToolCalls {
 			inputParam := make(map[string]any)
 			_ = sonic.UnmarshalString(toolCall.Function.Arguments, &inputParam)
-			claudeMessage.Content = append(claudeMessage.Content, Content{
+			claudeMessage.Content = append(claudeMessage.Content, &Content{
 				Type:  toolUseType,
 				ID:    toolCall.ID,
 				Name:  toolCall.Function.Name,
@@ -207,9 +209,56 @@ func OpenAIConvertRequest(meta *meta.Meta, req *http.Request) (*Request, error)
 		claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
 	}
 
+	if len(imageTasks) > 0 {
+		err := batchPatchImage2Base64(req.Context(), imageTasks)
+		if err != nil {
+			return nil, err
+		}
+	}
+
 	return &claudeRequest, nil
 }
 
+func batchPatchImage2Base64(ctx context.Context, imageTasks []*Content) error {
+	sem := semaphore.NewWeighted(3)
+	var wg sync.WaitGroup
+	var mu sync.Mutex
+	var processErrs []error
+
+	for _, task := range imageTasks {
+		if task.Source.URL == "" {
+			continue
+		}
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+
+			_ = sem.Acquire(ctx, 1)
+			defer sem.Release(1)
+
+			mimeType, data, err := image.GetImageFromURL(ctx, task.Source.URL)
+			if err != nil {
+				mu.Lock()
+				processErrs = append(processErrs, err)
+				mu.Unlock()
+				return
+			}
+
+			task.Source.Type = "base64"
+			task.Source.URL = ""
+			task.Source.MediaType = mimeType
+			task.Source.Data = data
+		}()
+	}
+
+	wg.Wait()
+
+	if len(processErrs) != 0 {
+		return errors.Join(processErrs...)
+	}
+	return nil
+}
+
 // https://docs.anthropic.com/claude/reference/messages-streaming
 func StreamResponse2OpenAI(meta *meta.Meta, respData []byte) (*relaymodel.ChatCompletionsStreamResponse, *relaymodel.ErrorWithStatusCode) {
 	var usage *relaymodel.Usage

+ 1 - 1
core/relay/adaptor/gemini/embeddings.go

@@ -29,7 +29,7 @@ func ConvertEmbeddingRequest(meta *meta.Meta, req *http.Request) (string, http.H
 		requests[i] = EmbeddingRequest{
 			Model: model,
 			Content: ChatContent{
-				Parts: []Part{
+				Parts: []*Part{
 					{
 						Text: input,
 					},

+ 69 - 43
core/relay/adaptor/gemini/main.go

@@ -4,6 +4,7 @@ import (
 	"bufio"
 	"bytes"
 	"context"
+	"errors"
 	"fmt"
 	"io"
 	"net/http"
@@ -25,14 +26,11 @@ import (
 	relaymodel "github.com/labring/aiproxy/core/relay/model"
 	"github.com/labring/aiproxy/core/relay/utils"
 	log "github.com/sirupsen/logrus"
+	"golang.org/x/sync/semaphore"
 )
 
 // https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
 
-const (
-	VisionMaxImageNum = 16
-)
-
 var toolChoiceTypeMap = map[string]string{
 	"none":     "NONE",
 	"auto":     "AUTO",
@@ -146,37 +144,27 @@ func buildToolConfig(textRequest *relaymodel.GeneralOpenAIRequest) *ToolConfig {
 	return &toolConfig
 }
 
-func buildMessageParts(ctx context.Context, part relaymodel.MessageContent) ([]Part, error) {
-	if part.Type == relaymodel.ContentTypeText {
-		return []Part{{Text: part.Text}}, nil
+func buildMessageParts(message relaymodel.MessageContent) *Part {
+	part := &Part{
+		Text: message.Text,
 	}
-
-	if part.Type == relaymodel.ContentTypeImageURL {
-		mimeType, data, err := image.GetImageFromURL(ctx, part.ImageURL.URL)
-		if err != nil {
-			return nil, err
+	if message.ImageURL != nil {
+		part.InlineData = &InlineData{
+			Data: message.ImageURL.URL,
 		}
-		return []Part{{
-			InlineData: &InlineData{
-				MimeType: mimeType,
-				Data:     data,
-			},
-		}}, nil
 	}
-
-	return nil, nil
+	return part
 }
 
-func buildContents(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest) (*ChatContent, []*ChatContent, error) {
+func buildContents(textRequest *relaymodel.GeneralOpenAIRequest) (*ChatContent, []*ChatContent, []*Part) {
 	contents := make([]*ChatContent, 0, len(textRequest.Messages))
-	imageNum := 0
+	var imageTasks []*Part
 
 	var systemContent *ChatContent
 
 	for _, message := range textRequest.Messages {
 		content := ChatContent{
-			Role:  message.Role,
-			Parts: make([]Part, 0),
+			Role: message.Role,
 		}
 
 		switch {
@@ -190,7 +178,7 @@ func buildContents(ctx context.Context, textRequest *relaymodel.GeneralOpenAIReq
 				} else {
 					args = make(map[string]any)
 				}
-				content.Parts = append(content.Parts, Part{
+				content.Parts = append(content.Parts, &Part{
 					FunctionCall: &FunctionCall{
 						Name: toolCall.Function.Name,
 						Args: args,
@@ -211,7 +199,7 @@ func buildContents(ctx context.Context, textRequest *relaymodel.GeneralOpenAIReq
 			} else {
 				contentMap = make(map[string]any)
 			}
-			content.Parts = append(content.Parts, Part{
+			content.Parts = append(content.Parts, &Part{
 				FunctionResponse: &FunctionResponse{
 					Name: *message.Name,
 					Response: struct {
@@ -226,18 +214,11 @@ func buildContents(ctx context.Context, textRequest *relaymodel.GeneralOpenAIReq
 		default:
 			openaiContent := message.ParseContent()
 			for _, part := range openaiContent {
-				if part.Type == relaymodel.ContentTypeImageURL {
-					imageNum++
-					if imageNum > VisionMaxImageNum {
-						continue
-					}
-				}
-
-				parts, err := buildMessageParts(ctx, part)
-				if err != nil {
-					return nil, nil, err
+				part := buildMessageParts(part)
+				if part.InlineData != nil {
+					imageTasks = append(imageTasks, part)
 				}
-				content.Parts = append(content.Parts, parts...)
+				content.Parts = append(content.Parts, part)
 			}
 		}
 
@@ -253,7 +234,48 @@ func buildContents(ctx context.Context, textRequest *relaymodel.GeneralOpenAIReq
 		contents = append(contents, &content)
 	}
 
-	return systemContent, contents, nil
+	return systemContent, contents, imageTasks
+}
+
+func processImageTasks(ctx context.Context, imageTasks []*Part) error {
+	if len(imageTasks) == 0 {
+		return nil
+	}
+
+	sem := semaphore.NewWeighted(3)
+	var wg sync.WaitGroup
+	var mu sync.Mutex
+	var processErrs []error
+
+	for _, task := range imageTasks {
+		if task.InlineData == nil || task.InlineData.Data == "" {
+			continue
+		}
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			_ = sem.Acquire(ctx, 1)
+			defer sem.Release(1)
+
+			mimeType, data, err := image.GetImageFromURL(ctx, task.InlineData.Data)
+			if err != nil {
+				mu.Lock()
+				processErrs = append(processErrs, err)
+				mu.Unlock()
+				return
+			}
+
+			task.InlineData.MimeType = mimeType
+			task.InlineData.Data = data
+		}()
+	}
+
+	wg.Wait()
+
+	if len(processErrs) != 0 {
+		return errors.Join(processErrs...)
+	}
+	return nil
 }
 
 // Setting safety to the lowest possible values since Gemini is already powerless enough
@@ -266,9 +288,13 @@ func ConvertRequest(meta *meta.Meta, req *http.Request) (string, http.Header, io
 	textRequest.Model = meta.ActualModel
 	meta.Set("stream", textRequest.Stream)
 
-	systemContent, contents, err := buildContents(req.Context(), textRequest)
-	if err != nil {
-		return "", nil, nil, err
+	systemContent, contents, imageTasks := buildContents(textRequest)
+
+	// Process image tasks concurrently
+	if len(imageTasks) > 0 {
+		if err := processImageTasks(req.Context(), imageTasks); err != nil {
+			return "", nil, nil, err
+		}
 	}
 
 	config, err := buildGenerationConfig(meta, req, textRequest)
@@ -418,7 +444,7 @@ func responseChat2OpenAI(meta *meta.Meta, response *ChatResponse) *relaymodel.Te
 			}
 			for _, part := range candidate.Content.Parts {
 				if part.FunctionCall != nil {
-					toolCall, err := getToolCall(&part)
+					toolCall, err := getToolCall(part)
 					if err != nil {
 						log.Error("get tool call failed: " + err.Error())
 					}
@@ -496,7 +522,7 @@ func streamResponseChat2OpenAI(meta *meta.Meta, geminiResponse *ChatResponse) *r
 			}
 			for _, part := range candidate.Content.Parts {
 				if part.FunctionCall != nil {
-					toolCall, err := getToolCall(&part)
+					toolCall, err := getToolCall(part)
 					if err != nil {
 						log.Error("get tool call failed: " + err.Error())
 					}

+ 2 - 2
core/relay/adaptor/gemini/model.go

@@ -66,8 +66,8 @@ type Part struct {
 }
 
 type ChatContent struct {
-	Role  string `json:"role,omitempty"`
-	Parts []Part `json:"parts"`
+	Role  string  `json:"role,omitempty"`
+	Parts []*Part `json:"parts"`
 }
 
 type ChatSafetySettings struct {