file_decoder.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. package service
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "fmt"
  6. "image"
  7. _ "image/gif"
  8. _ "image/jpeg"
  9. _ "image/png"
  10. "io"
  11. "net/http"
  12. "one-api/common"
  13. "one-api/constant"
  14. "one-api/logger"
  15. "one-api/types"
  16. "strings"
  17. "github.com/gin-gonic/gin"
  18. )
  19. // GetFileTypeFromUrl 获取文件类型,返回 mime type, 例如 image/jpeg, image/png, image/gif, image/bmp, image/tiff, application/pdf
  20. // 如果获取失败,返回 application/octet-stream
  21. func GetFileTypeFromUrl(c *gin.Context, url string, reason ...string) (string, error) {
  22. response, err := DoDownloadRequest(url, []string{"get_mime_type", strings.Join(reason, ", ")}...)
  23. if err != nil {
  24. common.SysLog(fmt.Sprintf("fail to get file type from url: %s, error: %s", url, err.Error()))
  25. return "", err
  26. }
  27. defer response.Body.Close()
  28. if response.StatusCode != 200 {
  29. logger.LogError(c, fmt.Sprintf("failed to download file from %s, status code: %d", url, response.StatusCode))
  30. return "", fmt.Errorf("failed to download file, status code: %d", response.StatusCode)
  31. }
  32. if headerType := strings.TrimSpace(response.Header.Get("Content-Type")); headerType != "" {
  33. if i := strings.Index(headerType, ";"); i != -1 {
  34. headerType = headerType[:i]
  35. }
  36. if headerType != "application/octet-stream" {
  37. return headerType, nil
  38. }
  39. }
  40. if cd := response.Header.Get("Content-Disposition"); cd != "" {
  41. parts := strings.Split(cd, ";")
  42. for _, part := range parts {
  43. part = strings.TrimSpace(part)
  44. if strings.HasPrefix(strings.ToLower(part), "filename=") {
  45. name := strings.TrimSpace(strings.TrimPrefix(part, "filename="))
  46. if len(name) > 2 && name[0] == '"' && name[len(name)-1] == '"' {
  47. name = name[1 : len(name)-1]
  48. }
  49. if dot := strings.LastIndex(name, "."); dot != -1 && dot+1 < len(name) {
  50. ext := strings.ToLower(name[dot+1:])
  51. if ext != "" {
  52. mt := GetMimeTypeByExtension(ext)
  53. if mt != "application/octet-stream" {
  54. return mt, nil
  55. }
  56. }
  57. }
  58. break
  59. }
  60. }
  61. }
  62. cleanedURL := url
  63. if q := strings.Index(cleanedURL, "?"); q != -1 {
  64. cleanedURL = cleanedURL[:q]
  65. }
  66. if slash := strings.LastIndex(cleanedURL, "/"); slash != -1 && slash+1 < len(cleanedURL) {
  67. last := cleanedURL[slash+1:]
  68. if dot := strings.LastIndex(last, "."); dot != -1 && dot+1 < len(last) {
  69. ext := strings.ToLower(last[dot+1:])
  70. if ext != "" {
  71. mt := GetMimeTypeByExtension(ext)
  72. if mt != "application/octet-stream" {
  73. return mt, nil
  74. }
  75. }
  76. }
  77. }
  78. var readData []byte
  79. limits := []int{512, 8 * 1024, 24 * 1024, 64 * 1024}
  80. for _, limit := range limits {
  81. logger.LogDebug(c, fmt.Sprintf("Trying to read %d bytes to determine file type", limit))
  82. if len(readData) < limit {
  83. need := limit - len(readData)
  84. tmp := make([]byte, need)
  85. n, _ := io.ReadFull(response.Body, tmp)
  86. if n > 0 {
  87. readData = append(readData, tmp[:n]...)
  88. }
  89. }
  90. if len(readData) == 0 {
  91. continue
  92. }
  93. sniffed := http.DetectContentType(readData)
  94. if sniffed != "" && sniffed != "application/octet-stream" {
  95. return sniffed, nil
  96. }
  97. if _, format, err := image.DecodeConfig(bytes.NewReader(readData)); err == nil {
  98. switch strings.ToLower(format) {
  99. case "jpeg", "jpg":
  100. return "image/jpeg", nil
  101. case "png":
  102. return "image/png", nil
  103. case "gif":
  104. return "image/gif", nil
  105. case "bmp":
  106. return "image/bmp", nil
  107. case "tiff":
  108. return "image/tiff", nil
  109. default:
  110. if format != "" {
  111. return "image/" + strings.ToLower(format), nil
  112. }
  113. }
  114. }
  115. }
  116. // Fallback
  117. return "application/octet-stream", nil
  118. }
  119. func GetFileBase64FromUrl(c *gin.Context, url string, reason ...string) (*types.LocalFileData, error) {
  120. contextKey := fmt.Sprintf("file_download_%s", common.GenerateHMAC(url))
  121. // Check if the file has already been downloaded in this request
  122. if cachedData, exists := c.Get(contextKey); exists {
  123. if common.DebugEnabled {
  124. logger.LogDebug(c, fmt.Sprintf("Using cached file data for URL: %s", url))
  125. }
  126. return cachedData.(*types.LocalFileData), nil
  127. }
  128. var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
  129. resp, err := DoDownloadRequest(url, reason...)
  130. if err != nil {
  131. return nil, err
  132. }
  133. defer resp.Body.Close()
  134. // Always use LimitReader to prevent oversized downloads
  135. fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1)))
  136. if err != nil {
  137. return nil, err
  138. }
  139. // Check actual size after reading
  140. if len(fileBytes) > maxFileSize {
  141. return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)
  142. }
  143. // Convert to base64
  144. base64Data := base64.StdEncoding.EncodeToString(fileBytes)
  145. mimeType := resp.Header.Get("Content-Type")
  146. if len(strings.Split(mimeType, ";")) > 1 {
  147. // If Content-Type has parameters, take the first part
  148. mimeType = strings.Split(mimeType, ";")[0]
  149. }
  150. if mimeType == "application/octet-stream" {
  151. logger.LogDebug(c, fmt.Sprintf("MIME type is application/octet-stream for URL: %s", url))
  152. // try to guess the MIME type from the url last segment
  153. urlParts := strings.Split(url, "/")
  154. if len(urlParts) > 0 {
  155. lastSegment := urlParts[len(urlParts)-1]
  156. if strings.Contains(lastSegment, ".") {
  157. // Extract the file extension
  158. filename := strings.Split(lastSegment, ".")
  159. if len(filename) > 1 {
  160. ext := strings.ToLower(filename[len(filename)-1])
  161. // Guess MIME type based on file extension
  162. mimeType = GetMimeTypeByExtension(ext)
  163. }
  164. }
  165. } else {
  166. // try to guess the MIME type from the file extension
  167. fileName := resp.Header.Get("Content-Disposition")
  168. if fileName != "" {
  169. // Extract the filename from the Content-Disposition header
  170. parts := strings.Split(fileName, ";")
  171. for _, part := range parts {
  172. if strings.HasPrefix(strings.TrimSpace(part), "filename=") {
  173. fileName = strings.TrimSpace(strings.TrimPrefix(part, "filename="))
  174. // Remove quotes if present
  175. if len(fileName) > 2 && fileName[0] == '"' && fileName[len(fileName)-1] == '"' {
  176. fileName = fileName[1 : len(fileName)-1]
  177. }
  178. // Guess MIME type based on file extension
  179. if ext := strings.ToLower(strings.TrimPrefix(fileName, ".")); ext != "" {
  180. mimeType = GetMimeTypeByExtension(ext)
  181. }
  182. break
  183. }
  184. }
  185. }
  186. }
  187. }
  188. data := &types.LocalFileData{
  189. Base64Data: base64Data,
  190. MimeType: mimeType,
  191. Size: int64(len(fileBytes)),
  192. }
  193. // Store the file data in the context to avoid re-downloading
  194. c.Set(contextKey, data)
  195. return data, nil
  196. }
  197. func GetMimeTypeByExtension(ext string) string {
  198. // Convert to lowercase for case-insensitive comparison
  199. ext = strings.ToLower(ext)
  200. switch ext {
  201. // Text files
  202. case "txt", "md", "markdown", "csv", "json", "xml", "html", "htm":
  203. return "text/plain"
  204. // Image files
  205. case "jpg", "jpeg":
  206. return "image/jpeg"
  207. case "png":
  208. return "image/png"
  209. case "gif":
  210. return "image/gif"
  211. // Audio files
  212. case "mp3":
  213. return "audio/mp3"
  214. case "wav":
  215. return "audio/wav"
  216. case "mpeg":
  217. return "audio/mpeg"
  218. // Video files
  219. case "mp4":
  220. return "video/mp4"
  221. case "wmv":
  222. return "video/wmv"
  223. case "flv":
  224. return "video/flv"
  225. case "mov":
  226. return "video/mov"
  227. case "mpg":
  228. return "video/mpg"
  229. case "avi":
  230. return "video/avi"
  231. case "mpegps":
  232. return "video/mpegps"
  233. // Document files
  234. case "pdf":
  235. return "application/pdf"
  236. default:
  237. return "application/octet-stream" // Default for unknown types
  238. }
  239. }