Преглед изворни кода

refactor: update function signatures to include context and improve file handling #1599

CaIon пре 4 месеци
родитељ
комит
0bb43aa464

+ 2 - 1
constant/context_key.go

@@ -3,7 +3,8 @@ package constant
 type ContextKey string
 
 const (
-	ContextKeyPromptTokens ContextKey = "prompt_tokens"
+	ContextKeyTokenCountMeta ContextKey = "token_count_meta"
+	ContextKeyPromptTokens   ContextKey = "prompt_tokens"
 
 	ContextKeyOriginalModel    ContextKey = "original_model"
 	ContextKeyRequestStartTime ContextKey = "request_start_time"

+ 2 - 0
controller/relay.go

@@ -133,6 +133,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
 		return
 	}
 
+	common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
+
 	preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
 	if newAPIError != nil {
 		return

+ 2 - 2
dto/claude.go

@@ -231,7 +231,7 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
 							data = common.Interface2String(media.Source.Data)
 						}
 						if data != "" {
-							fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data})
+							fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
 						}
 					}
 				}
@@ -263,7 +263,7 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
 						data = common.Interface2String(media.Source.Data)
 					}
 					if data != "" {
-						fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data})
+						fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
 					}
 				}
 			case "tool_use":

+ 8 - 8
dto/gemini.go

@@ -35,23 +35,23 @@ func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
 			if part.InlineData != nil && part.InlineData.Data != "" {
 				if strings.HasPrefix(part.InlineData.MimeType, "image/") {
 					files = append(files, &types.FileMeta{
-						FileType: types.FileTypeImage,
-						Data:     part.InlineData.Data,
+						FileType:   types.FileTypeImage,
+						OriginData: part.InlineData.Data,
 					})
 				} else if strings.HasPrefix(part.InlineData.MimeType, "audio/") {
 					files = append(files, &types.FileMeta{
-						FileType: types.FileTypeAudio,
-						Data:     part.InlineData.Data,
+						FileType:   types.FileTypeAudio,
+						OriginData: part.InlineData.Data,
 					})
 				} else if strings.HasPrefix(part.InlineData.MimeType, "video/") {
 					files = append(files, &types.FileMeta{
-						FileType: types.FileTypeVideo,
-						Data:     part.InlineData.Data,
+						FileType:   types.FileTypeVideo,
+						OriginData: part.InlineData.Data,
 					})
 				} else {
 					files = append(files, &types.FileMeta{
-						FileType: types.FileTypeFile,
-						Data:     part.InlineData.Data,
+						FileType:   types.FileTypeFile,
+						OriginData: part.InlineData.Data,
 					})
 				}
 			}

+ 9 - 9
dto/openai_request.go

@@ -118,7 +118,7 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
 						meta := &types.FileMeta{
 							FileType: types.FileTypeImage,
 						}
-						meta.Data = imageUrl.Url
+						meta.OriginData = imageUrl.Url
 						meta.Detail = imageUrl.Detail
 						fileMeta = append(fileMeta, meta)
 					}
@@ -128,7 +128,7 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
 						meta := &types.FileMeta{
 							FileType: types.FileTypeAudio,
 						}
-						meta.Data = inputAudio.Data
+						meta.OriginData = inputAudio.Data
 						fileMeta = append(fileMeta, meta)
 					}
 				} else if m.Type == ContentTypeFile {
@@ -137,7 +137,7 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
 						meta := &types.FileMeta{
 							FileType: types.FileTypeFile,
 						}
-						meta.Data = file.FileData
+						meta.OriginData = file.FileData
 						fileMeta = append(fileMeta, meta)
 					}
 				} else if m.Type == ContentTypeVideoUrl {
@@ -146,7 +146,7 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
 						meta := &types.FileMeta{
 							FileType: types.FileTypeVideo,
 						}
-						meta.Data = videoUrl.Url
+						meta.OriginData = videoUrl.Url
 						fileMeta = append(fileMeta, meta)
 					}
 				} else {
@@ -784,14 +784,14 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
 		for _, input := range inputs {
 			if input.Type == "input_image" {
 				fileMeta = append(fileMeta, &types.FileMeta{
-					FileType: types.FileTypeImage,
-					Data:     input.ImageUrl,
-					Detail:   input.Detail,
+					FileType:   types.FileTypeImage,
+					OriginData: input.ImageUrl,
+					Detail:     input.Detail,
 				})
 			} else if input.Type == "input_file" {
 				fileMeta = append(fileMeta, &types.FileMeta{
-					FileType: types.FileTypeFile,
-					Data:     input.FileUrl,
+					FileType:   types.FileTypeFile,
+					OriginData: input.FileUrl,
 				})
 			} else {
 				texts = append(texts, input.Text)

+ 1 - 1
relay/channel/aws/adaptor.go

@@ -63,7 +63,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 
 	var claudeReq *dto.ClaudeRequest
 	var err error
-	claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
+	claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
 	if err != nil {
 		return nil, err
 	}

+ 1 - 1
relay/channel/claude/adaptor.go

@@ -78,7 +78,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 	if a.RequestMode == RequestModeCompletion {
 		return RequestOpenAI2ClaudeComplete(*request), nil
 	} else {
-		return RequestOpenAI2ClaudeMessage(*request)
+		return RequestOpenAI2ClaudeMessage(c, *request)
 	}
 }
 

+ 2 - 2
relay/channel/claude/relay-claude.go

@@ -71,7 +71,7 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.Cla
 	return &claudeRequest
 }
 
-func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
+func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
 	claudeTools := make([]any, 0, len(textRequest.Tools))
 
 	for _, tool := range textRequest.Tools {
@@ -355,7 +355,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
 						// 判断是否是url
 						if strings.HasPrefix(imageUrl.Url, "http") {
 							// 是url,获取图片的类型和base64编码的数据
-							fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
+							fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Claude")
 							if err != nil {
 								return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
 							}

+ 1 - 1
relay/channel/gemini/adaptor.go

@@ -142,7 +142,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 		return nil, errors.New("request is nil")
 	}
 
-	geminiRequest, err := CovertGemini2OpenAI(*request, info)
+	geminiRequest, err := CovertGemini2OpenAI(c, *request, info)
 	if err != nil {
 		return nil, err
 	}

+ 2 - 2
relay/channel/gemini/relay-gemini.go

@@ -178,7 +178,7 @@ func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.Rel
 }
 
 // Setting safety to the lowest possible values since Gemini is already powerless enough
-func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
+func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
 
 	geminiRequest := dto.GeminiChatRequest{
 		Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
@@ -390,7 +390,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
 				// 判断是否是url
 				if strings.HasPrefix(part.GetImageMedia().Url, "http") {
 					// 是url,获取文件的类型和base64编码的数据
-					fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url)
+					fileData, err := service.GetFileBase64FromUrl(c, part.GetImageMedia().Url, "formatting image for Gemini")
 					if err != nil {
 						return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err)
 					}

+ 2 - 2
relay/channel/ollama/adaptor.go

@@ -31,7 +31,7 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
 	openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{
 		IncludeUsage: true,
 	}
-	return requestOpenAI2Ollama(openaiRequest.(*dto.GeneralOpenAIRequest))
+	return requestOpenAI2Ollama(c, openaiRequest.(*dto.GeneralOpenAIRequest))
 }
 
 func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -69,7 +69,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}
-	return requestOpenAI2Ollama(request)
+	return requestOpenAI2Ollama(c, request)
 }
 
 func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {

+ 2 - 2
relay/channel/ollama/relay-ollama.go

@@ -14,7 +14,7 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
-func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
+func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
 	messages := make([]dto.Message, 0, len(request.Messages))
 	for _, message := range request.Messages {
 		if !message.IsStringContent() {
@@ -24,7 +24,7 @@ func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, er
 					imageUrl := mediaMessage.GetImageMedia()
 					// check if not base64
 					if strings.HasPrefix(imageUrl.Url, "http") {
-						fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
+						fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Ollama")
 						if err != nil {
 							return nil, err
 						}

+ 2 - 2
relay/channel/vertex/adaptor.go

@@ -182,7 +182,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 		return nil, errors.New("request is nil")
 	}
 	if a.RequestMode == RequestModeClaude {
-		claudeReq, err := claude.RequestOpenAI2ClaudeMessage(*request)
+		claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request)
 		if err != nil {
 			return nil, err
 		}
@@ -191,7 +191,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 		info.UpstreamModelName = claudeReq.Model
 		return vertexClaudeReq, nil
 	} else if a.RequestMode == RequestModeGemini {
-		geminiRequest, err := gemini.CovertGemini2OpenAI(*request, info)
+		geminiRequest, err := gemini.CovertGemini2OpenAI(c, *request, info)
 		if err != nil {
 			return nil, err
 		}

+ 3 - 3
service/cf_worker.go

@@ -42,16 +42,16 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
 	return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
 }
 
-func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
+func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) {
 	if setting.EnableWorker() {
-		common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
+		common.SysLog(fmt.Sprintf("downloading file from worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
 		req := &WorkerRequest{
 			URL: originUrl,
 			Key: setting.WorkerValidKey,
 		}
 		return DoWorkerRequest(req)
 	} else {
-		common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
+		common.SysLog(fmt.Sprintf("downloading from origin with worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
 		return http.Get(originUrl)
 	}
 }

+ 21 - 6
service/file_decoder.go

@@ -3,17 +3,29 @@ package service
 import (
 	"encoding/base64"
 	"fmt"
+	"github.com/gin-gonic/gin"
 	"io"
 	"one-api/common"
 	"one-api/constant"
-	"one-api/dto"
+	"one-api/logger"
+	"one-api/types"
 	"strings"
 )
 
-func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
+func GetFileBase64FromUrl(c *gin.Context, url string, reason ...string) (*types.LocalFileData, error) {
+	contextKey := fmt.Sprintf("file_download_%s", common.GenerateHMAC(url))
+
+	// Check if the file has already been downloaded in this request
+	if cachedData, exists := c.Get(contextKey); exists {
+		if common.DebugEnabled {
+			logger.LogDebug(c, fmt.Sprintf("Using cached file data for URL: %s", url))
+		}
+		return cachedData.(*types.LocalFileData), nil
+	}
+
 	var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
 
-	resp, err := DoDownloadRequest(url)
+	resp, err := DoDownloadRequest(url, reason...)
 	if err != nil {
 		return nil, err
 	}
@@ -77,12 +89,15 @@ func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
 			}
 		}
 	}
-
-	return &dto.LocalFileData{
+	data := &types.LocalFileData{
 		Base64Data: base64Data,
 		MimeType:   mimeType,
 		Size:       int64(len(fileBytes)),
-	}, nil
+	}
+	// Store the file data in the context to avoid re-downloading
+	c.Set(contextKey, data)
+
+	return data, nil
 }
 
 func GetMimeTypeByExtension(ext string) string {

+ 40 - 6
service/token_counter.go

@@ -154,16 +154,22 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
 	var err error
 	var format string
 	var b64str string
-	if strings.HasPrefix(fileMeta.Data, "http") {
-		config, format, err = DecodeUrlImageData(fileMeta.Data)
+
+	if fileMeta.ParsedData != nil {
+		config, format, b64str, err = DecodeBase64ImageData(fileMeta.ParsedData.Base64Data)
 	} else {
-		common.SysLog(fmt.Sprintf("decoding image"))
-		config, format, b64str, err = DecodeBase64ImageData(fileMeta.Data)
+		if strings.HasPrefix(fileMeta.OriginData, "http") {
+			config, format, err = DecodeUrlImageData(fileMeta.OriginData)
+		} else {
+			common.SysLog(fmt.Sprintf("decoding image"))
+			config, format, b64str, err = DecodeBase64ImageData(fileMeta.OriginData)
+		}
+		fileMeta.MimeType = format
 	}
+
 	if err != nil {
 		return 0, err
 	}
-	fileMeta.MimeType = format
 
 	if config.Width == 0 || config.Height == 0 {
 		// not an image
@@ -171,7 +177,7 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
 			// file type
 			return 3 * baseTokens, nil
 		}
-		return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.Data))
+		return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.OriginData))
 	}
 
 	width := config.Width
@@ -268,6 +274,34 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
 		tkm += 3
 	}
 
+	shouldFetchFiles := true
+
+	if info.RelayFormat == types.RelayFormatOpenAIRealtime || info.RelayFormat == types.RelayFormatGemini {
+		shouldFetchFiles = false
+	}
+
+	if shouldFetchFiles {
+		for _, file := range meta.Files {
+			if strings.HasPrefix(file.OriginData, "http") {
+				localFileData, err := GetFileBase64FromUrl(c, file.OriginData, "token_counter")
+				if err != nil {
+					return 0, fmt.Errorf("error getting file base64 from url: %v", err)
+				}
+				if strings.HasPrefix(localFileData.MimeType, "image/") {
+					file.FileType = types.FileTypeImage
+				} else if strings.HasPrefix(localFileData.MimeType, "video/") {
+					file.FileType = types.FileTypeVideo
+				} else if strings.HasPrefix(localFileData.MimeType, "audio/") {
+					file.FileType = types.FileTypeAudio
+				} else {
+					file.FileType = types.FileTypeFile
+				}
+				file.MimeType = localFileData.MimeType
+				file.ParsedData = localFileData
+			}
+		}
+	}
+
 	for _, file := range meta.Files {
 		switch file.FileType {
 		case types.FileTypeImage:

+ 1 - 1
dto/file_data.go → types/file_data.go

@@ -1,4 +1,4 @@
-package dto
+package types
 
 type LocalFileData struct {
 	MimeType   string

+ 4 - 3
types/request_meta.go

@@ -32,9 +32,10 @@ type TokenCountMeta struct {
 
 type FileMeta struct {
 	FileType
-	MimeType string
-	Data     string
-	Detail   string
+	MimeType   string
+	OriginData string // url or base64 data
+	Detail     string
+	ParsedData *LocalFileData
 }
 
 type RequestMeta struct {