video_proxy.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. package controller
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "net/url"
  9. "strings"
  10. "time"
  11. "github.com/QuantumNous/new-api/common"
  12. "github.com/QuantumNous/new-api/constant"
  13. "github.com/QuantumNous/new-api/logger"
  14. "github.com/QuantumNous/new-api/model"
  15. "github.com/QuantumNous/new-api/service"
  16. "github.com/QuantumNous/new-api/setting/system_setting"
  17. "github.com/gin-gonic/gin"
  18. )
  19. // videoProxyError returns a standardized OpenAI-style error response.
  20. func videoProxyError(c *gin.Context, status int, errType, message string) {
  21. c.JSON(status, gin.H{
  22. "error": gin.H{
  23. "message": message,
  24. "type": errType,
  25. },
  26. })
  27. }
  28. func VideoProxy(c *gin.Context) {
  29. taskID := c.Param("task_id")
  30. if taskID == "" {
  31. videoProxyError(c, http.StatusBadRequest, "invalid_request_error", "task_id is required")
  32. return
  33. }
  34. userID := c.GetInt("id")
  35. task, exists, err := model.GetByTaskId(userID, taskID)
  36. if err != nil {
  37. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error()))
  38. videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to query task")
  39. return
  40. }
  41. if !exists || task == nil {
  42. videoProxyError(c, http.StatusNotFound, "invalid_request_error", "Task not found")
  43. return
  44. }
  45. if task.Status != model.TaskStatusSuccess {
  46. videoProxyError(c, http.StatusBadRequest, "invalid_request_error",
  47. fmt.Sprintf("Task is not completed yet, current status: %s", task.Status))
  48. return
  49. }
  50. channel, err := model.CacheGetChannel(task.ChannelId)
  51. if err != nil {
  52. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get channel for task %s: %s", taskID, err.Error()))
  53. videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to retrieve channel information")
  54. return
  55. }
  56. baseURL := channel.GetBaseURL()
  57. if baseURL == "" {
  58. baseURL = "https://api.openai.com"
  59. }
  60. var videoURL string
  61. proxy := channel.GetSetting().Proxy
  62. client, err := service.GetHttpClientWithProxy(proxy)
  63. if err != nil {
  64. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create proxy client for task %s: %s", taskID, err.Error()))
  65. videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy client")
  66. return
  67. }
  68. ctx, cancel := context.WithTimeout(c.Request.Context(), 60*time.Second)
  69. defer cancel()
  70. req, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil)
  71. if err != nil {
  72. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error()))
  73. videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request")
  74. return
  75. }
  76. switch channel.Type {
  77. case constant.ChannelTypeGemini:
  78. apiKey := task.PrivateData.Key
  79. if apiKey == "" {
  80. logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID))
  81. videoProxyError(c, http.StatusInternalServerError, "server_error", "API key not stored for task")
  82. return
  83. }
  84. videoURL, err = getGeminiVideoURL(channel, task, apiKey)
  85. if err != nil {
  86. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Gemini video URL for task %s: %s", taskID, err.Error()))
  87. videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to resolve Gemini video URL")
  88. return
  89. }
  90. req.Header.Set("x-goog-api-key", apiKey)
  91. case constant.ChannelTypeVertexAi:
  92. videoURL, err = getVertexVideoURL(channel, task)
  93. if err != nil {
  94. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Vertex video URL for task %s: %s", taskID, err.Error()))
  95. videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to resolve Vertex video URL")
  96. return
  97. }
  98. case constant.ChannelTypeOpenAI, constant.ChannelTypeSora:
  99. videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.GetUpstreamTaskID())
  100. req.Header.Set("Authorization", "Bearer "+channel.Key)
  101. default:
  102. // Video URL is stored in PrivateData.ResultURL (fallback to FailReason for old data)
  103. videoURL = task.GetResultURL()
  104. }
  105. videoURL = strings.TrimSpace(videoURL)
  106. if videoURL == "" {
  107. logger.LogError(c.Request.Context(), fmt.Sprintf("Video URL is empty for task %s", taskID))
  108. videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to fetch video content")
  109. return
  110. }
  111. if strings.HasPrefix(videoURL, "data:") {
  112. if err := writeVideoDataURL(c, videoURL); err != nil {
  113. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to decode video data URL for task %s: %s", taskID, err.Error()))
  114. videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to fetch video content")
  115. }
  116. return
  117. }
  118. fetchSetting := system_setting.GetFetchSetting()
  119. if err := common.ValidateURLWithFetchSetting(videoURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
  120. logger.LogError(c.Request.Context(), fmt.Sprintf("Video URL blocked for task %s: %v", taskID, err))
  121. videoProxyError(c, http.StatusForbidden, "server_error", fmt.Sprintf("request blocked: %v", err))
  122. return
  123. }
  124. req.URL, err = url.Parse(videoURL)
  125. if err != nil {
  126. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error()))
  127. videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request")
  128. return
  129. }
  130. resp, err := client.Do(req)
  131. if err != nil {
  132. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to fetch video from %s: %s", videoURL, err.Error()))
  133. videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to fetch video content")
  134. return
  135. }
  136. defer resp.Body.Close()
  137. if resp.StatusCode != http.StatusOK {
  138. logger.LogError(c.Request.Context(), fmt.Sprintf("Upstream returned status %d for %s", resp.StatusCode, videoURL))
  139. videoProxyError(c, http.StatusBadGateway, "server_error",
  140. fmt.Sprintf("Upstream service returned status %d", resp.StatusCode))
  141. return
  142. }
  143. for key, values := range resp.Header {
  144. for _, value := range values {
  145. c.Writer.Header().Add(key, value)
  146. }
  147. }
  148. c.Writer.Header().Set("Cache-Control", "public, max-age=86400")
  149. c.Writer.WriteHeader(resp.StatusCode)
  150. if _, err = io.Copy(c.Writer, resp.Body); err != nil {
  151. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error()))
  152. }
  153. }
  154. func writeVideoDataURL(c *gin.Context, dataURL string) error {
  155. parts := strings.SplitN(dataURL, ",", 2)
  156. if len(parts) != 2 {
  157. return fmt.Errorf("invalid data url")
  158. }
  159. header := parts[0]
  160. payload := parts[1]
  161. if !strings.HasPrefix(header, "data:") || !strings.Contains(header, ";base64") {
  162. return fmt.Errorf("unsupported data url")
  163. }
  164. mimeType := strings.TrimPrefix(header, "data:")
  165. mimeType = strings.TrimSuffix(mimeType, ";base64")
  166. if mimeType == "" {
  167. mimeType = "video/mp4"
  168. }
  169. videoBytes, err := base64.StdEncoding.DecodeString(payload)
  170. if err != nil {
  171. videoBytes, err = base64.RawStdEncoding.DecodeString(payload)
  172. if err != nil {
  173. return err
  174. }
  175. }
  176. c.Writer.Header().Set("Content-Type", mimeType)
  177. c.Writer.Header().Set("Cache-Control", "public, max-age=86400")
  178. c.Writer.WriteHeader(http.StatusOK)
  179. _, err = c.Writer.Write(videoBytes)
  180. return err
  181. }