| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203 |
- package service
- import (
- "bytes"
- "fmt"
- "image"
- _ "image/gif"
- _ "image/jpeg"
- _ "image/png"
- "io"
- "net/http"
- "strings"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/logger"
- "github.com/QuantumNous/new-api/types"
- "github.com/gin-gonic/gin"
- )
- // GetFileTypeFromUrl 获取文件类型,返回 mime type, 例如 image/jpeg, image/png, image/gif, image/bmp, image/tiff, application/pdf
- // 如果获取失败,返回 application/octet-stream
- func GetFileTypeFromUrl(c *gin.Context, url string, reason ...string) (string, error) {
- response, err := DoDownloadRequest(url, []string{"get_mime_type", strings.Join(reason, ", ")}...)
- if err != nil {
- common.SysLog(fmt.Sprintf("fail to get file type from url: %s, error: %s", url, err.Error()))
- return "", err
- }
- defer response.Body.Close()
- if response.StatusCode != 200 {
- logger.LogError(c, fmt.Sprintf("failed to download file from %s, status code: %d", url, response.StatusCode))
- return "", fmt.Errorf("failed to download file, status code: %d", response.StatusCode)
- }
- if headerType := strings.TrimSpace(response.Header.Get("Content-Type")); headerType != "" {
- if i := strings.Index(headerType, ";"); i != -1 {
- headerType = headerType[:i]
- }
- if headerType != "application/octet-stream" {
- return headerType, nil
- }
- }
- if cd := response.Header.Get("Content-Disposition"); cd != "" {
- parts := strings.Split(cd, ";")
- for _, part := range parts {
- part = strings.TrimSpace(part)
- if strings.HasPrefix(strings.ToLower(part), "filename=") {
- name := strings.TrimSpace(strings.TrimPrefix(part, "filename="))
- if len(name) > 2 && name[0] == '"' && name[len(name)-1] == '"' {
- name = name[1 : len(name)-1]
- }
- if dot := strings.LastIndex(name, "."); dot != -1 && dot+1 < len(name) {
- ext := strings.ToLower(name[dot+1:])
- if ext != "" {
- mt := GetMimeTypeByExtension(ext)
- if mt != "application/octet-stream" {
- return mt, nil
- }
- }
- }
- break
- }
- }
- }
- cleanedURL := url
- if q := strings.Index(cleanedURL, "?"); q != -1 {
- cleanedURL = cleanedURL[:q]
- }
- if slash := strings.LastIndex(cleanedURL, "/"); slash != -1 && slash+1 < len(cleanedURL) {
- last := cleanedURL[slash+1:]
- if dot := strings.LastIndex(last, "."); dot != -1 && dot+1 < len(last) {
- ext := strings.ToLower(last[dot+1:])
- if ext != "" {
- mt := GetMimeTypeByExtension(ext)
- if mt != "application/octet-stream" {
- return mt, nil
- }
- }
- }
- }
- var readData []byte
- limits := []int{512, 8 * 1024, 24 * 1024, 64 * 1024}
- for _, limit := range limits {
- logger.LogDebug(c, fmt.Sprintf("Trying to read %d bytes to determine file type", limit))
- if len(readData) < limit {
- need := limit - len(readData)
- tmp := make([]byte, need)
- n, _ := io.ReadFull(response.Body, tmp)
- if n > 0 {
- readData = append(readData, tmp[:n]...)
- }
- }
- if len(readData) == 0 {
- continue
- }
- sniffed := http.DetectContentType(readData)
- if sniffed != "" && sniffed != "application/octet-stream" {
- return sniffed, nil
- }
- if _, format, err := image.DecodeConfig(bytes.NewReader(readData)); err == nil {
- switch strings.ToLower(format) {
- case "jpeg", "jpg":
- return "image/jpeg", nil
- case "png":
- return "image/png", nil
- case "gif":
- return "image/gif", nil
- case "bmp":
- return "image/bmp", nil
- case "tiff":
- return "image/tiff", nil
- default:
- if format != "" {
- return "image/" + strings.ToLower(format), nil
- }
- }
- }
- }
- // Fallback
- return "application/octet-stream", nil
- }
- // GetFileBase64FromUrl 从 URL 获取文件的 base64 编码数据
- // Deprecated: 请使用 GetBase64Data 配合 types.NewURLFileSource 替代
- // 此函数保留用于向后兼容,内部已重构为调用统一的文件服务
- func GetFileBase64FromUrl(c *gin.Context, url string, reason ...string) (*types.LocalFileData, error) {
- source := types.NewURLFileSource(url)
- cachedData, err := LoadFileSource(c, source, reason...)
- if err != nil {
- return nil, err
- }
- // 转换为旧的 LocalFileData 格式以保持兼容
- base64Data, err := cachedData.GetBase64Data()
- if err != nil {
- return nil, err
- }
- return &types.LocalFileData{
- Base64Data: base64Data,
- MimeType: cachedData.MimeType,
- Size: cachedData.Size,
- Url: url,
- }, nil
- }
- func GetMimeTypeByExtension(ext string) string {
- // Convert to lowercase for case-insensitive comparison
- ext = strings.ToLower(ext)
- switch ext {
- // Text files
- case "txt", "md", "markdown", "csv", "json", "xml", "html", "htm":
- return "text/plain"
- // Image files
- case "jpg", "jpeg":
- return "image/jpeg"
- case "png":
- return "image/png"
- case "gif":
- return "image/gif"
- case "jfif":
- return "image/jpeg"
- // Audio files
- case "mp3":
- return "audio/mp3"
- case "wav":
- return "audio/wav"
- case "mpeg":
- return "audio/mpeg"
- // Video files
- case "mp4":
- return "video/mp4"
- case "wmv":
- return "video/wmv"
- case "flv":
- return "video/flv"
- case "mov":
- return "video/mov"
- case "mpg":
- return "video/mpg"
- case "avi":
- return "video/avi"
- case "mpegps":
- return "video/mpegps"
- // Document files
- case "pdf":
- return "application/pdf"
- default:
- return "application/octet-stream" // Default for unknown types
- }
- }
|