relay.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. package controller
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "github.com/gorilla/websocket"
  8. "io"
  9. "log"
  10. "net/http"
  11. "one-api/common"
  12. "one-api/dto"
  13. "one-api/middleware"
  14. "one-api/model"
  15. "one-api/relay"
  16. "one-api/relay/constant"
  17. relayconstant "one-api/relay/constant"
  18. "one-api/service"
  19. "strings"
  20. )
  21. func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
  22. var err *dto.OpenAIErrorWithStatusCode
  23. switch relayMode {
  24. case relayconstant.RelayModeImagesGenerations:
  25. err = relay.ImageHelper(c, relayMode)
  26. case relayconstant.RelayModeAudioSpeech:
  27. fallthrough
  28. case relayconstant.RelayModeAudioTranslation:
  29. fallthrough
  30. case relayconstant.RelayModeAudioTranscription:
  31. err = relay.AudioHelper(c)
  32. case relayconstant.RelayModeRerank:
  33. err = relay.RerankHelper(c, relayMode)
  34. default:
  35. err = relay.TextHelper(c)
  36. }
  37. return err
  38. }
  39. func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode {
  40. var err *dto.OpenAIErrorWithStatusCode
  41. switch relayMode {
  42. default:
  43. err = relay.TextHelper(c)
  44. }
  45. return err
  46. }
  47. func Relay(c *gin.Context) {
  48. relayMode := constant.Path2RelayMode(c.Request.URL.Path)
  49. requestId := c.GetString(common.RequestIdKey)
  50. group := c.GetString("group")
  51. originalModel := c.GetString("original_model")
  52. var openaiErr *dto.OpenAIErrorWithStatusCode
  53. for i := 0; i <= common.RetryTimes; i++ {
  54. channel, err := getChannel(c, group, originalModel, i)
  55. if err != nil {
  56. common.LogError(c, err.Error())
  57. openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
  58. break
  59. }
  60. openaiErr = relayRequest(c, relayMode, channel)
  61. if openaiErr == nil {
  62. return // 成功处理请求,直接返回
  63. }
  64. go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
  65. if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
  66. break
  67. }
  68. }
  69. useChannel := c.GetStringSlice("use_channel")
  70. if len(useChannel) > 1 {
  71. retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
  72. common.LogInfo(c, retryLogStr)
  73. }
  74. if openaiErr != nil {
  75. if openaiErr.StatusCode == http.StatusTooManyRequests {
  76. openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
  77. }
  78. openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
  79. c.JSON(openaiErr.StatusCode, gin.H{
  80. "error": openaiErr.Error,
  81. })
  82. }
  83. }
  84. var upgrader = websocket.Upgrader{
  85. Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol,则必须在此声明对应的 Protocol TODO add other protocol
  86. CheckOrigin: func(r *http.Request) bool {
  87. return true // 允许跨域
  88. },
  89. }
  90. func WssRelay(c *gin.Context) {
  91. // 将 HTTP 连接升级为 WebSocket 连接
  92. ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
  93. defer ws.Close()
  94. if err != nil {
  95. openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
  96. service.WssError(c, ws, openaiErr.Error)
  97. return
  98. }
  99. relayMode := constant.Path2RelayMode(c.Request.URL.Path)
  100. requestId := c.GetString(common.RequestIdKey)
  101. group := c.GetString("group")
  102. //wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
  103. originalModel := c.GetString("original_model")
  104. var openaiErr *dto.OpenAIErrorWithStatusCode
  105. for i := 0; i <= common.RetryTimes; i++ {
  106. channel, err := getChannel(c, group, originalModel, i)
  107. if err != nil {
  108. common.LogError(c, err.Error())
  109. openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
  110. break
  111. }
  112. openaiErr = wssRequest(c, ws, relayMode, channel)
  113. if openaiErr == nil {
  114. return // 成功处理请求,直接返回
  115. }
  116. go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
  117. if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
  118. break
  119. }
  120. }
  121. useChannel := c.GetStringSlice("use_channel")
  122. if len(useChannel) > 1 {
  123. retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
  124. common.LogInfo(c, retryLogStr)
  125. }
  126. if openaiErr != nil {
  127. if openaiErr.StatusCode == http.StatusTooManyRequests {
  128. openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
  129. }
  130. openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
  131. service.WssError(c, ws, openaiErr.Error)
  132. }
  133. }
  134. func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
  135. addUsedChannel(c, channel.Id)
  136. requestBody, _ := common.GetRequestBody(c)
  137. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  138. return relayHandler(c, relayMode)
  139. }
  140. func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
  141. addUsedChannel(c, channel.Id)
  142. requestBody, _ := common.GetRequestBody(c)
  143. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  144. return relay.WssHelper(c, ws)
  145. }
  146. func addUsedChannel(c *gin.Context, channelId int) {
  147. useChannel := c.GetStringSlice("use_channel")
  148. useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
  149. c.Set("use_channel", useChannel)
  150. }
  151. func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
  152. if retryCount == 0 {
  153. autoBan := c.GetBool("auto_ban")
  154. autoBanInt := 1
  155. if !autoBan {
  156. autoBanInt = 0
  157. }
  158. return &model.Channel{
  159. Id: c.GetInt("channel_id"),
  160. Type: c.GetInt("channel_type"),
  161. Name: c.GetString("channel_name"),
  162. AutoBan: &autoBanInt,
  163. }, nil
  164. }
  165. channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
  166. if err != nil {
  167. return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
  168. }
  169. middleware.SetupContextForSelectedChannel(c, channel, originalModel)
  170. return channel, nil
  171. }
  172. func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
  173. if openaiErr == nil {
  174. return false
  175. }
  176. if openaiErr.LocalError {
  177. return false
  178. }
  179. if retryTimes <= 0 {
  180. return false
  181. }
  182. if _, ok := c.Get("specific_channel_id"); ok {
  183. return false
  184. }
  185. if openaiErr.StatusCode == http.StatusTooManyRequests {
  186. return true
  187. }
  188. if openaiErr.StatusCode == 307 {
  189. return true
  190. }
  191. if openaiErr.StatusCode/100 == 5 {
  192. // 超时不重试
  193. if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
  194. return false
  195. }
  196. return true
  197. }
  198. if openaiErr.StatusCode == http.StatusBadRequest {
  199. channelType := c.GetInt("channel_type")
  200. if channelType == common.ChannelTypeAnthropic {
  201. return true
  202. }
  203. return false
  204. }
  205. if openaiErr.StatusCode == 408 {
  206. // azure处理超时不重试
  207. return false
  208. }
  209. if openaiErr.StatusCode/100 == 2 {
  210. return false
  211. }
  212. return true
  213. }
  214. func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
  215. // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
  216. // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
  217. common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
  218. if service.ShouldDisableChannel(channelType, err) && autoBan {
  219. service.DisableChannel(channelId, channelName, err.Error.Message)
  220. }
  221. }
  222. func RelayMidjourney(c *gin.Context) {
  223. relayMode := c.GetInt("relay_mode")
  224. var err *dto.MidjourneyResponse
  225. switch relayMode {
  226. case relayconstant.RelayModeMidjourneyNotify:
  227. err = relay.RelayMidjourneyNotify(c)
  228. case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
  229. err = relay.RelayMidjourneyTask(c, relayMode)
  230. case relayconstant.RelayModeMidjourneyTaskImageSeed:
  231. err = relay.RelayMidjourneyTaskImageSeed(c)
  232. case relayconstant.RelayModeSwapFace:
  233. err = relay.RelaySwapFace(c)
  234. default:
  235. err = relay.RelayMidjourneySubmit(c, relayMode)
  236. }
  237. //err = relayMidjourneySubmit(c, relayMode)
  238. log.Println(err)
  239. if err != nil {
  240. statusCode := http.StatusBadRequest
  241. if err.Code == 30 {
  242. err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
  243. statusCode = http.StatusTooManyRequests
  244. }
  245. c.JSON(statusCode, gin.H{
  246. "description": fmt.Sprintf("%s %s", err.Description, err.Result),
  247. "type": "upstream_error",
  248. "code": err.Code,
  249. })
  250. channelId := c.GetInt("channel_id")
  251. common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
  252. }
  253. }
  254. func RelayNotImplemented(c *gin.Context) {
  255. err := dto.OpenAIError{
  256. Message: "API not implemented",
  257. Type: "new_api_error",
  258. Param: "",
  259. Code: "api_not_implemented",
  260. }
  261. c.JSON(http.StatusNotImplemented, gin.H{
  262. "error": err,
  263. })
  264. }
  265. func RelayNotFound(c *gin.Context) {
  266. err := dto.OpenAIError{
  267. Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
  268. Type: "invalid_request_error",
  269. Param: "",
  270. Code: "",
  271. }
  272. c.JSON(http.StatusNotFound, gin.H{
  273. "error": err,
  274. })
  275. }
  276. func RelayTask(c *gin.Context) {
  277. retryTimes := common.RetryTimes
  278. channelId := c.GetInt("channel_id")
  279. relayMode := c.GetInt("relay_mode")
  280. group := c.GetString("group")
  281. originalModel := c.GetString("original_model")
  282. c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
  283. taskErr := taskRelayHandler(c, relayMode)
  284. if taskErr == nil {
  285. retryTimes = 0
  286. }
  287. for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
  288. channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
  289. if err != nil {
  290. common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
  291. break
  292. }
  293. channelId = channel.Id
  294. useChannel := c.GetStringSlice("use_channel")
  295. useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
  296. c.Set("use_channel", useChannel)
  297. common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
  298. middleware.SetupContextForSelectedChannel(c, channel, originalModel)
  299. requestBody, err := common.GetRequestBody(c)
  300. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  301. taskErr = taskRelayHandler(c, relayMode)
  302. }
  303. useChannel := c.GetStringSlice("use_channel")
  304. if len(useChannel) > 1 {
  305. retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
  306. common.LogInfo(c, retryLogStr)
  307. }
  308. if taskErr != nil {
  309. if taskErr.StatusCode == http.StatusTooManyRequests {
  310. taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
  311. }
  312. c.JSON(taskErr.StatusCode, taskErr)
  313. }
  314. }
  315. func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
  316. var err *dto.TaskError
  317. switch relayMode {
  318. case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
  319. err = relay.RelayTaskFetch(c, relayMode)
  320. default:
  321. err = relay.RelayTaskSubmit(c, relayMode)
  322. }
  323. return err
  324. }
  325. func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
  326. if taskErr == nil {
  327. return false
  328. }
  329. if retryTimes <= 0 {
  330. return false
  331. }
  332. if _, ok := c.Get("specific_channel_id"); ok {
  333. return false
  334. }
  335. if taskErr.StatusCode == http.StatusTooManyRequests {
  336. return true
  337. }
  338. if taskErr.StatusCode == 307 {
  339. return true
  340. }
  341. if taskErr.StatusCode/100 == 5 {
  342. // 超时不重试
  343. if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
  344. return false
  345. }
  346. return true
  347. }
  348. if taskErr.StatusCode == http.StatusBadRequest {
  349. return false
  350. }
  351. if taskErr.StatusCode == 408 {
  352. // azure处理超时不重试
  353. return false
  354. }
  355. if taskErr.LocalError {
  356. return false
  357. }
  358. if taskErr.StatusCode/100 == 2 {
  359. return false
  360. }
  361. return true
  362. }