file_decoder.go 5.2 KB


  1. package service
  2. import (
  3. "bytes"
  4. "fmt"
  5. "image"
  6. _ "image/gif"
  7. _ "image/jpeg"
  8. _ "image/png"
  9. "io"
  10. "net/http"
  11. "strings"
  12. "github.com/QuantumNous/new-api/common"
  13. "github.com/QuantumNous/new-api/logger"
  14. "github.com/QuantumNous/new-api/types"
  15. "github.com/gin-gonic/gin"
  16. )
  17. // GetFileTypeFromUrl 获取文件类型,返回 mime type, 例如 image/jpeg, image/png, image/gif, image/bmp, image/tiff, application/pdf
  18. // 如果获取失败,返回 application/octet-stream
  19. func GetFileTypeFromUrl(c *gin.Context, url string, reason ...string) (string, error) {
  20. response, err := DoDownloadRequest(url, []string{"get_mime_type", strings.Join(reason, ", ")}...)
  21. if err != nil {
  22. common.SysLog(fmt.Sprintf("fail to get file type from url: %s, error: %s", url, err.Error()))
  23. return "", err
  24. }
  25. defer response.Body.Close()
  26. if response.StatusCode != 200 {
  27. logger.LogError(c, fmt.Sprintf("failed to download file from %s, status code: %d", url, response.StatusCode))
  28. return "", fmt.Errorf("failed to download file, status code: %d", response.StatusCode)
  29. }
  30. if headerType := strings.TrimSpace(response.Header.Get("Content-Type")); headerType != "" {
  31. if i := strings.Index(headerType, ";"); i != -1 {
  32. headerType = headerType[:i]
  33. }
  34. if headerType != "application/octet-stream" {
  35. return headerType, nil
  36. }
  37. }
  38. if cd := response.Header.Get("Content-Disposition"); cd != "" {
  39. parts := strings.Split(cd, ";")
  40. for _, part := range parts {
  41. part = strings.TrimSpace(part)
  42. if strings.HasPrefix(strings.ToLower(part), "filename=") {
  43. name := strings.TrimSpace(strings.TrimPrefix(part, "filename="))
  44. if len(name) > 2 && name[0] == '"' && name[len(name)-1] == '"' {
  45. name = name[1 : len(name)-1]
  46. }
  47. if dot := strings.LastIndex(name, "."); dot != -1 && dot+1 < len(name) {
  48. ext := strings.ToLower(name[dot+1:])
  49. if ext != "" {
  50. mt := GetMimeTypeByExtension(ext)
  51. if mt != "application/octet-stream" {
  52. return mt, nil
  53. }
  54. }
  55. }
  56. break
  57. }
  58. }
  59. }
  60. cleanedURL := url
  61. if q := strings.Index(cleanedURL, "?"); q != -1 {
  62. cleanedURL = cleanedURL[:q]
  63. }
  64. if slash := strings.LastIndex(cleanedURL, "/"); slash != -1 && slash+1 < len(cleanedURL) {
  65. last := cleanedURL[slash+1:]
  66. if dot := strings.LastIndex(last, "."); dot != -1 && dot+1 < len(last) {
  67. ext := strings.ToLower(last[dot+1:])
  68. if ext != "" {
  69. mt := GetMimeTypeByExtension(ext)
  70. if mt != "application/octet-stream" {
  71. return mt, nil
  72. }
  73. }
  74. }
  75. }
  76. var readData []byte
  77. limits := []int{512, 8 * 1024, 24 * 1024, 64 * 1024}
  78. for _, limit := range limits {
  79. logger.LogDebug(c, fmt.Sprintf("Trying to read %d bytes to determine file type", limit))
  80. if len(readData) < limit {
  81. need := limit - len(readData)
  82. tmp := make([]byte, need)
  83. n, _ := io.ReadFull(response.Body, tmp)
  84. if n > 0 {
  85. readData = append(readData, tmp[:n]...)
  86. }
  87. }
  88. if len(readData) == 0 {
  89. continue
  90. }
  91. sniffed := http.DetectContentType(readData)
  92. if sniffed != "" && sniffed != "application/octet-stream" {
  93. return sniffed, nil
  94. }
  95. if _, format, err := image.DecodeConfig(bytes.NewReader(readData)); err == nil {
  96. switch strings.ToLower(format) {
  97. case "jpeg", "jpg":
  98. return "image/jpeg", nil
  99. case "png":
  100. return "image/png", nil
  101. case "gif":
  102. return "image/gif", nil
  103. case "bmp":
  104. return "image/bmp", nil
  105. case "tiff":
  106. return "image/tiff", nil
  107. default:
  108. if format != "" {
  109. return "image/" + strings.ToLower(format), nil
  110. }
  111. }
  112. }
  113. }
  114. // Fallback
  115. return "application/octet-stream", nil
  116. }
  117. // GetFileBase64FromUrl 从 URL 获取文件的 base64 编码数据
  118. // Deprecated: 请使用 GetBase64Data 配合 types.NewURLFileSource 替代
  119. // 此函数保留用于向后兼容,内部已重构为调用统一的文件服务
  120. func GetFileBase64FromUrl(c *gin.Context, url string, reason ...string) (*types.LocalFileData, error) {
  121. source := types.NewURLFileSource(url)
  122. cachedData, err := LoadFileSource(c, source, reason...)
  123. if err != nil {
  124. return nil, err
  125. }
  126. // 转换为旧的 LocalFileData 格式以保持兼容
  127. base64Data, err := cachedData.GetBase64Data()
  128. if err != nil {
  129. return nil, err
  130. }
  131. return &types.LocalFileData{
  132. Base64Data: base64Data,
  133. MimeType: cachedData.MimeType,
  134. Size: cachedData.Size,
  135. Url: url,
  136. }, nil
  137. }
  138. func GetMimeTypeByExtension(ext string) string {
  139. // Convert to lowercase for case-insensitive comparison
  140. ext = strings.ToLower(ext)
  141. switch ext {
  142. // Text files
  143. case "txt", "md", "markdown", "csv", "json", "xml", "html", "htm":
  144. return "text/plain"
  145. // Image files
  146. case "jpg", "jpeg":
  147. return "image/jpeg"
  148. case "png":
  149. return "image/png"
  150. case "gif":
  151. return "image/gif"
  152. case "jfif":
  153. return "image/jpeg"
  154. // Audio files
  155. case "mp3":
  156. return "audio/mp3"
  157. case "wav":
  158. return "audio/wav"
  159. case "mpeg":
  160. return "audio/mpeg"
  161. // Video files
  162. case "mp4":
  163. return "video/mp4"
  164. case "wmv":
  165. return "video/wmv"
  166. case "flv":
  167. return "video/flv"
  168. case "mov":
  169. return "video/mov"
  170. case "mpg":
  171. return "video/mpg"
  172. case "avi":
  173. return "video/avi"
  174. case "mpegps":
  175. return "video/mpegps"
  176. // Document files
  177. case "pdf":
  178. return "application/pdf"
  179. default:
  180. return "application/octet-stream" // Default for unknown types
  181. }
  182. }