| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471 |
- package service
- import (
- "bytes"
- "encoding/base64"
- "fmt"
- "image"
- _ "image/gif"
- _ "image/jpeg"
- _ "image/png"
- "io"
- "net/http"
- "strings"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/constant"
- "github.com/QuantumNous/new-api/logger"
- "github.com/QuantumNous/new-api/types"
- "github.com/gin-gonic/gin"
- "golang.org/x/image/webp"
- )
- // FileService 统一的文件处理服务
- // 提供文件下载、解码、缓存等功能的统一入口
- // getContextCacheKey 生成 context 缓存的 key
- func getContextCacheKey(url string) string {
- return fmt.Sprintf("file_cache_%s", common.GenerateHMAC(url))
- }
- // LoadFileSource 加载文件源数据
- // 这是统一的入口,会自动处理缓存和不同的来源类型
- func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string) (*types.CachedFileData, error) {
- if source == nil {
- return nil, fmt.Errorf("file source is nil")
- }
- if common.DebugEnabled {
- logger.LogDebug(c, fmt.Sprintf("LoadFileSource starting for: %s", source.GetIdentifier()))
- }
- // 1. 快速检查内部缓存
- if source.HasCache() {
- // 即使命中内部缓存,也要确保注册到清理列表(如果尚未注册)
- if c != nil {
- registerSourceForCleanup(c, source)
- }
- return source.GetCache(), nil
- }
- // 2. 加锁保护加载过程
- source.Mu().Lock()
- defer source.Mu().Unlock()
- // 3. 双重检查
- if source.HasCache() {
- if c != nil {
- registerSourceForCleanup(c, source)
- }
- return source.GetCache(), nil
- }
- // 4. 如果是 URL,检查 Context 缓存
- var contextKey string
- if source.IsURL() && c != nil {
- contextKey = getContextCacheKey(source.URL)
- if cachedData, exists := c.Get(contextKey); exists {
- data := cachedData.(*types.CachedFileData)
- source.SetCache(data)
- registerSourceForCleanup(c, source)
- return data, nil
- }
- }
- // 5. 执行加载逻辑
- var cachedData *types.CachedFileData
- var err error
- if source.IsURL() {
- cachedData, err = loadFromURL(c, source.URL, reason...)
- } else {
- cachedData, err = loadFromBase64(source.Base64Data, source.MimeType)
- }
- if err != nil {
- return nil, err
- }
- // 6. 设置缓存
- source.SetCache(cachedData)
- if contextKey != "" && c != nil {
- c.Set(contextKey, cachedData)
- }
- // 7. 注册到 context 以便请求结束时自动清理
- if c != nil {
- registerSourceForCleanup(c, source)
- }
- return cachedData, nil
- }
- // registerSourceForCleanup 注册 FileSource 到 context 以便请求结束时清理
- func registerSourceForCleanup(c *gin.Context, source *types.FileSource) {
- if source.IsRegistered() {
- return
- }
- key := string(constant.ContextKeyFileSourcesToCleanup)
- var sources []*types.FileSource
- if existing, exists := c.Get(key); exists {
- sources = existing.([]*types.FileSource)
- }
- sources = append(sources, source)
- c.Set(key, sources)
- source.SetRegistered(true)
- }
- // CleanupFileSources 清理请求中所有注册的 FileSource
- // 应在请求结束时调用(通常由中间件自动调用)
- func CleanupFileSources(c *gin.Context) {
- key := string(constant.ContextKeyFileSourcesToCleanup)
- if sources, exists := c.Get(key); exists {
- for _, source := range sources.([]*types.FileSource) {
- if cache := source.GetCache(); cache != nil {
- cache.Close()
- }
- }
- c.Set(key, nil) // 清除引用
- }
- }
- // loadFromURL 从 URL 加载文件
- func loadFromURL(c *gin.Context, url string, reason ...string) (*types.CachedFileData, error) {
- // 下载文件
- var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
- if common.DebugEnabled {
- logger.LogDebug(c, "loadFromURL: initiating download")
- }
- resp, err := DoDownloadRequest(url, reason...)
- if err != nil {
- return nil, fmt.Errorf("failed to download file from %s: %w", url, err)
- }
- defer resp.Body.Close()
- if resp.StatusCode != 200 {
- return nil, fmt.Errorf("failed to download file, status code: %d", resp.StatusCode)
- }
- // 读取文件内容(限制大小)
- if common.DebugEnabled {
- logger.LogDebug(c, "loadFromURL: reading response body")
- }
- fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1)))
- if err != nil {
- return nil, fmt.Errorf("failed to read file content: %w", err)
- }
- if len(fileBytes) > maxFileSize {
- return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)
- }
- // 转换为 base64
- base64Data := base64.StdEncoding.EncodeToString(fileBytes)
- // 智能获取 MIME 类型
- mimeType := smartDetectMimeType(resp, url, fileBytes)
- // 判断是否使用磁盘缓存
- base64Size := int64(len(base64Data))
- var cachedData *types.CachedFileData
- if shouldUseDiskCache(base64Size) {
- // 使用磁盘缓存
- diskPath, err := writeToDiskCache(base64Data)
- if err != nil {
- // 磁盘缓存失败,回退到内存
- logger.LogWarn(c, fmt.Sprintf("Failed to write to disk cache, falling back to memory: %v", err))
- cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes)))
- } else {
- cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(fileBytes)))
- cachedData.DiskSize = base64Size
- cachedData.OnClose = func(size int64) {
- common.DecrementDiskFiles(size)
- }
- common.IncrementDiskFiles(base64Size)
- if common.DebugEnabled {
- logger.LogDebug(c, fmt.Sprintf("File cached to disk: %s, size: %d bytes", diskPath, base64Size))
- }
- }
- } else {
- // 使用内存缓存
- cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes)))
- }
- // 如果是图片,尝试获取图片配置
- if strings.HasPrefix(mimeType, "image/") {
- if common.DebugEnabled {
- logger.LogDebug(c, "loadFromURL: decoding image config")
- }
- config, format, err := decodeImageConfig(fileBytes)
- if err == nil {
- cachedData.ImageConfig = &config
- cachedData.ImageFormat = format
- // 如果通过图片解码获取了更准确的格式,更新 MIME 类型
- if mimeType == "application/octet-stream" || mimeType == "" {
- cachedData.MimeType = "image/" + format
- }
- }
- }
- return cachedData, nil
- }
- // shouldUseDiskCache 判断是否应该使用磁盘缓存
- func shouldUseDiskCache(dataSize int64) bool {
- return common.ShouldUseDiskCache(dataSize)
- }
- // writeToDiskCache 将数据写入磁盘缓存
- func writeToDiskCache(base64Data string) (string, error) {
- return common.WriteDiskCacheFileString(common.DiskCacheTypeFile, base64Data)
- }
- // smartDetectMimeType 智能检测 MIME 类型
- func smartDetectMimeType(resp *http.Response, url string, fileBytes []byte) string {
- // 1. 尝试从 Content-Type header 获取
- mimeType := resp.Header.Get("Content-Type")
- if idx := strings.Index(mimeType, ";"); idx != -1 {
- mimeType = strings.TrimSpace(mimeType[:idx])
- }
- if mimeType != "" && mimeType != "application/octet-stream" {
- return mimeType
- }
- // 2. 尝试从 Content-Disposition header 的 filename 获取
- if cd := resp.Header.Get("Content-Disposition"); cd != "" {
- parts := strings.Split(cd, ";")
- for _, part := range parts {
- part = strings.TrimSpace(part)
- if strings.HasPrefix(strings.ToLower(part), "filename=") {
- name := strings.TrimSpace(strings.TrimPrefix(part, "filename="))
- // 移除引号
- if len(name) > 2 && name[0] == '"' && name[len(name)-1] == '"' {
- name = name[1 : len(name)-1]
- }
- if dot := strings.LastIndex(name, "."); dot != -1 && dot+1 < len(name) {
- ext := strings.ToLower(name[dot+1:])
- if ext != "" {
- mt := GetMimeTypeByExtension(ext)
- if mt != "application/octet-stream" {
- return mt
- }
- }
- }
- break
- }
- }
- }
- // 3. 尝试从 URL 路径获取扩展名
- mt := guessMimeTypeFromURL(url)
- if mt != "application/octet-stream" {
- return mt
- }
- // 4. 使用 http.DetectContentType 内容嗅探
- if len(fileBytes) > 0 {
- sniffed := http.DetectContentType(fileBytes)
- if sniffed != "" && sniffed != "application/octet-stream" {
- // 去除可能的 charset 参数
- if idx := strings.Index(sniffed, ";"); idx != -1 {
- sniffed = strings.TrimSpace(sniffed[:idx])
- }
- return sniffed
- }
- }
- // 5. 尝试作为图片解码获取格式
- if len(fileBytes) > 0 {
- if _, format, err := decodeImageConfig(fileBytes); err == nil && format != "" {
- return "image/" + strings.ToLower(format)
- }
- }
- // 最终回退
- return "application/octet-stream"
- }
- // loadFromBase64 从 base64 字符串加载文件
- func loadFromBase64(base64String string, providedMimeType string) (*types.CachedFileData, error) {
- var mimeType string
- var cleanBase64 string
- // 处理 data: 前缀
- if strings.HasPrefix(base64String, "data:") {
- idx := strings.Index(base64String, ",")
- if idx != -1 {
- header := base64String[:idx]
- cleanBase64 = base64String[idx+1:]
- if strings.Contains(header, ":") && strings.Contains(header, ";") {
- mimeStart := strings.Index(header, ":") + 1
- mimeEnd := strings.Index(header, ";")
- if mimeStart < mimeEnd {
- mimeType = header[mimeStart:mimeEnd]
- }
- }
- } else {
- cleanBase64 = base64String
- }
- } else {
- cleanBase64 = base64String
- }
- if providedMimeType != "" {
- mimeType = providedMimeType
- }
- decodedData, err := base64.StdEncoding.DecodeString(cleanBase64)
- if err != nil {
- return nil, fmt.Errorf("failed to decode base64 data: %w", err)
- }
- base64Size := int64(len(cleanBase64))
- var cachedData *types.CachedFileData
- if shouldUseDiskCache(base64Size) {
- diskPath, err := writeToDiskCache(cleanBase64)
- if err != nil {
- cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData)))
- } else {
- cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(decodedData)))
- cachedData.DiskSize = base64Size
- cachedData.OnClose = func(size int64) {
- common.DecrementDiskFiles(size)
- }
- common.IncrementDiskFiles(base64Size)
- }
- } else {
- cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData)))
- }
- if mimeType == "" || strings.HasPrefix(mimeType, "image/") {
- config, format, err := decodeImageConfig(decodedData)
- if err == nil {
- cachedData.ImageConfig = &config
- cachedData.ImageFormat = format
- if mimeType == "" {
- cachedData.MimeType = "image/" + format
- }
- }
- }
- return cachedData, nil
- }
- // GetImageConfig 获取图片配置
- func GetImageConfig(c *gin.Context, source *types.FileSource) (image.Config, string, error) {
- cachedData, err := LoadFileSource(c, source, "get_image_config")
- if err != nil {
- return image.Config{}, "", err
- }
- if cachedData.ImageConfig != nil {
- return *cachedData.ImageConfig, cachedData.ImageFormat, nil
- }
- base64Str, err := cachedData.GetBase64Data()
- if err != nil {
- return image.Config{}, "", fmt.Errorf("failed to get base64 data: %w", err)
- }
- decodedData, err := base64.StdEncoding.DecodeString(base64Str)
- if err != nil {
- return image.Config{}, "", fmt.Errorf("failed to decode base64 for image config: %w", err)
- }
- config, format, err := decodeImageConfig(decodedData)
- if err != nil {
- return image.Config{}, "", err
- }
- cachedData.ImageConfig = &config
- cachedData.ImageFormat = format
- return config, format, nil
- }
- // GetBase64Data 获取 base64 编码的数据
- func GetBase64Data(c *gin.Context, source *types.FileSource, reason ...string) (string, string, error) {
- cachedData, err := LoadFileSource(c, source, reason...)
- if err != nil {
- return "", "", err
- }
- base64Str, err := cachedData.GetBase64Data()
- if err != nil {
- return "", "", fmt.Errorf("failed to get base64 data: %w", err)
- }
- return base64Str, cachedData.MimeType, nil
- }
- // GetMimeType 获取文件的 MIME 类型
- func GetMimeType(c *gin.Context, source *types.FileSource) (string, error) {
- if source.HasCache() {
- return source.GetCache().MimeType, nil
- }
- if source.IsURL() {
- mimeType, err := GetFileTypeFromUrl(c, source.URL, "get_mime_type")
- if err == nil && mimeType != "" && mimeType != "application/octet-stream" {
- return mimeType, nil
- }
- }
- cachedData, err := LoadFileSource(c, source, "get_mime_type")
- if err != nil {
- return "", err
- }
- return cachedData.MimeType, nil
- }
- // DetectFileType 检测文件类型
- func DetectFileType(mimeType string) types.FileType {
- if strings.HasPrefix(mimeType, "image/") {
- return types.FileTypeImage
- }
- if strings.HasPrefix(mimeType, "audio/") {
- return types.FileTypeAudio
- }
- if strings.HasPrefix(mimeType, "video/") {
- return types.FileTypeVideo
- }
- return types.FileTypeFile
- }
- // decodeImageConfig 从字节数据解码图片配置
- func decodeImageConfig(data []byte) (image.Config, string, error) {
- reader := bytes.NewReader(data)
- config, format, err := image.DecodeConfig(reader)
- if err == nil {
- return config, format, nil
- }
- reader.Seek(0, io.SeekStart)
- config, err = webp.DecodeConfig(reader)
- if err == nil {
- return config, "webp", nil
- }
- return image.Config{}, "", fmt.Errorf("failed to decode image config: unsupported format")
- }
- // guessMimeTypeFromURL 从 URL 猜测 MIME 类型
- func guessMimeTypeFromURL(url string) string {
- cleanedURL := url
- if q := strings.Index(cleanedURL, "?"); q != -1 {
- cleanedURL = cleanedURL[:q]
- }
- if slash := strings.LastIndex(cleanedURL, "/"); slash != -1 && slash+1 < len(cleanedURL) {
- last := cleanedURL[slash+1:]
- if dot := strings.LastIndex(last, "."); dot != -1 && dot+1 < len(last) {
- ext := strings.ToLower(last[dot+1:])
- return GetMimeTypeByExtension(ext)
- }
- }
- return "application/octet-stream"
- }
|