瀏覽代碼

feat: implement file type detection from URL with enhanced MIME type handling

CaIon 4 月之前
父節點
當前提交
5d6fac69c4
共有 2 個文件被更改,包括 127 次插入14 次删除
  1. 116 4
      service/file_decoder.go
  2. 11 10
      service/token_counter.go

+ 116 - 4
service/file_decoder.go

@@ -1,17 +1,131 @@
 package service
 
 import (
+	"bytes"
 	"encoding/base64"
 	"fmt"
-	"github.com/gin-gonic/gin"
+	"image"
 	"io"
+	"net/http"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/logger"
 	"one-api/types"
 	"strings"
+
+	"github.com/gin-gonic/gin"
 )
 
+// GetFileTypeFromUrl 获取文件类型,返回 mime type, 例如 image/jpeg, image/png, image/gif, image/bmp, image/tiff, application/pdf
+// 如果获取失败,返回 application/octet-stream
+func GetFileTypeFromUrl(c *gin.Context, url string, reason ...string) (string, error) {
+	response, err := DoDownloadRequest(url, reason...)
+	if err != nil {
+		common.SysLog(fmt.Sprintf("fail to get file type from url: %s, error: %s", url, err.Error()))
+		return "", err
+	}
+	defer response.Body.Close()
+
+	if response.StatusCode != 200 {
+		logger.LogError(c, fmt.Sprintf("failed to download file from %s, status code: %d", url, response.StatusCode))
+		return "", fmt.Errorf("failed to download file, status code: %d", response.StatusCode)
+	}
+
+	if headerType := strings.TrimSpace(response.Header.Get("Content-Type")); headerType != "" {
+		if i := strings.Index(headerType, ";"); i != -1 {
+			headerType = headerType[:i]
+		}
+		if headerType != "application/octet-stream" {
+			return headerType, nil
+		}
+	}
+
+	if cd := response.Header.Get("Content-Disposition"); cd != "" {
+		parts := strings.Split(cd, ";")
+		for _, part := range parts {
+			part = strings.TrimSpace(part)
+			if strings.HasPrefix(strings.ToLower(part), "filename=") {
+				name := strings.TrimSpace(strings.TrimPrefix(part, "filename="))
+				if len(name) > 2 && name[0] == '"' && name[len(name)-1] == '"' {
+					name = name[1 : len(name)-1]
+				}
+				if dot := strings.LastIndex(name, "."); dot != -1 && dot+1 < len(name) {
+					ext := strings.ToLower(name[dot+1:])
+					if ext != "" {
+						mt := GetMimeTypeByExtension(ext)
+						if mt != "application/octet-stream" {
+							return mt, nil
+						}
+					}
+				}
+				break
+			}
+		}
+	}
+
+	cleanedURL := url
+	if q := strings.Index(cleanedURL, "?"); q != -1 {
+		cleanedURL = cleanedURL[:q]
+	}
+	if slash := strings.LastIndex(cleanedURL, "/"); slash != -1 && slash+1 < len(cleanedURL) {
+		last := cleanedURL[slash+1:]
+		if dot := strings.LastIndex(last, "."); dot != -1 && dot+1 < len(last) {
+			ext := strings.ToLower(last[dot+1:])
+			if ext != "" {
+				mt := GetMimeTypeByExtension(ext)
+				if mt != "application/octet-stream" {
+					return mt, nil
+				}
+			}
+		}
+	}
+
+	var readData []byte
+	limits := []int{512, 8 * 1024, 24 * 1024, 64 * 1024}
+	for _, limit := range limits {
+		logger.LogDebug(c, fmt.Sprintf("Trying to read %d bytes to determine file type", limit))
+		if len(readData) < limit {
+			need := limit - len(readData)
+			tmp := make([]byte, need)
+			n, _ := io.ReadFull(response.Body, tmp)
+			if n > 0 {
+				readData = append(readData, tmp[:n]...)
+			}
+		}
+
+		if len(readData) == 0 {
+			continue
+		}
+
+		sniffed := http.DetectContentType(readData)
+		if sniffed != "" && sniffed != "application/octet-stream" {
+			return sniffed, nil
+		}
+
+		if _, format, err := image.DecodeConfig(bytes.NewReader(readData)); err == nil {
+			switch strings.ToLower(format) {
+			case "jpeg", "jpg":
+				return "image/jpeg", nil
+			case "png":
+				return "image/png", nil
+			case "gif":
+				return "image/gif", nil
+			case "bmp":
+				return "image/bmp", nil
+			case "tiff":
+				return "image/tiff", nil
+			default:
+				if format != "" {
+					return "image/" + strings.ToLower(format), nil
+				}
+			}
+		}
+	}
+
+	// Fallback
+	return "application/octet-stream", nil
+}
+
 func GetFileBase64FromUrl(c *gin.Context, url string, reason ...string) (*types.LocalFileData, error) {
 	contextKey := fmt.Sprintf("file_download_%s", common.GenerateHMAC(url))
 
@@ -50,9 +164,7 @@ func GetFileBase64FromUrl(c *gin.Context, url string, reason ...string) (*types.
 		mimeType = strings.Split(mimeType, ";")[0]
 	}
 	if mimeType == "application/octet-stream" {
-		if common.DebugEnabled {
-			println("MIME type is application/octet-stream, trying to guess from URL or filename")
-		}
+		logger.LogDebug(c, fmt.Sprintf("MIME type is application/octet-stream for URL: %s", url))
 		// try to guess the MIME type from the url last segment
 		urlParts := strings.Split(url, "/")
 		if len(urlParts) > 0 {

+ 11 - 10
service/token_counter.go

@@ -283,21 +283,20 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
 	if shouldFetchFiles {
 		for _, file := range meta.Files {
 			if strings.HasPrefix(file.OriginData, "http") {
-				localFileData, err := GetFileBase64FromUrl(c, file.OriginData, "token_counter")
+				mineType, err := GetFileTypeFromUrl(c, file.OriginData, "token_counter")
 				if err != nil {
 					return 0, fmt.Errorf("error getting file base64 from url: %v", err)
 				}
-				if strings.HasPrefix(localFileData.MimeType, "image/") {
+				if strings.HasPrefix(mineType, "image/") {
 					file.FileType = types.FileTypeImage
-				} else if strings.HasPrefix(localFileData.MimeType, "video/") {
+				} else if strings.HasPrefix(mineType, "video/") {
 					file.FileType = types.FileTypeVideo
-				} else if strings.HasPrefix(localFileData.MimeType, "audio/") {
+				} else if strings.HasPrefix(mineType, "audio/") {
 					file.FileType = types.FileTypeAudio
 				} else {
 					file.FileType = types.FileTypeFile
 				}
-				file.MimeType = localFileData.MimeType
-				file.ParsedData = localFileData
+				file.MimeType = mineType
 			}
 		}
 	}
@@ -306,7 +305,7 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
 		switch file.FileType {
 		case types.FileTypeImage:
 			if info.RelayFormat == types.RelayFormatGemini {
-				tkm += 240
+				tkm += 256
 			} else {
 				token, err := getImageToken(file, model, info.IsStream)
 				if err != nil {
@@ -315,11 +314,13 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
 				tkm += token
 			}
 		case types.FileTypeAudio:
-			tkm += 100
+			tkm += 256
 		case types.FileTypeVideo:
-			tkm += 5000
+			tkm += 4096 * 2
 		case types.FileTypeFile:
-			tkm += 5000
+			tkm += 4096
+		default:
+			tkm += 4096 // Default case for unknown file types
 		}
 	}