video_proxy.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. package controller
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "net/url"
  7. "time"
  8. "github.com/QuantumNous/new-api/constant"
  9. "github.com/QuantumNous/new-api/logger"
  10. "github.com/QuantumNous/new-api/model"
  11. "github.com/gin-gonic/gin"
  12. )
  13. func VideoProxy(c *gin.Context) {
  14. taskID := c.Param("task_id")
  15. if taskID == "" {
  16. c.JSON(http.StatusBadRequest, gin.H{
  17. "error": gin.H{
  18. "message": "task_id is required",
  19. "type": "invalid_request_error",
  20. },
  21. })
  22. return
  23. }
  24. task, exists, err := model.GetByOnlyTaskId(taskID)
  25. if err != nil {
  26. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error()))
  27. c.JSON(http.StatusInternalServerError, gin.H{
  28. "error": gin.H{
  29. "message": "Failed to query task",
  30. "type": "server_error",
  31. },
  32. })
  33. return
  34. }
  35. if !exists || task == nil {
  36. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: %v", taskID, err))
  37. c.JSON(http.StatusNotFound, gin.H{
  38. "error": gin.H{
  39. "message": "Task not found",
  40. "type": "invalid_request_error",
  41. },
  42. })
  43. return
  44. }
  45. if task.Status != model.TaskStatusSuccess {
  46. c.JSON(http.StatusBadRequest, gin.H{
  47. "error": gin.H{
  48. "message": fmt.Sprintf("Task is not completed yet, current status: %s", task.Status),
  49. "type": "invalid_request_error",
  50. },
  51. })
  52. return
  53. }
  54. channel, err := model.CacheGetChannel(task.ChannelId)
  55. if err != nil {
  56. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: not found", taskID))
  57. c.JSON(http.StatusInternalServerError, gin.H{
  58. "error": gin.H{
  59. "message": "Failed to retrieve channel information",
  60. "type": "server_error",
  61. },
  62. })
  63. return
  64. }
  65. baseURL := channel.GetBaseURL()
  66. if baseURL == "" {
  67. baseURL = "https://api.openai.com"
  68. }
  69. var videoURL string
  70. client := &http.Client{
  71. Timeout: 60 * time.Second,
  72. }
  73. req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, "", nil)
  74. if err != nil {
  75. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error()))
  76. c.JSON(http.StatusInternalServerError, gin.H{
  77. "error": gin.H{
  78. "message": "Failed to create proxy request",
  79. "type": "server_error",
  80. },
  81. })
  82. return
  83. }
  84. switch channel.Type {
  85. case constant.ChannelTypeGemini:
  86. apiKey := task.PrivateData.Key
  87. if apiKey == "" {
  88. logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID))
  89. c.JSON(http.StatusInternalServerError, gin.H{
  90. "error": gin.H{
  91. "message": "API key not stored for task",
  92. "type": "server_error",
  93. },
  94. })
  95. return
  96. }
  97. videoURL, err = getGeminiVideoURL(channel, task, apiKey)
  98. if err != nil {
  99. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Gemini video URL for task %s: %s", taskID, err.Error()))
  100. c.JSON(http.StatusBadGateway, gin.H{
  101. "error": gin.H{
  102. "message": "Failed to resolve Gemini video URL",
  103. "type": "server_error",
  104. },
  105. })
  106. return
  107. }
  108. req.Header.Set("x-goog-api-key", apiKey)
  109. case constant.ChannelTypeOpenAI, constant.ChannelTypeSora:
  110. videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID)
  111. req.Header.Set("Authorization", "Bearer "+channel.Key)
  112. default:
  113. // Video URL is directly in task.FailReason
  114. videoURL = task.FailReason
  115. }
  116. req.URL, err = url.Parse(videoURL)
  117. if err != nil {
  118. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error()))
  119. c.JSON(http.StatusInternalServerError, gin.H{
  120. "error": gin.H{
  121. "message": "Failed to create proxy request",
  122. "type": "server_error",
  123. },
  124. })
  125. return
  126. }
  127. resp, err := client.Do(req)
  128. if err != nil {
  129. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to fetch video from %s: %s", videoURL, err.Error()))
  130. c.JSON(http.StatusBadGateway, gin.H{
  131. "error": gin.H{
  132. "message": "Failed to fetch video content",
  133. "type": "server_error",
  134. },
  135. })
  136. return
  137. }
  138. defer resp.Body.Close()
  139. if resp.StatusCode != http.StatusOK {
  140. logger.LogError(c.Request.Context(), fmt.Sprintf("Upstream returned status %d for %s", resp.StatusCode, videoURL))
  141. c.JSON(http.StatusBadGateway, gin.H{
  142. "error": gin.H{
  143. "message": fmt.Sprintf("Upstream service returned status %d", resp.StatusCode),
  144. "type": "server_error",
  145. },
  146. })
  147. return
  148. }
  149. for key, values := range resp.Header {
  150. for _, value := range values {
  151. c.Writer.Header().Add(key, value)
  152. }
  153. }
  154. c.Writer.Header().Set("Cache-Control", "public, max-age=86400") // Cache for 24 hours
  155. c.Writer.WriteHeader(resp.StatusCode)
  156. _, err = io.Copy(c.Writer, resp.Body)
  157. if err != nil {
  158. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error()))
  159. }
  160. }