Jelajahi Sumber

refactor(file-source): unify file source creation and enhance caching mechanisms

CaIon 5 hari lalu
induk
melakukan
03758a4a85

+ 24 - 30
dto/claude.go

@@ -98,6 +98,20 @@ func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
 	return mediaContent
 }
 
+func (m *ClaudeMediaMessage) ToFileSource() types.FileSource {
+	if m.Source == nil {
+		return nil
+	}
+	data := m.Source.Url
+	if data == "" {
+		data = common.Interface2String(m.Source.Data)
+	}
+	if data == "" {
+		return nil
+	}
+	return types.NewFileSourceFromData(data, m.Source.MediaType)
+}
+
 type ClaudeMessageSource struct {
 	Type      string `json:"type"`
 	MediaType string `json:"media_type,omitempty"`
@@ -223,14 +237,6 @@ type OutputConfigForEffort struct {
 	Effort string `json:"effort,omitempty"`
 }
 
-// createClaudeFileSource 根据数据内容创建正确类型的 FileSource
-func createClaudeFileSource(data string) *types.FileSource {
-	if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") {
-		return types.NewURLFileSource(data)
-	}
-	return types.NewBase64FileSource(data, "")
-}
-
 func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
 	maxTokens := 0
 	if c.MaxTokens != nil {
@@ -258,17 +264,11 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
 				case "text":
 					texts = append(texts, media.GetText())
 				case "image":
-					if media.Source != nil {
-						data := media.Source.Url
-						if data == "" {
-							data = common.Interface2String(media.Source.Data)
-						}
-						if data != "" {
-							fileMeta = append(fileMeta, &types.FileMeta{
-								FileType: types.FileTypeImage,
-								Source:   createClaudeFileSource(data),
-							})
-						}
+					if source := media.ToFileSource(); source != nil {
+						fileMeta = append(fileMeta, &types.FileMeta{
+							FileType: types.FileTypeImage,
+							Source:   source,
+						})
 					}
 				}
 			}
@@ -293,17 +293,11 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
 			case "text":
 				texts = append(texts, media.GetText())
 			case "image":
-				if media.Source != nil {
-					data := media.Source.Url
-					if data == "" {
-						data = common.Interface2String(media.Source.Data)
-					}
-					if data != "" {
-						fileMeta = append(fileMeta, &types.FileMeta{
-							FileType: types.FileTypeImage,
-							Source:   createClaudeFileSource(data),
-						})
-					}
+				if source := media.ToFileSource(); source != nil {
+					fileMeta = append(fileMeta, &types.FileMeta{
+						FileType: types.FileTypeImage,
+						Source:   source,
+					})
 				}
 			case "tool_use":
 				if media.Name != "" {

+ 8 - 11
dto/gemini.go

@@ -64,14 +64,6 @@ type LatLng struct {
 	Longitude *float64 `json:"longitude,omitempty"`
 }
 
-// createGeminiFileSource 根据数据内容创建正确类型的 FileSource
-func createGeminiFileSource(data string, mimeType string) *types.FileSource {
-	if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") {
-		return types.NewURLFileSource(data)
-	}
-	return types.NewBase64FileSource(data, mimeType)
-}
-
 func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
 	var files []*types.FileMeta = make([]*types.FileMeta, 0)
 
@@ -87,9 +79,8 @@ func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
 			if part.Text != "" {
 				inputTexts = append(inputTexts, part.Text)
 			}
-			if part.InlineData != nil && part.InlineData.Data != "" {
+			if source := part.InlineData.ToFileSource(); source != nil {
 				mimeType := part.InlineData.MimeType
-				source := createGeminiFileSource(part.InlineData.Data, mimeType)
 				var fileType types.FileType
 				if strings.HasPrefix(mimeType, "image/") {
 					fileType = types.FileTypeImage
@@ -103,7 +94,6 @@ func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
 				files = append(files, &types.FileMeta{
 					FileType: fileType,
 					Source:   source,
-					MimeType: mimeType,
 				})
 			}
 		}
@@ -215,6 +205,13 @@ type GeminiInlineData struct {
 	Data     string `json:"data"`
 }
 
+func (d *GeminiInlineData) ToFileSource() types.FileSource {
+	if d == nil || d.Data == "" {
+		return nil
+	}
+	return types.NewFileSourceFromData(d.Data, d.MimeType)
+}
+
 // UnmarshalJSON custom unmarshaler for GeminiInlineData to support snake_case and camelCase for MimeType
 func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
 	type Alias GeminiInlineData // Use type alias to avoid recursion

+ 53 - 47
dto/openai_request.go

@@ -108,14 +108,6 @@ type GeneralOpenAIRequest struct {
 	ReasoningSplit json.RawMessage `json:"reasoning_split,omitempty"`
 }
 
-// createFileSource 根据数据内容创建正确类型的 FileSource
-func createFileSource(data string) *types.FileSource {
-	if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") {
-		return types.NewURLFileSource(data)
-	}
-	return types.NewBase64FileSource(data, "")
-}
-
 func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
 	var tokenCountMeta types.TokenCountMeta
 	var texts = make([]string, 0)
@@ -159,44 +151,24 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
 			}
 			arrayContent := message.ParseContent()
 			for _, m := range arrayContent {
-				if m.Type == ContentTypeImageURL {
-					imageUrl := m.GetImageMedia()
-					if imageUrl != nil && imageUrl.Url != "" {
-						source := createFileSource(imageUrl.Url)
-						fileMeta = append(fileMeta, &types.FileMeta{
-							FileType: types.FileTypeImage,
-							Source:   source,
-							Detail:   imageUrl.Detail,
-						})
-					}
-				} else if m.Type == ContentTypeInputAudio {
-					inputAudio := m.GetInputAudio()
-					if inputAudio != nil && inputAudio.Data != "" {
-						source := createFileSource(inputAudio.Data)
-						fileMeta = append(fileMeta, &types.FileMeta{
-							FileType: types.FileTypeAudio,
-							Source:   source,
-						})
-					}
-				} else if m.Type == ContentTypeFile {
-					file := m.GetFile()
-					if file != nil && file.FileData != "" {
-						source := createFileSource(file.FileData)
-						fileMeta = append(fileMeta, &types.FileMeta{
-							FileType: types.FileTypeFile,
-							Source:   source,
-						})
-					}
-				} else if m.Type == ContentTypeVideoUrl {
-					videoUrl := m.GetVideoUrl()
-					if videoUrl != nil && videoUrl.Url != "" {
-						source := createFileSource(videoUrl.Url)
-						fileMeta = append(fileMeta, &types.FileMeta{
-							FileType: types.FileTypeVideo,
-							Source:   source,
-						})
+				source := m.ToFileSource()
+				if source != nil {
+					meta := &types.FileMeta{Source: source}
+					switch m.Type {
+					case ContentTypeImageURL:
+						meta.FileType = types.FileTypeImage
+						if img := m.GetImageMedia(); img != nil {
+							meta.Detail = img.Detail
+						}
+					case ContentTypeInputAudio:
+						meta.FileType = types.FileTypeAudio
+					case ContentTypeFile:
+						meta.FileType = types.FileTypeFile
+					case ContentTypeVideoUrl:
+						meta.FileType = types.FileTypeVideo
 					}
-				} else {
+					fileMeta = append(fileMeta, meta)
+				} else if m.Type == ContentTypeText {
 					texts = append(texts, m.Text)
 				}
 			}
@@ -391,6 +363,40 @@ func (m *MediaContent) GetVideoUrl() *MessageVideoUrl {
 	return nil
 }
 
+func (m *MediaContent) ToFileSource() types.FileSource {
+	switch m.Type {
+	case ContentTypeImageURL:
+		img := m.GetImageMedia()
+		if img == nil || img.Url == "" {
+			return nil
+		}
+		return types.NewFileSourceFromData(img.Url, img.MimeType)
+	case ContentTypeInputAudio:
+		audio := m.GetInputAudio()
+		if audio == nil || audio.Data == "" {
+			return nil
+		}
+		mimeType := ""
+		if audio.Format != "" {
+			mimeType = "audio/" + audio.Format
+		}
+		return types.NewFileSourceFromData(audio.Data, mimeType)
+	case ContentTypeFile:
+		file := m.GetFile()
+		if file == nil || file.FileData == "" {
+			return nil
+		}
+		return types.NewFileSourceFromData(file.FileData, "")
+	case ContentTypeVideoUrl:
+		video := m.GetVideoUrl()
+		if video == nil || video.Url == "" {
+			return nil
+		}
+		return types.NewFileSourceFromData(video.Url, "")
+	}
+	return nil
+}
+
 type MessageImageUrl struct {
 	Url      string `json:"url"`
 	Detail   string `json:"detail,omitempty"`
@@ -865,7 +871,7 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
 				if input.ImageUrl != "" {
 					fileMeta = append(fileMeta, &types.FileMeta{
 						FileType: types.FileTypeImage,
-						Source:   createFileSource(input.ImageUrl),
+						Source:   types.NewFileSourceFromData(input.ImageUrl, ""),
 						Detail:   input.Detail,
 					})
 				}
@@ -873,7 +879,7 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
 				if input.FileUrl != "" {
 					fileMeta = append(fileMeta, &types.FileMeta{
 						FileType: types.FileTypeFile,
-						Source:   createFileSource(input.FileUrl),
+						Source:   types.NewFileSourceFromData(input.FileUrl, ""),
 					})
 				}
 			} else {

+ 15 - 83
relay/channel/claude/relay-claude.go

@@ -1,12 +1,10 @@
 package claude
 
 import (
-	"encoding/base64"
 	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
-	"path/filepath"
 	"strings"
 
 	"github.com/QuantumNous/new-api/common"
@@ -46,61 +44,6 @@ func maybeMarkClaudeRefusal(c *gin.Context, stopReason string) {
 	}
 }
 
-func createClaudeFileSource(file *dto.MessageFile) *types.FileSource {
-	if file == nil || file.FileData == "" {
-		return nil
-	}
-	if strings.HasPrefix(file.FileData, "http://") || strings.HasPrefix(file.FileData, "https://") {
-		return types.NewURLFileSource(file.FileData)
-	}
-	mimeType := ""
-	if ext := strings.TrimPrefix(strings.ToLower(filepath.Ext(file.FileName)), "."); ext != "" {
-		if detected := service.GetMimeTypeByExtension(ext); detected != "application/octet-stream" {
-			mimeType = detected
-		}
-	}
-	return types.NewBase64FileSource(file.FileData, mimeType)
-}
-
-func buildClaudeFileMessage(c *gin.Context, file *dto.MessageFile) (*dto.ClaudeMediaMessage, error) {
-	source := createClaudeFileSource(file)
-	if source == nil {
-		return nil, nil
-	}
-	base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting document for Claude")
-	if err != nil {
-		return nil, fmt.Errorf("get file data failed: %w", err)
-	}
-	switch strings.ToLower(mimeType) {
-	case "application/pdf":
-		return &dto.ClaudeMediaMessage{
-			Type: "document",
-			Source: &dto.ClaudeMessageSource{
-				Type:      "base64",
-				MediaType: mimeType,
-				Data:      base64Data,
-			},
-		}, nil
-	case "text/plain":
-		decodedData, err := base64.StdEncoding.DecodeString(base64Data)
-		if err != nil {
-			return nil, fmt.Errorf("decode text file data failed: %w", err)
-		}
-		return &dto.ClaudeMediaMessage{
-			Type: "text",
-			Text: common.GetPointer(string(decodedData)),
-		}, nil
-	default:
-		msg := fmt.Sprintf("claude: skip unsupported file content, filename=%q, mime=%q", file.FileName, mimeType)
-		if c != nil {
-			logger.LogInfo(c, msg)
-		} else {
-			common.SysLog(msg)
-		}
-		return nil, nil
-	}
-}
-
 func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
 	claudeTools := make([]any, 0, len(textRequest.Tools))
 
@@ -142,7 +85,7 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
 
 			// 解析 UserLocation JSON
 			var userLocationMap map[string]interface{}
-			if err := json.Unmarshal(textRequest.WebSearchOptions.UserLocation, &userLocationMap); err == nil {
+			if err := common.Unmarshal(textRequest.WebSearchOptions.UserLocation, &userLocationMap); err == nil {
 				// 检查是否有 approximate 字段
 				if approximateData, ok := userLocationMap["approximate"].(map[string]interface{}); ok {
 					if timezone, ok := approximateData["timezone"].(string); ok && timezone != "" {
@@ -406,44 +349,33 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
 							Type: "text",
 							Text: common.GetPointer[string](mediaMessage.Text),
 						})
-					case dto.ContentTypeImageURL:
+					default:
+						source := mediaMessage.ToFileSource()
+						if source == nil {
+							continue
+						}
+						base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Claude")
+						if err != nil {
+							return nil, fmt.Errorf("get file data failed: %s", err.Error())
+						}
 						claudeMediaMessage := dto.ClaudeMediaMessage{
-							Type: "image",
 							Source: &dto.ClaudeMessageSource{
 								Type: "base64",
 							},
 						}
-						imageUrl := mediaMessage.GetImageMedia()
-						if imageUrl == nil {
-							continue
-						}
-						// 使用统一的文件服务获取图片数据
-						var source *types.FileSource
-						if strings.HasPrefix(imageUrl.Url, "http") {
-							source = types.NewURLFileSource(imageUrl.Url)
+						if strings.HasPrefix(mimeType, "application/pdf") {
+							claudeMediaMessage.Type = "document"
 						} else {
-							source = types.NewBase64FileSource(imageUrl.Url, "")
-						}
-						base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Claude")
-						if err != nil {
-							return nil, fmt.Errorf("get file data failed: %s", err.Error())
+							claudeMediaMessage.Type = "image"
 						}
+
 						claudeMediaMessage.Source.MediaType = mimeType
 						claudeMediaMessage.Source.Data = base64Data
 						claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
-					// FIXME
-					//case dto.ContentTypeFile:
-					//	claudeFileMessage, err := buildClaudeFileMessage(c, mediaMessage.GetFile())
-					//	if err != nil {
-					//		return nil, err
-					//	}
-					//	if claudeFileMessage != nil {
-					//		claudeMediaMessages = append(claudeMediaMessages, *claudeFileMessage)
-					//	}
-					default:
 						continue
 					}
 				}
+
 				if message.ToolCalls != nil {
 					for _, toolCall := range message.ParseToolCalls() {
 						inputObj := make(map[string]any)

+ 4 - 38
relay/channel/gemini/relay-gemini.go

@@ -585,14 +585,10 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
 						Text: part.Text,
 					})
 				}
-			} else if part.Type == dto.ContentTypeImageURL {
-				// 使用统一的文件服务获取图片数据
-				var source *types.FileSource
-				imageUrl := part.GetImageMedia().Url
-				if strings.HasPrefix(imageUrl, "http") {
-					source = types.NewURLFileSource(imageUrl)
-				} else {
-					source = types.NewBase64FileSource(imageUrl, "")
+			} else {
+				source := part.ToFileSource()
+				if source == nil {
+					continue
 				}
 				base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Gemini")
 				if err != nil {
@@ -604,36 +600,6 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
 					return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", mimeType, source.GetIdentifier(), getSupportedMimeTypesList())
 				}
 
-				parts = append(parts, dto.GeminiPart{
-					InlineData: &dto.GeminiInlineData{
-						MimeType: mimeType,
-						Data:     base64Data,
-					},
-				})
-			} else if part.Type == dto.ContentTypeFile {
-				if part.GetFile().FileId != "" {
-					return nil, fmt.Errorf("only base64 file is supported in gemini")
-				}
-				fileSource := types.NewBase64FileSource(part.GetFile().FileData, "")
-				base64Data, mimeType, err := service.GetBase64Data(c, fileSource, "formatting file for Gemini")
-				if err != nil {
-					return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
-				}
-				parts = append(parts, dto.GeminiPart{
-					InlineData: &dto.GeminiInlineData{
-						MimeType: mimeType,
-						Data:     base64Data,
-					},
-				})
-			} else if part.Type == dto.ContentTypeInputAudio {
-				if part.GetInputAudio().Data == "" {
-					return nil, fmt.Errorf("only base64 audio is supported in gemini")
-				}
-				audioSource := types.NewBase64FileSource(part.GetInputAudio().Data, "audio/"+part.GetInputAudio().Format)
-				base64Data, mimeType, err := service.GetBase64Data(c, audioSource, "formatting audio for Gemini")
-				if err != nil {
-					return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
-				}
 				parts = append(parts, dto.GeminiPart{
 					InlineData: &dto.GeminiInlineData{
 						MimeType: mimeType,

+ 2 - 9
relay/channel/ollama/relay-ollama.go

@@ -98,15 +98,8 @@ func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*Ollam
 			parts := m.ParseContent()
 			for _, part := range parts {
 				if part.Type == dto.ContentTypeImageURL {
-					img := part.GetImageMedia()
-					if img != nil && img.Url != "" {
-						// 使用统一的文件服务获取图片数据
-						var source *types.FileSource
-						if strings.HasPrefix(img.Url, "http") {
-							source = types.NewURLFileSource(img.Url)
-						} else {
-							source = types.NewBase64FileSource(img.Url, "")
-						}
+					source := part.ToFileSource()
+					if source != nil {
 						base64Data, _, err := service.GetBase64Data(c, source, "fetch image for ollama chat")
 						if err != nil {
 							return nil, err

+ 53 - 32
service/file_service.go

@@ -25,14 +25,26 @@ import (
 // FileService 统一的文件处理服务
 // 提供文件下载、解码、缓存等功能的统一入口
 
-// getContextCacheKey 生成 context 缓存的 key
+// getContextCacheKey 生成 URL context 缓存的 key
 func getContextCacheKey(url string) string {
 	return fmt.Sprintf("file_cache_%s", common.GenerateHMAC(url))
 }
 
+// getBase64ContextCacheKey 生成 base64 context 缓存的 key
+// 使用 length + MIME + 前 128 字符作为输入,避免对整个 base64 数据做 hash
+func getBase64ContextCacheKey(data string, mimeType string) string {
+	keyMaterial := fmt.Sprintf("%d:%s:", len(data), mimeType)
+	if len(data) > 128 {
+		keyMaterial += data[:128]
+	} else {
+		keyMaterial += data
+	}
+	return fmt.Sprintf("b64_cache_%s", common.GenerateHMAC(keyMaterial))
+}
+
 // LoadFileSource 加载文件源数据
 // 这是统一的入口,会自动处理缓存和不同的来源类型
-func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string) (*types.CachedFileData, error) {
+func LoadFileSource(c *gin.Context, source types.FileSource, reason ...string) (*types.CachedFileData, error) {
 	if source == nil {
 		return nil, fmt.Errorf("file source is nil")
 	}
@@ -43,7 +55,6 @@ func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string)
 
 	// 1. 快速检查内部缓存
 	if source.HasCache() {
-		// 即使命中内部缓存,也要确保注册到清理列表(如果尚未注册)
 		if c != nil {
 			registerSourceForCleanup(c, source)
 		}
@@ -62,39 +73,49 @@ func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string)
 		return source.GetCache(), nil
 	}
 
-	// 4. 如果是 URL,检查 Context 缓存
-	var contextKey string
-	if source.IsURL() && c != nil {
-		contextKey = getContextCacheKey(source.URL)
-		if cachedData, exists := c.Get(contextKey); exists {
-			data := cachedData.(*types.CachedFileData)
-			source.SetCache(data)
-			registerSourceForCleanup(c, source)
-			return data, nil
-		}
-	}
-
-	// 5. 执行加载逻辑
+	// 4. 根据来源类型加载(含 URL context 缓存查找)
 	var cachedData *types.CachedFileData
+	var contextKey string
 	var err error
 
-	if source.IsURL() {
-		cachedData, err = loadFromURL(c, source.URL, reason...)
-	} else {
-		cachedData, err = loadFromBase64(source.Base64Data, source.MimeType)
+	switch s := source.(type) {
+	case *types.URLSource:
+		if c != nil {
+			contextKey = getContextCacheKey(s.URL)
+			if cached, exists := c.Get(contextKey); exists {
+				data := cached.(*types.CachedFileData)
+				source.SetCache(data)
+				registerSourceForCleanup(c, source)
+				return data, nil
+			}
+		}
+		cachedData, err = loadFromURL(c, s.URL, reason...)
+	case *types.Base64Source:
+		if c != nil {
+			contextKey = getBase64ContextCacheKey(s.Base64Data, s.MimeType)
+			if cached, exists := c.Get(contextKey); exists {
+				data := cached.(*types.CachedFileData)
+				source.SetCache(data)
+				registerSourceForCleanup(c, source)
+				return data, nil
+			}
+		}
+		cachedData, err = loadFromBase64(s.Base64Data, s.MimeType)
+	default:
+		return nil, fmt.Errorf("unsupported file source type: %T", source)
 	}
 
 	if err != nil {
 		return nil, err
 	}
 
-	// 6. 设置缓存
+	// 5. 设置缓存
 	source.SetCache(cachedData)
 	if contextKey != "" && c != nil {
 		c.Set(contextKey, cachedData)
 	}
 
-	// 7. 注册到 context 以便请求结束时自动清理
+	// 6. 注册到 context 以便请求结束时自动清理
 	if c != nil {
 		registerSourceForCleanup(c, source)
 	}
@@ -103,15 +124,15 @@ func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string)
 }
 
 // registerSourceForCleanup 注册 FileSource 到 context 以便请求结束时清理
-func registerSourceForCleanup(c *gin.Context, source *types.FileSource) {
+func registerSourceForCleanup(c *gin.Context, source types.FileSource) {
 	if source.IsRegistered() {
 		return
 	}
 
 	key := string(constant.ContextKeyFileSourcesToCleanup)
-	var sources []*types.FileSource
+	var sources []types.FileSource
 	if existing, exists := c.Get(key); exists {
-		sources = existing.([]*types.FileSource)
+		sources = existing.([]types.FileSource)
 	}
 	sources = append(sources, source)
 	c.Set(key, sources)
@@ -123,12 +144,12 @@ func registerSourceForCleanup(c *gin.Context, source *types.FileSource) {
 func CleanupFileSources(c *gin.Context) {
 	key := string(constant.ContextKeyFileSourcesToCleanup)
 	if sources, exists := c.Get(key); exists {
-		for _, source := range sources.([]*types.FileSource) {
+		for _, source := range sources.([]types.FileSource) {
 			if cache := source.GetCache(); cache != nil {
 				cache.Close()
 			}
 		}
-		c.Set(key, nil) // 清除引用
+		c.Set(key, nil)
 	}
 }
 
@@ -363,7 +384,7 @@ func loadFromBase64(base64String string, providedMimeType string) (*types.Cached
 }
 
 // GetImageConfig 获取图片配置
-func GetImageConfig(c *gin.Context, source *types.FileSource) (image.Config, string, error) {
+func GetImageConfig(c *gin.Context, source types.FileSource) (image.Config, string, error) {
 	cachedData, err := LoadFileSource(c, source, "get_image_config")
 	if err != nil {
 		return image.Config{}, "", err
@@ -394,7 +415,7 @@ func GetImageConfig(c *gin.Context, source *types.FileSource) (image.Config, str
 }
 
 // GetBase64Data 获取 base64 编码的数据
-func GetBase64Data(c *gin.Context, source *types.FileSource, reason ...string) (string, string, error) {
+func GetBase64Data(c *gin.Context, source types.FileSource, reason ...string) (string, string, error) {
 	cachedData, err := LoadFileSource(c, source, reason...)
 	if err != nil {
 		return "", "", err
@@ -407,13 +428,13 @@ func GetBase64Data(c *gin.Context, source *types.FileSource, reason ...string) (
 }
 
 // GetMimeType 获取文件的 MIME 类型
-func GetMimeType(c *gin.Context, source *types.FileSource) (string, error) {
+func GetMimeType(c *gin.Context, source types.FileSource) (string, error) {
 	if source.HasCache() {
 		return source.GetCache().MimeType, nil
 	}
 
-	if source.IsURL() {
-		mimeType, err := GetFileTypeFromUrl(c, source.URL, "get_mime_type")
+	if urlSource, ok := source.(*types.URLSource); ok {
+		mimeType, err := GetFileTypeFromUrl(c, urlSource.URL, "get_mime_type")
 		if err == nil && mimeType != "" && mimeType != "application/octet-stream" {
 			return mimeType, nil
 		}

+ 0 - 3
service/token_counter.go

@@ -100,8 +100,6 @@ func getImageToken(c *gin.Context, fileMeta *types.FileMeta, model string, strea
 	if err != nil {
 		return 0, err
 	}
-	fileMeta.MimeType = format
-
 	if config.Width == 0 || config.Height == 0 {
 		// not an image, but might be a valid file
 		if format != "" {
@@ -268,7 +266,6 @@ func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *rela
 				}
 				continue
 			}
-			file.MimeType = cachedData.MimeType
 			file.FileType = DetectFileType(cachedData.MimeType)
 		}
 	}

+ 127 - 126
types/file_source.go

@@ -4,39 +4,144 @@ import (
 	"fmt"
 	"image"
 	"os"
+	"strings"
 	"sync"
 )
 
-// FileSourceType 文件来源类型
-type FileSourceType string
+// FileSource 统一的文件来源抽象接口
+// 支持 URL 和 base64 两种来源,提供懒加载和缓存机制
+type FileSource interface {
+	IsURL() bool
+	GetIdentifier() string
+	GetRawData() string
+	ClearRawData()
 
-const (
-	FileSourceTypeURL    FileSourceType = "url"    // URL 来源
-	FileSourceTypeBase64 FileSourceType = "base64" // Base64 内联数据
-)
+	SetCache(data *CachedFileData)
+	GetCache() *CachedFileData
+	HasCache() bool
+	ClearCache()
 
-// FileSource 统一的文件来源抽象
-// 支持 URL 和 base64 两种来源,提供懒加载和缓存机制
-type FileSource struct {
-	Type       FileSourceType `json:"type"`                  // 来源类型
-	URL        string         `json:"url,omitempty"`         // URL(当 Type 为 url 时)
-	Base64Data string         `json:"base64_data,omitempty"` // Base64 数据(当 Type 为 base64 时)
-	MimeType   string         `json:"mime_type,omitempty"`   // MIME 类型(可选,会自动检测)
+	IsRegistered() bool
+	SetRegistered(registered bool)
+	Mu() *sync.Mutex
+}
 
-	// 内部缓存(不导出,不序列化)
+// baseFileSource 共享的缓存/锁/清理注册状态
+type baseFileSource struct {
 	cachedData  *CachedFileData
 	cacheLoaded bool
-	registered  bool       // 是否已注册到清理列表
-	mu          sync.Mutex // 保护加载过程
+	registered  bool
+	mu          sync.Mutex
+}
+
+func (b *baseFileSource) SetCache(data *CachedFileData) {
+	b.cachedData = data
+	b.cacheLoaded = true
+}
+
+func (b *baseFileSource) GetCache() *CachedFileData {
+	return b.cachedData
+}
+
+func (b *baseFileSource) HasCache() bool {
+	return b.cacheLoaded && b.cachedData != nil
+}
+
+func (b *baseFileSource) ClearCache() {
+	if b.cachedData != nil {
+		b.cachedData.Close()
+	}
+	b.cachedData = nil
+	b.cacheLoaded = false
+}
+
+func (b *baseFileSource) IsRegistered() bool {
+	return b.registered
+}
+
+func (b *baseFileSource) SetRegistered(registered bool) {
+	b.registered = registered
 }
 
-// Mu 获取内部锁
-func (f *FileSource) Mu() *sync.Mutex {
-	return &f.mu
+func (b *baseFileSource) Mu() *sync.Mutex {
+	return &b.mu
 }
 
-// CachedFileData 缓存的文件数据
-// 支持内存缓存和磁盘缓存两种模式
+// ---------------------------------------------------------------------------
+// URLSource — URL 来源的 FileSource 实现
+// ---------------------------------------------------------------------------
+
+type URLSource struct {
+	baseFileSource
+	URL string
+}
+
+func (u *URLSource) IsURL() bool { return true }
+
+func (u *URLSource) GetIdentifier() string {
+	if len(u.URL) > 100 {
+		return u.URL[:100] + "..."
+	}
+	return u.URL
+}
+
+func (u *URLSource) GetRawData() string { return u.URL }
+
+func (u *URLSource) ClearRawData() {}
+
+// ---------------------------------------------------------------------------
+// Base64Source — Base64 内联数据来源的 FileSource 实现
+// ---------------------------------------------------------------------------
+
+type Base64Source struct {
+	baseFileSource
+	Base64Data string
+	MimeType   string
+}
+
+func (b *Base64Source) IsURL() bool { return false }
+
+func (b *Base64Source) GetIdentifier() string {
+	if len(b.Base64Data) > 50 {
+		return "base64:" + b.Base64Data[:50] + "..."
+	}
+	return "base64:" + b.Base64Data
+}
+
+func (b *Base64Source) GetRawData() string { return b.Base64Data }
+
+func (b *Base64Source) ClearRawData() {
+	if len(b.Base64Data) > 1024 {
+		b.Base64Data = ""
+	}
+}
+
+// ---------------------------------------------------------------------------
+// Constructors
+// ---------------------------------------------------------------------------
+
+func NewURLFileSource(url string) *URLSource {
+	return &URLSource{URL: url}
+}
+
+func NewBase64FileSource(base64Data string, mimeType string) *Base64Source {
+	return &Base64Source{
+		Base64Data: base64Data,
+		MimeType:   mimeType,
+	}
+}
+
+func NewFileSourceFromData(data string, mimeType string) FileSource {
+	if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") {
+		return NewURLFileSource(data)
+	}
+	return NewBase64FileSource(data, mimeType)
+}
+
+// ---------------------------------------------------------------------------
+// CachedFileData — 缓存的文件数据(支持内存和磁盘两种模式)
+// ---------------------------------------------------------------------------
+
 type CachedFileData struct {
 	base64Data  string        // 内存中的 base64 数据(小文件)
 	MimeType    string        // MIME 类型
@@ -45,18 +150,15 @@ type CachedFileData struct {
 	ImageConfig *image.Config // 图片配置(如果是图片)
 	ImageFormat string        // 图片格式(如果是图片)
 
-	// 磁盘缓存相关
 	diskPath        string     // 磁盘缓存文件路径(大文件)
 	isDisk          bool       // 是否使用磁盘缓存
 	diskMu          sync.Mutex // 磁盘操作锁(保护磁盘文件的读取和删除)
 	diskClosed      bool       // 是否已关闭/清理
 	statDecremented bool       // 是否已扣减统计
 
-	// 统计回调,避免循环依赖
 	OnClose func(size int64)
 }
 
-// NewMemoryCachedData 创建内存缓存的数据
 func NewMemoryCachedData(base64Data string, mimeType string, size int64) *CachedFileData {
 	return &CachedFileData{
 		base64Data: base64Data,
@@ -66,7 +168,6 @@ func NewMemoryCachedData(base64Data string, mimeType string, size int64) *Cached
 	}
 }
 
-// NewDiskCachedData 创建磁盘缓存的数据
 func NewDiskCachedData(diskPath string, mimeType string, size int64) *CachedFileData {
 	return &CachedFileData{
 		diskPath: diskPath,
@@ -76,7 +177,6 @@ func NewDiskCachedData(diskPath string, mimeType string, size int64) *CachedFile
 	}
 }
 
-// GetBase64Data 获取 base64 数据(自动处理内存/磁盘)
 func (c *CachedFileData) GetBase64Data() (string, error) {
 	if !c.isDisk {
 		return c.base64Data, nil
@@ -89,7 +189,6 @@ func (c *CachedFileData) GetBase64Data() (string, error) {
 		return "", fmt.Errorf("disk cache already closed")
 	}
 
-	// 从磁盘读取
 	data, err := os.ReadFile(c.diskPath)
 	if err != nil {
 		return "", fmt.Errorf("failed to read from disk cache: %w", err)
@@ -97,22 +196,19 @@ func (c *CachedFileData) GetBase64Data() (string, error) {
 	return string(data), nil
 }
 
-// SetBase64Data 设置 base64 数据(仅用于内存模式)
 func (c *CachedFileData) SetBase64Data(data string) {
 	if !c.isDisk {
 		c.base64Data = data
 	}
 }
 
-// IsDisk 是否使用磁盘缓存
 func (c *CachedFileData) IsDisk() bool {
 	return c.isDisk
 }
 
-// Close 关闭并清理资源
 func (c *CachedFileData) Close() error {
 	if !c.isDisk {
-		c.base64Data = "" // 释放内存
+		c.base64Data = ""
 		return nil
 	}
 
@@ -126,7 +222,6 @@ func (c *CachedFileData) Close() error {
 	c.diskClosed = true
 	if c.diskPath != "" {
 		err := os.Remove(c.diskPath)
-		// 只有在删除成功且未扣减过统计时,才执行回调
 		if err == nil && !c.statDecremented && c.OnClose != nil {
 			c.OnClose(c.DiskSize)
 			c.statDecremented = true
@@ -135,97 +230,3 @@ func (c *CachedFileData) Close() error {
 	}
 	return nil
 }
-
-// NewURLFileSource 创建 URL 来源的 FileSource
-func NewURLFileSource(url string) *FileSource {
-	return &FileSource{
-		Type: FileSourceTypeURL,
-		URL:  url,
-	}
-}
-
-// NewBase64FileSource 创建 base64 来源的 FileSource
-func NewBase64FileSource(base64Data string, mimeType string) *FileSource {
-	return &FileSource{
-		Type:       FileSourceTypeBase64,
-		Base64Data: base64Data,
-		MimeType:   mimeType,
-	}
-}
-
-// IsURL 判断是否是 URL 来源
-func (f *FileSource) IsURL() bool {
-	return f.Type == FileSourceTypeURL
-}
-
-// IsBase64 判断是否是 base64 来源
-func (f *FileSource) IsBase64() bool {
-	return f.Type == FileSourceTypeBase64
-}
-
-// GetIdentifier 获取文件标识符(用于日志和错误追踪)
-func (f *FileSource) GetIdentifier() string {
-	if f.IsURL() {
-		if len(f.URL) > 100 {
-			return f.URL[:100] + "..."
-		}
-		return f.URL
-	}
-	if len(f.Base64Data) > 50 {
-		return "base64:" + f.Base64Data[:50] + "..."
-	}
-	return "base64:" + f.Base64Data
-}
-
-// GetRawData 获取原始数据(URL 或完整的 base64 字符串)
-func (f *FileSource) GetRawData() string {
-	if f.IsURL() {
-		return f.URL
-	}
-	return f.Base64Data
-}
-
-// SetCache 设置缓存数据
-func (f *FileSource) SetCache(data *CachedFileData) {
-	f.cachedData = data
-	f.cacheLoaded = true
-}
-
-// IsRegistered 是否已注册到清理列表
-func (f *FileSource) IsRegistered() bool {
-	return f.registered
-}
-
-// SetRegistered 设置注册状态
-func (f *FileSource) SetRegistered(registered bool) {
-	f.registered = registered
-}
-
-// GetCache 获取缓存数据
-func (f *FileSource) GetCache() *CachedFileData {
-	return f.cachedData
-}
-
-// HasCache 是否有缓存
-func (f *FileSource) HasCache() bool {
-	return f.cacheLoaded && f.cachedData != nil
-}
-
-// ClearCache 清除缓存,释放内存和磁盘文件
-func (f *FileSource) ClearCache() {
-	// 如果有缓存数据,先关闭它(会清理磁盘文件)
-	if f.cachedData != nil {
-		f.cachedData.Close()
-	}
-	f.cachedData = nil
-	f.cacheLoaded = false
-}
-
-// ClearRawData 清除原始数据,只保留必要的元信息
-// 用于在处理完成后释放大文件的内存
-func (f *FileSource) ClearRawData() {
-	// 保留 URL(通常很短),只清除大的 base64 数据
-	if f.IsBase64() && len(f.Base64Data) > 1024 {
-		f.Base64Data = ""
-	}
-}

+ 4 - 5
types/request_meta.go

@@ -32,13 +32,12 @@ type TokenCountMeta struct {
 
 type FileMeta struct {
 	FileType
-	MimeType string
-	Source   *FileSource // 统一的文件来源(URL 或 base64)
-	Detail   string      // 图片细节级别(low/high/auto)
+	Source FileSource // 统一的文件来源(URL 或 base64)
+	Detail string     // 图片细节级别(low/high/auto)
 }
 
 // NewFileMeta 创建新的 FileMeta
-func NewFileMeta(fileType FileType, source *FileSource) *FileMeta {
+func NewFileMeta(fileType FileType, source FileSource) *FileMeta {
 	return &FileMeta{
 		FileType: fileType,
 		Source:   source,
@@ -46,7 +45,7 @@ func NewFileMeta(fileType FileType, source *FileSource) *FileMeta {
 }
 
 // NewImageFileMeta 创建图片类型的 FileMeta
-func NewImageFileMeta(source *FileSource, detail string) *FileMeta {
+func NewImageFileMeta(source FileSource, detail string) *FileMeta {
 	return &FileMeta{
 		FileType: FileTypeImage,
 		Source:   source,