video_proxy.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. package controller
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "one-api/logger"
  7. "one-api/model"
  8. "time"
  9. "github.com/gin-gonic/gin"
  10. )
  11. func VideoProxy(c *gin.Context) {
  12. taskID := c.Param("task_id")
  13. if taskID == "" {
  14. c.JSON(http.StatusBadRequest, gin.H{
  15. "error": gin.H{
  16. "message": "task_id is required",
  17. "type": "invalid_request_error",
  18. },
  19. })
  20. return
  21. }
  22. task, exists, err := model.GetByOnlyTaskId(taskID)
  23. if err != nil {
  24. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error()))
  25. c.JSON(http.StatusInternalServerError, gin.H{
  26. "error": gin.H{
  27. "message": "Failed to query task",
  28. "type": "server_error",
  29. },
  30. })
  31. return
  32. }
  33. if !exists || task == nil {
  34. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: %s", taskID, err.Error()))
  35. c.JSON(http.StatusNotFound, gin.H{
  36. "error": gin.H{
  37. "message": "Task not found",
  38. "type": "invalid_request_error",
  39. },
  40. })
  41. return
  42. }
  43. if task.Status != model.TaskStatusSuccess {
  44. c.JSON(http.StatusBadRequest, gin.H{
  45. "error": gin.H{
  46. "message": fmt.Sprintf("Task is not completed yet, current status: %s", task.Status),
  47. "type": "invalid_request_error",
  48. },
  49. })
  50. return
  51. }
  52. channel, err := model.CacheGetChannel(task.ChannelId)
  53. if err != nil {
  54. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get channel %d: %s", task.ChannelId, err.Error()))
  55. c.JSON(http.StatusInternalServerError, gin.H{
  56. "error": gin.H{
  57. "message": "Failed to retrieve channel information",
  58. "type": "server_error",
  59. },
  60. })
  61. return
  62. }
  63. baseURL := channel.GetBaseURL()
  64. if baseURL == "" {
  65. baseURL = "https://api.openai.com"
  66. }
  67. videoURL := fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID)
  68. client := &http.Client{
  69. Timeout: 60 * time.Second,
  70. }
  71. req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, videoURL, nil)
  72. if err != nil {
  73. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request for %s: %s", videoURL, err.Error()))
  74. c.JSON(http.StatusInternalServerError, gin.H{
  75. "error": gin.H{
  76. "message": "Failed to create proxy request",
  77. "type": "server_error",
  78. },
  79. })
  80. return
  81. }
  82. req.Header.Set("Authorization", "Bearer "+channel.Key)
  83. resp, err := client.Do(req)
  84. if err != nil {
  85. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to fetch video from %s: %s", videoURL, err.Error()))
  86. c.JSON(http.StatusBadGateway, gin.H{
  87. "error": gin.H{
  88. "message": "Failed to fetch video content",
  89. "type": "server_error",
  90. },
  91. })
  92. return
  93. }
  94. defer resp.Body.Close()
  95. if resp.StatusCode != http.StatusOK {
  96. logger.LogError(c.Request.Context(), fmt.Sprintf("Upstream returned status %d for %s", resp.StatusCode, videoURL))
  97. c.JSON(http.StatusBadGateway, gin.H{
  98. "error": gin.H{
  99. "message": fmt.Sprintf("Upstream service returned status %d", resp.StatusCode),
  100. "type": "server_error",
  101. },
  102. })
  103. return
  104. }
  105. for key, values := range resp.Header {
  106. for _, value := range values {
  107. c.Writer.Header().Add(key, value)
  108. }
  109. }
  110. c.Writer.Header().Set("Cache-Control", "public, max-age=86400") // Cache for 24 hours
  111. c.Writer.WriteHeader(resp.StatusCode)
  112. _, err = io.Copy(c.Writer, resp.Body)
  113. if err != nil {
  114. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error()))
  115. }
  116. }