video_proxy.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. package controller
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "net/url"
  8. "time"
  9. "github.com/QuantumNous/new-api/constant"
  10. "github.com/QuantumNous/new-api/logger"
  11. "github.com/QuantumNous/new-api/model"
  12. "github.com/QuantumNous/new-api/service"
  13. "github.com/gin-gonic/gin"
  14. )
  15. func VideoProxy(c *gin.Context) {
  16. taskID := c.Param("task_id")
  17. if taskID == "" {
  18. c.JSON(http.StatusBadRequest, gin.H{
  19. "error": gin.H{
  20. "message": "task_id is required",
  21. "type": "invalid_request_error",
  22. },
  23. })
  24. return
  25. }
  26. task, exists, err := model.GetByOnlyTaskId(taskID)
  27. if err != nil {
  28. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error()))
  29. c.JSON(http.StatusInternalServerError, gin.H{
  30. "error": gin.H{
  31. "message": "Failed to query task",
  32. "type": "server_error",
  33. },
  34. })
  35. return
  36. }
  37. if !exists || task == nil {
  38. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: %v", taskID, err))
  39. c.JSON(http.StatusNotFound, gin.H{
  40. "error": gin.H{
  41. "message": "Task not found",
  42. "type": "invalid_request_error",
  43. },
  44. })
  45. return
  46. }
  47. if task.Status != model.TaskStatusSuccess {
  48. c.JSON(http.StatusBadRequest, gin.H{
  49. "error": gin.H{
  50. "message": fmt.Sprintf("Task is not completed yet, current status: %s", task.Status),
  51. "type": "invalid_request_error",
  52. },
  53. })
  54. return
  55. }
  56. channel, err := model.CacheGetChannel(task.ChannelId)
  57. if err != nil {
  58. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: not found", taskID))
  59. c.JSON(http.StatusInternalServerError, gin.H{
  60. "error": gin.H{
  61. "message": "Failed to retrieve channel information",
  62. "type": "server_error",
  63. },
  64. })
  65. return
  66. }
  67. baseURL := channel.GetBaseURL()
  68. if baseURL == "" {
  69. baseURL = "https://api.openai.com"
  70. }
  71. var videoURL string
  72. proxy := channel.GetSetting().Proxy
  73. client, err := service.GetHttpClientWithProxy(proxy)
  74. if err != nil {
  75. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create proxy client for task %s: %s", taskID, err.Error()))
  76. c.JSON(http.StatusInternalServerError, gin.H{
  77. "error": gin.H{
  78. "message": "Failed to create proxy client",
  79. "type": "server_error",
  80. },
  81. })
  82. return
  83. }
  84. ctx, cancel := context.WithTimeout(c.Request.Context(), 60*time.Second)
  85. defer cancel()
  86. req, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil)
  87. if err != nil {
  88. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error()))
  89. c.JSON(http.StatusInternalServerError, gin.H{
  90. "error": gin.H{
  91. "message": "Failed to create proxy request",
  92. "type": "server_error",
  93. },
  94. })
  95. return
  96. }
  97. switch channel.Type {
  98. case constant.ChannelTypeGemini:
  99. apiKey := task.PrivateData.Key
  100. if apiKey == "" {
  101. logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID))
  102. c.JSON(http.StatusInternalServerError, gin.H{
  103. "error": gin.H{
  104. "message": "API key not stored for task",
  105. "type": "server_error",
  106. },
  107. })
  108. return
  109. }
  110. videoURL, err = getGeminiVideoURL(channel, task, apiKey)
  111. if err != nil {
  112. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Gemini video URL for task %s: %s", taskID, err.Error()))
  113. c.JSON(http.StatusBadGateway, gin.H{
  114. "error": gin.H{
  115. "message": "Failed to resolve Gemini video URL",
  116. "type": "server_error",
  117. },
  118. })
  119. return
  120. }
  121. req.Header.Set("x-goog-api-key", apiKey)
  122. case constant.ChannelTypeOpenAI, constant.ChannelTypeSora:
  123. videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID)
  124. req.Header.Set("Authorization", "Bearer "+channel.Key)
  125. default:
  126. // Video URL is directly in task.FailReason
  127. videoURL = task.FailReason
  128. }
  129. req.URL, err = url.Parse(videoURL)
  130. if err != nil {
  131. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error()))
  132. c.JSON(http.StatusInternalServerError, gin.H{
  133. "error": gin.H{
  134. "message": "Failed to create proxy request",
  135. "type": "server_error",
  136. },
  137. })
  138. return
  139. }
  140. resp, err := client.Do(req)
  141. if err != nil {
  142. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to fetch video from %s: %s", videoURL, err.Error()))
  143. c.JSON(http.StatusBadGateway, gin.H{
  144. "error": gin.H{
  145. "message": "Failed to fetch video content",
  146. "type": "server_error",
  147. },
  148. })
  149. return
  150. }
  151. defer resp.Body.Close()
  152. if resp.StatusCode != http.StatusOK {
  153. logger.LogError(c.Request.Context(), fmt.Sprintf("Upstream returned status %d for %s", resp.StatusCode, videoURL))
  154. c.JSON(http.StatusBadGateway, gin.H{
  155. "error": gin.H{
  156. "message": fmt.Sprintf("Upstream service returned status %d", resp.StatusCode),
  157. "type": "server_error",
  158. },
  159. })
  160. return
  161. }
  162. for key, values := range resp.Header {
  163. for _, value := range values {
  164. c.Writer.Header().Add(key, value)
  165. }
  166. }
  167. c.Writer.Header().Set("Cache-Control", "public, max-age=86400") // Cache for 24 hours
  168. c.Writer.WriteHeader(resp.StatusCode)
  169. _, err = io.Copy(c.Writer, resp.Body)
  170. if err != nil {
  171. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error()))
  172. }
  173. }