relay.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. package controller
  2. import (
  3. "fmt"
  4. "github.com/gin-gonic/gin"
  5. "log"
  6. "net/http"
  7. "one-api/common"
  8. "one-api/dto"
  9. "one-api/relay"
  10. "one-api/relay/constant"
  11. relayconstant "one-api/relay/constant"
  12. "one-api/service"
  13. "strconv"
  14. "strings"
  15. )
  16. func Relay(c *gin.Context) {
  17. relayMode := constant.Path2RelayMode(c.Request.URL.Path)
  18. var err *dto.OpenAIErrorWithStatusCode
  19. switch relayMode {
  20. case relayconstant.RelayModeImagesGenerations:
  21. err = relay.RelayImageHelper(c, relayMode)
  22. case relayconstant.RelayModeAudioSpeech:
  23. fallthrough
  24. case relayconstant.RelayModeAudioTranslation:
  25. fallthrough
  26. case relayconstant.RelayModeAudioTranscription:
  27. err = relay.AudioHelper(c, relayMode)
  28. default:
  29. err = relay.TextHelper(c)
  30. }
  31. if err != nil {
  32. requestId := c.GetString(common.RequestIdKey)
  33. retryTimesStr := c.Query("retry")
  34. retryTimes, _ := strconv.Atoi(retryTimesStr)
  35. if retryTimesStr == "" {
  36. retryTimes = common.RetryTimes
  37. }
  38. if retryTimes > 0 {
  39. c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d&error=%s", c.Request.URL.Path, retryTimes-1, err.Error.Message))
  40. } else {
  41. if err.StatusCode == http.StatusTooManyRequests {
  42. //err.Error.Message = "当前分组上游负载已饱和,请稍后再试"
  43. }
  44. err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId)
  45. c.JSON(err.StatusCode, gin.H{
  46. "error": err.Error,
  47. })
  48. }
  49. channelId := c.GetInt("channel_id")
  50. autoBan := c.GetBool("auto_ban")
  51. common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
  52. // https://platform.openai.com/docs/guides/error-codes/api-errors
  53. if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
  54. channelId := c.GetInt("channel_id")
  55. channelName := c.GetString("channel_name")
  56. service.DisableChannel(channelId, channelName, err.Error.Message)
  57. }
  58. }
  59. }
  60. func RelayMidjourney(c *gin.Context) {
  61. relayMode := relayconstant.RelayModeUnknown
  62. if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
  63. relayMode = relayconstant.RelayModeMidjourneyImagine
  64. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
  65. relayMode = relayconstant.RelayModeMidjourneyBlend
  66. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") {
  67. relayMode = relayconstant.RelayModeMidjourneyDescribe
  68. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
  69. relayMode = relayconstant.RelayModeMidjourneyNotify
  70. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
  71. relayMode = relayconstant.RelayModeMidjourneyChange
  72. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
  73. relayMode = relayconstant.RelayModeMidjourneyChange
  74. } else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
  75. relayMode = relayconstant.RelayModeMidjourneyTaskFetch
  76. } else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
  77. relayMode = relayconstant.RelayModeMidjourneyTaskFetchByCondition
  78. }
  79. var err *dto.MidjourneyResponse
  80. switch relayMode {
  81. case relayconstant.RelayModeMidjourneyNotify:
  82. err = relay.RelayMidjourneyNotify(c)
  83. case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
  84. err = relay.RelayMidjourneyTask(c, relayMode)
  85. default:
  86. err = relay.RelayMidjourneySubmit(c, relayMode)
  87. }
  88. //err = relayMidjourneySubmit(c, relayMode)
  89. log.Println(err)
  90. if err != nil {
  91. retryTimesStr := c.Query("retry")
  92. retryTimes, _ := strconv.Atoi(retryTimesStr)
  93. if retryTimesStr == "" {
  94. retryTimes = common.RetryTimes
  95. }
  96. if retryTimes > 0 {
  97. c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
  98. } else {
  99. if err.Code == 30 {
  100. err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
  101. }
  102. c.JSON(429, gin.H{
  103. "error": fmt.Sprintf("%s %s", err.Description, err.Result),
  104. "type": "upstream_error",
  105. })
  106. }
  107. channelId := c.GetInt("channel_id")
  108. common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
  109. //if shouldDisableChannel(&err.Error) {
  110. // channelId := c.GetInt("channel_id")
  111. // channelName := c.GetString("channel_name")
  112. // disableChannel(channelId, channelName, err.Result)
  113. //};''''''''''''''''''''''''''''''''
  114. }
  115. }
  116. func RelayNotImplemented(c *gin.Context) {
  117. err := dto.OpenAIError{
  118. Message: "API not implemented",
  119. Type: "new_api_error",
  120. Param: "",
  121. Code: "api_not_implemented",
  122. }
  123. c.JSON(http.StatusNotImplemented, gin.H{
  124. "error": err,
  125. })
  126. }
  127. func RelayNotFound(c *gin.Context) {
  128. err := dto.OpenAIError{
  129. Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
  130. Type: "invalid_request_error",
  131. Param: "",
  132. Code: "",
  133. }
  134. c.JSON(http.StatusNotFound, gin.H{
  135. "error": err,
  136. })
  137. }