Przeglądaj źródła

feat: add multi-file type support for Gemini and Claude

- Add file data DTO for structured file handling
- Implement file decoder service
- Update Claude and Gemini relay channels to handle various file types
- Reorganize worker service to cf_worker for clarity
- Update token counter and image service for new file types
CalciumIon 1 rok temu
rodzic
commit
2b38e8ed8d

+ 1 - 0
README.en.md

@@ -82,6 +82,7 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
 - `GEMINI_MODEL_MAP`: Specify Gemini model versions (v1/v1beta), format: "model:version", comma-separated
 - `COHERE_SAFETY_SETTING`: Cohere model [safety settings](https://docs.cohere.com/docs/safety-modes#overview), options: `NONE`, `CONTEXTUAL`, `STRICT`, default `NONE`
 - `GEMINI_VISION_MAX_IMAGE_NUM`: Gemini model maximum image number, default `16`, set to `-1` to disable
+- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`
 
 ## Deployment
 > [!TIP]

+ 1 - 0
README.md

@@ -88,6 +88,7 @@
 - `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置(v1beta)
 - `COHERE_SAFETY_SETTING`:Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`,`STRICT`,默认为 `NONE`。
 - `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认为 `16`,设置为 `-1` 则不限制。
+- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB,默认为 `20`。
 ## 部署
 > [!TIP]
 > 最新版Docker镜像:`calciumion/new-api:latest`  

+ 2 - 0
constant/env.go

@@ -10,6 +10,8 @@ import (
 var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
 var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
 
+var MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
+
 // ForceStreamOption 覆盖请求参数,强制返回usage信息
 var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
 

+ 8 - 0
dto/file_data.go

@@ -0,0 +1,8 @@
+package dto
+
+type LocalFileData struct {
+	MimeType   string
+	Base64Data string
+	Url        string
+	Size       int64
+}

+ 6 - 3
relay/channel/claude/relay-claude.go

@@ -225,9 +225,12 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 						// 判断是否是url
 						if strings.HasPrefix(imageUrl.Url, "http") {
 							// 是url,获取图片的类型和base64编码的数据
-							mimeType, data, _ := service.GetImageFromUrl(imageUrl.Url)
-							claudeMediaMessage.Source.MediaType = mimeType
-							claudeMediaMessage.Source.Data = data
+							fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
+							if err != nil {
+								return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
+							}
+							claudeMediaMessage.Source.MediaType = fileData.MimeType
+							claudeMediaMessage.Source.Data = fileData.Base64Data
 						} else {
 							_, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
 							if err != nil {

+ 6 - 3
relay/channel/gemini/relay-gemini.go

@@ -192,11 +192,14 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 				// 判断是否是url
 				if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
 					// 是url,获取图片的类型和base64编码的数据
-					mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
+					fileData, err := service.GetFileBase64FromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
+					if err != nil {
+						return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
+					}
 					parts = append(parts, GeminiPart{
 						InlineData: &GeminiInlineData{
-							MimeType: mimeType,
-							Data:     data,
+							MimeType: fileData.MimeType,
+							Data:     fileData.Base64Data,
 						},
 					})
 				} else {

+ 1 - 1
relay/relay-text.go

@@ -230,7 +230,7 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
 	var err error
 	switch info.RelayMode {
 	case relayconstant.RelayModeChatCompletions:
-		promptTokens, err = service.CountTokenChatRequest(*textRequest, textRequest.Model)
+		promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
 	case relayconstant.RelayModeCompletions:
 		promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
 	case relayconstant.RelayModeModerations:

+ 6 - 3
service/worker.go → service/cf_worker.go

@@ -9,9 +9,12 @@ import (
 	"strings"
 )
 
-func DoImageRequest(originUrl string) (resp *http.Response, err error) {
+func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
 	if setting.EnableWorker() {
-		common.SysLog(fmt.Sprintf("downloading image from worker: %s", originUrl))
+		common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
+		if !strings.HasPrefix(originUrl, "https") {
+			return nil, fmt.Errorf("only support https url")
+		}
 		workerUrl := setting.WorkerUrl
 		if !strings.HasSuffix(workerUrl, "/") {
 			workerUrl += "/"
@@ -20,7 +23,7 @@ func DoImageRequest(originUrl string) (resp *http.Response, err error) {
 		data := []byte(`{"url":"` + originUrl + `","key":"` + setting.WorkerValidKey + `"}`)
 		return http.Post(setting.WorkerUrl, "application/json", bytes.NewBuffer(data))
 	} else {
-		common.SysLog(fmt.Sprintf("downloading image from origin: %s", originUrl))
+		common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
 		return http.Get(originUrl)
 	}
 }

+ 39 - 0
service/file_decoder.go

@@ -0,0 +1,39 @@
+package service
+
+import (
+	"encoding/base64"
+	"fmt"
+	"io"
+	"one-api/constant"
+	"one-api/dto"
+)
+
+var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
+
+func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
+	resp, err := DoDownloadRequest(url)
+	if err != nil {
+		return nil, err
+	}
+	defer resp.Body.Close()
+
+	// Always use LimitReader to prevent oversized downloads
+	fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1)))
+	if err != nil {
+		return nil, err
+	}
+
+	// Check actual size after reading
+	if len(fileBytes) > maxFileSize {
+		return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)
+	}
+
+	// Convert to base64
+	base64Data := base64.StdEncoding.EncodeToString(fileBytes)
+
+	return &dto.LocalFileData{
+		Base64Data: base64Data,
+		MimeType:   resp.Header.Get("Content-Type"),
+		Size:       int64(len(fileBytes)),
+	}, nil
+}

+ 10 - 4
service/image.go

@@ -33,12 +33,12 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
 
 // GetImageFromUrl 获取图片的类型和base64编码的数据
 func GetImageFromUrl(url string) (mimeType string, data string, err error) {
-	resp, err := DoImageRequest(url)
+	resp, err := DoDownloadRequest(url)
 	if err != nil {
-		return
+		return "", "", err
 	}
 	if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
-		return
+		return "", "", fmt.Errorf("invalid content type: %s, required image/*", resp.Header.Get("Content-Type"))
 	}
 	defer resp.Body.Close()
 	buffer := bytes.NewBuffer(nil)
@@ -52,7 +52,7 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
 }
 
 func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
-	response, err := DoImageRequest(imageUrl)
+	response, err := DoDownloadRequest(imageUrl)
 	if err != nil {
 		common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
 		return image.Config{}, "", err
@@ -64,6 +64,12 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
 		return image.Config{}, "", err
 	}
 
+	mimeType := response.Header.Get("Content-Type")
+
+	if !strings.HasPrefix(mimeType, "image/") {
+		return image.Config{}, "", fmt.Errorf("invalid content type: %s, required image/*", mimeType)
+	}
+
 	var readData []byte
 	for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
 		common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))

+ 9 - 6
service/token_counter.go

@@ -80,7 +80,7 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
 	return len(tokenEncoder.Encode(text, nil, nil))
 }
 
-func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
+func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
 	baseTokens := 85
 	if model == "glm-4v" {
 		return 1047, nil
@@ -96,6 +96,9 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
 	if !constant.GetMediaToken {
 		return 256, nil
 	}
+	if info.ChannelType == common.ChannelTypeGemini || info.ChannelType == common.ChannelTypeVertexAi || info.ChannelType == common.ChannelTypeAnthropic {
+		return 256, nil
+	}
 	// 同步One API的图片计费逻辑
 	if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
 		imageUrl.Detail = "high"
@@ -155,9 +158,9 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
 	return tiles*tileTokens + baseTokens, nil
 }
 
-func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int, error) {
+func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
 	tkm := 0
-	msgTokens, err := CountTokenMessages(request.Messages, model, request.Stream)
+	msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
 	if err != nil {
 		return 0, err
 	}
@@ -179,7 +182,7 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int,
 				countStr += fmt.Sprintf("%v", tool.Function.Parameters)
 			}
 		}
-		toolTokens, err := CountTokenInput(countStr, model)
+		toolTokens, err := CountTokenInput(countStr, request.Model)
 		if err != nil {
 			return 0, err
 		}
@@ -256,7 +259,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
 	return textToken, audioToken, nil
 }
 
-func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
+func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
 	//recover when panic
 	tokenEncoder := getTokenEncoder(model)
 	// Reference:
@@ -290,7 +293,7 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
 				for _, m := range arrayContent {
 					if m.Type == dto.ContentTypeImageURL {
 						imageUrl := m.ImageUrl.(dto.MessageImageUrl)
-						imageTokenNum, err := getImageToken(&imageUrl, model, stream)
+						imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
 						if err != nil {
 							return 0, err
 						}