relay-controller.go 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851
  1. package controller
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "math/rand/v2"
  9. "net/http"
  10. "slices"
  11. "strconv"
  12. "time"
  13. "github.com/gin-gonic/gin"
  14. "github.com/labring/aiproxy/core/common"
  15. "github.com/labring/aiproxy/core/common/config"
  16. "github.com/labring/aiproxy/core/common/consume"
  17. "github.com/labring/aiproxy/core/common/notify"
  18. "github.com/labring/aiproxy/core/common/reqlimit"
  19. "github.com/labring/aiproxy/core/common/trylock"
  20. "github.com/labring/aiproxy/core/middleware"
  21. "github.com/labring/aiproxy/core/model"
  22. "github.com/labring/aiproxy/core/monitor"
  23. "github.com/labring/aiproxy/core/relay/adaptor"
  24. "github.com/labring/aiproxy/core/relay/adaptor/openai"
  25. "github.com/labring/aiproxy/core/relay/adaptors"
  26. "github.com/labring/aiproxy/core/relay/controller"
  27. "github.com/labring/aiproxy/core/relay/meta"
  28. "github.com/labring/aiproxy/core/relay/mode"
  29. relaymodel "github.com/labring/aiproxy/core/relay/model"
  30. log "github.com/sirupsen/logrus"
  31. )
  32. // https://platform.openai.com/docs/api-reference/chat
  33. type (
  34. RelayHandler func(*gin.Context, *meta.Meta) *controller.HandleResult
  35. GetRequestUsage func(*gin.Context, *model.ModelConfig) (model.Usage, error)
  36. GetRequestPrice func(*gin.Context, *model.ModelConfig) (model.Price, error)
  37. )
  38. type RelayController struct {
  39. GetRequestUsage GetRequestUsage
  40. GetRequestPrice GetRequestPrice
  41. Handler RelayHandler
  42. }
  43. var ErrInvalidChannelTypeCode = "invalid_channel_type"
  44. type warpAdaptor struct {
  45. adaptor.Adaptor
  46. }
  47. const (
  48. MetaChannelModelKeyRPM = "channel_model_rpm"
  49. MetaChannelModelKeyRPS = "channel_model_rps"
  50. MetaChannelModelKeyTPM = "channel_model_tpm"
  51. MetaChannelModelKeyTPS = "channel_model_tps"
  52. )
  53. func getChannelModelRequestRate(c *gin.Context, meta *meta.Meta) model.RequestRate {
  54. rate := model.RequestRate{}
  55. if rpm, ok := meta.Get(MetaChannelModelKeyRPM); ok {
  56. rate.RPM, _ = rpm.(int64)
  57. rate.RPS = meta.GetInt64(MetaChannelModelKeyRPS)
  58. } else {
  59. rpm, rps := reqlimit.GetChannelModelRequest(context.Background(), strconv.Itoa(meta.Channel.ID), meta.OriginModel)
  60. rate.RPM = rpm
  61. rate.RPS = rps
  62. updateChannelModelRequestRate(c, meta, rpm, rps)
  63. }
  64. if tpm, ok := meta.Get(MetaChannelModelKeyTPM); ok {
  65. rate.TPM, _ = tpm.(int64)
  66. rate.TPS = meta.GetInt64(MetaChannelModelKeyTPS)
  67. } else {
  68. tpm, tps := reqlimit.GetChannelModelTokensRequest(context.Background(), strconv.Itoa(meta.Channel.ID), meta.OriginModel)
  69. rate.TPM = tpm
  70. rate.TPS = tps
  71. updateChannelModelTokensRequestRate(c, meta, tpm, tps)
  72. }
  73. return rate
  74. }
  75. func updateChannelModelRequestRate(c *gin.Context, meta *meta.Meta, rpm, rps int64) {
  76. meta.Set(MetaChannelModelKeyRPM, rpm)
  77. meta.Set(MetaChannelModelKeyRPS, rps)
  78. log := middleware.GetLogger(c)
  79. log.Data["ch_rpm"] = rpm
  80. log.Data["ch_rps"] = rps
  81. }
  82. func updateChannelModelTokensRequestRate(c *gin.Context, meta *meta.Meta, tpm, tps int64) {
  83. meta.Set(MetaChannelModelKeyTPM, tpm)
  84. meta.Set(MetaChannelModelKeyTPS, tps)
  85. log := middleware.GetLogger(c)
  86. log.Data["ch_tpm"] = tpm
  87. log.Data["ch_tps"] = tps
  88. }
  89. func (w *warpAdaptor) DoRequest(meta *meta.Meta, c *gin.Context, req *http.Request) (*http.Response, error) {
  90. count, overLimitCount, secondCount := reqlimit.PushChannelModelRequest(
  91. context.Background(),
  92. strconv.Itoa(meta.Channel.ID),
  93. meta.OriginModel,
  94. )
  95. updateChannelModelRequestRate(c, meta, count+overLimitCount, secondCount)
  96. return w.Adaptor.DoRequest(meta, c, req)
  97. }
  98. func (w *warpAdaptor) DoResponse(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage, *relaymodel.ErrorWithStatusCode) {
  99. usage, relayErr := w.Adaptor.DoResponse(meta, c, resp)
  100. if usage == nil {
  101. return nil, relayErr
  102. }
  103. if usage.TotalTokens > 0 {
  104. count, overLimitCount, secondCount := reqlimit.PushChannelModelTokensRequest(
  105. context.Background(),
  106. strconv.Itoa(meta.Channel.ID),
  107. meta.OriginModel,
  108. int64(usage.TotalTokens),
  109. )
  110. updateChannelModelTokensRequestRate(c, meta, count+overLimitCount, secondCount)
  111. count, overLimitCount, secondCount = reqlimit.PushGroupModelTokensRequest(
  112. context.Background(),
  113. meta.Group.ID,
  114. meta.OriginModel,
  115. meta.ModelConfig.TPM,
  116. int64(usage.TotalTokens),
  117. )
  118. middleware.UpdateGroupModelTokensRequest(c, meta.Group, count+overLimitCount, secondCount)
  119. count, overLimitCount, secondCount = reqlimit.PushGroupModelTokennameTokensRequest(
  120. context.Background(),
  121. meta.Group.ID,
  122. meta.OriginModel,
  123. meta.Token.Name,
  124. int64(usage.TotalTokens),
  125. )
  126. middleware.UpdateGroupModelTokennameTokensRequest(c, count+overLimitCount, secondCount)
  127. }
  128. return usage, relayErr
  129. }
  130. func relayHandler(c *gin.Context, meta *meta.Meta) *controller.HandleResult {
  131. log := middleware.GetLogger(c)
  132. middleware.SetLogFieldsFromMeta(meta, log.Data)
  133. adaptor, ok := adaptors.GetAdaptor(meta.Channel.Type)
  134. if !ok {
  135. return &controller.HandleResult{
  136. Error: openai.ErrorWrapperWithMessage(
  137. fmt.Sprintf("invalid channel type: %d", meta.Channel.Type),
  138. ErrInvalidChannelTypeCode,
  139. http.StatusInternalServerError,
  140. ),
  141. }
  142. }
  143. return controller.Handle(&warpAdaptor{adaptor}, c, meta)
  144. }
  145. func relayController(m mode.Mode) RelayController {
  146. c := RelayController{
  147. Handler: relayHandler,
  148. }
  149. switch m {
  150. case mode.ImagesGenerations:
  151. c.GetRequestPrice = controller.GetImagesRequestPrice
  152. c.GetRequestUsage = controller.GetImagesRequestUsage
  153. case mode.ImagesEdits:
  154. c.GetRequestPrice = controller.GetImagesEditsRequestPrice
  155. c.GetRequestUsage = controller.GetImagesEditsRequestUsage
  156. case mode.AudioSpeech:
  157. c.GetRequestPrice = controller.GetTTSRequestPrice
  158. c.GetRequestUsage = controller.GetTTSRequestUsage
  159. case mode.AudioTranslation, mode.AudioTranscription:
  160. c.GetRequestPrice = controller.GetSTTRequestPrice
  161. c.GetRequestUsage = controller.GetSTTRequestUsage
  162. case mode.ParsePdf:
  163. c.GetRequestPrice = controller.GetPdfRequestPrice
  164. c.GetRequestUsage = controller.GetPdfRequestUsage
  165. case mode.Rerank:
  166. c.GetRequestPrice = controller.GetRerankRequestPrice
  167. c.GetRequestUsage = controller.GetRerankRequestUsage
  168. case mode.Anthropic:
  169. c.GetRequestPrice = controller.GetAnthropicRequestPrice
  170. c.GetRequestUsage = controller.GetAnthropicRequestUsage
  171. case mode.ChatCompletions:
  172. c.GetRequestPrice = controller.GetChatRequestPrice
  173. c.GetRequestUsage = controller.GetChatRequestUsage
  174. case mode.Embeddings:
  175. c.GetRequestPrice = controller.GetEmbedRequestPrice
  176. c.GetRequestUsage = controller.GetEmbedRequestUsage
  177. case mode.Completions:
  178. c.GetRequestPrice = controller.GetCompletionsRequestPrice
  179. c.GetRequestUsage = controller.GetCompletionsRequestUsage
  180. }
  181. return c
  182. }
  183. func RelayHelper(c *gin.Context, meta *meta.Meta, handel RelayHandler) (*controller.HandleResult, bool) {
  184. result := handel(c, meta)
  185. if result.Error == nil {
  186. if _, _, err := monitor.AddRequest(
  187. context.Background(),
  188. meta.OriginModel,
  189. int64(meta.Channel.ID),
  190. false,
  191. false,
  192. ); err != nil {
  193. log.Errorf("add request failed: %+v", err)
  194. }
  195. return result, false
  196. }
  197. shouldRetry := shouldRetry(c, *result.Error)
  198. if shouldRetry {
  199. hasPermission := channelHasPermission(*result.Error)
  200. beyondThreshold, banExecution, err := monitor.AddRequest(
  201. context.Background(),
  202. meta.OriginModel,
  203. int64(meta.Channel.ID),
  204. true,
  205. !hasPermission,
  206. )
  207. if err != nil {
  208. log.Errorf("add request failed: %+v", err)
  209. }
  210. switch {
  211. case banExecution:
  212. notifyChannelIssue(c, meta, "autoBanned", "Auto Banned", *result.Error)
  213. case beyondThreshold:
  214. notifyChannelIssue(c, meta, "beyondThreshold", "Error Rate Beyond Threshold", *result.Error)
  215. case !hasPermission:
  216. notifyChannelIssue(c, meta, "channelHasPermission", "No Permission", *result.Error)
  217. }
  218. }
  219. return result, shouldRetry
  220. }
  221. func notifyChannelIssue(c *gin.Context, meta *meta.Meta, issueType string, titleSuffix string, err relaymodel.ErrorWithStatusCode) {
  222. var notifyFunc func(title string, message string)
  223. lockKey := fmt.Sprintf("%s:%d:%s", issueType, meta.Channel.ID, meta.OriginModel)
  224. switch issueType {
  225. case "beyondThreshold":
  226. notifyFunc = func(title string, message string) {
  227. notify.WarnThrottle(lockKey, time.Minute, title, message)
  228. }
  229. default:
  230. notifyFunc = func(title string, message string) {
  231. notify.ErrorThrottle(lockKey, time.Minute, title, message)
  232. }
  233. }
  234. message := fmt.Sprintf(
  235. "channel: %s (type: %d, type name: %s, id: %d)\nmodel: %s\nmode: %s\nstatus code: %d\ndetail: %s\nrequest id: %s",
  236. meta.Channel.Name,
  237. meta.Channel.Type,
  238. meta.Channel.Type.String(),
  239. meta.Channel.ID,
  240. meta.OriginModel,
  241. meta.Mode,
  242. err.StatusCode,
  243. err.JSONOrEmpty(),
  244. meta.RequestID,
  245. )
  246. if err.StatusCode == http.StatusTooManyRequests {
  247. if !trylock.Lock(lockKey, time.Minute) {
  248. return
  249. }
  250. switch issueType {
  251. case "beyondThreshold":
  252. notifyFunc = notify.Warn
  253. default:
  254. notifyFunc = notify.Error
  255. }
  256. rate := getChannelModelRequestRate(c, meta)
  257. message += fmt.Sprintf("\nrpm: %d\nrps: %d\ntpm: %d\ntps: %d", rate.RPM, rate.RPS, rate.TPM, rate.TPS)
  258. }
  259. notifyFunc(
  260. fmt.Sprintf("%s `%s` %s", meta.Channel.Name, meta.OriginModel, titleSuffix),
  261. message,
  262. )
  263. }
  264. func filterChannels(channels []*model.Channel, ignoreChannel ...int64) []*model.Channel {
  265. filtered := make([]*model.Channel, 0)
  266. for _, channel := range channels {
  267. if channel.Status != model.ChannelStatusEnabled {
  268. continue
  269. }
  270. if slices.Contains(ignoreChannel, int64(channel.ID)) {
  271. continue
  272. }
  273. filtered = append(filtered, channel)
  274. }
  275. return filtered
  276. }
  277. var (
  278. ErrChannelsNotFound = errors.New("channels not found")
  279. ErrChannelsExhausted = errors.New("channels exhausted")
  280. )
  281. func GetRandomChannel(mc *model.ModelCaches, availableSet []string, modelName string, errorRates map[int64]float64, ignoreChannel ...int64) (*model.Channel, []*model.Channel, error) {
  282. channelMap := make(map[int]*model.Channel)
  283. for _, set := range availableSet {
  284. for _, channel := range mc.EnabledModel2ChannelsBySet[set][modelName] {
  285. channelMap[channel.ID] = channel
  286. }
  287. }
  288. migratedChannels := make([]*model.Channel, 0, len(channelMap))
  289. for _, channel := range channelMap {
  290. migratedChannels = append(migratedChannels, channel)
  291. }
  292. channel, err := getRandomChannel(migratedChannels, errorRates, ignoreChannel...)
  293. return channel, migratedChannels, err
  294. }
  295. func getPriority(channel *model.Channel, errorRate float64) int32 {
  296. priority := channel.GetPriority()
  297. if errorRate > 1 {
  298. errorRate = 1
  299. } else if errorRate < 0.1 {
  300. errorRate = 0.1
  301. }
  302. return int32(float64(priority) / errorRate)
  303. }
  304. //nolint:gosec
  305. func getRandomChannel(channels []*model.Channel, errorRates map[int64]float64, ignoreChannel ...int64) (*model.Channel, error) {
  306. if len(channels) == 0 {
  307. return nil, ErrChannelsNotFound
  308. }
  309. channels = filterChannels(channels, ignoreChannel...)
  310. if len(channels) == 0 {
  311. return nil, ErrChannelsExhausted
  312. }
  313. if len(channels) == 1 {
  314. return channels[0], nil
  315. }
  316. var totalWeight int32
  317. cachedPrioritys := make([]int32, len(channels))
  318. for i, ch := range channels {
  319. priority := getPriority(ch, errorRates[int64(ch.ID)])
  320. totalWeight += priority
  321. cachedPrioritys[i] = priority
  322. }
  323. if totalWeight == 0 {
  324. return channels[rand.IntN(len(channels))], nil
  325. }
  326. r := rand.Int32N(totalWeight)
  327. for i, ch := range channels {
  328. r -= cachedPrioritys[i]
  329. if r < 0 {
  330. return ch, nil
  331. }
  332. }
  333. return channels[rand.IntN(len(channels))], nil
  334. }
  335. func getChannelWithFallback(cache *model.ModelCaches, availableSet []string, modelName string, errorRates map[int64]float64, ignoreChannelIDs ...int64) (*model.Channel, []*model.Channel, error) {
  336. channel, migratedChannels, err := GetRandomChannel(cache, availableSet, modelName, errorRates, ignoreChannelIDs...)
  337. if err == nil {
  338. return channel, migratedChannels, nil
  339. }
  340. if !errors.Is(err, ErrChannelsExhausted) {
  341. return nil, migratedChannels, err
  342. }
  343. channel, migratedChannels, err = GetRandomChannel(cache, availableSet, modelName, errorRates)
  344. return channel, migratedChannels, err
  345. }
  346. func NewRelay(mode mode.Mode) func(c *gin.Context) {
  347. relayController := relayController(mode)
  348. return func(c *gin.Context) {
  349. relay(c, mode, relayController)
  350. }
  351. }
  352. func NewMetaByContext(c *gin.Context, channel *model.Channel, mode mode.Mode, opts ...meta.Option) *meta.Meta {
  353. return middleware.NewMetaByContext(c, channel, mode, opts...)
  354. }
  355. func relay(c *gin.Context, mode mode.Mode, relayController RelayController) {
  356. log := middleware.GetLogger(c)
  357. requestModel := middleware.GetRequestModel(c)
  358. mc := middleware.GetModelConfig(c)
  359. // Get initial channel
  360. initialChannel, err := getInitialChannel(c, requestModel, log)
  361. if err != nil || initialChannel == nil || initialChannel.channel == nil {
  362. middleware.AbortLogWithMessage(c,
  363. http.StatusServiceUnavailable,
  364. "the upstream load is saturated, please try again later",
  365. )
  366. return
  367. }
  368. billingEnabled := config.GetBillingEnabled()
  369. price := model.Price{}
  370. if billingEnabled && relayController.GetRequestPrice != nil {
  371. price, err = relayController.GetRequestPrice(c, mc)
  372. if err != nil {
  373. middleware.AbortLogWithMessage(c,
  374. http.StatusInternalServerError,
  375. "get request price failed: "+err.Error(),
  376. )
  377. return
  378. }
  379. }
  380. meta := NewMetaByContext(c, initialChannel.channel, mode)
  381. if billingEnabled && relayController.GetRequestUsage != nil {
  382. requestUsage, err := relayController.GetRequestUsage(c, mc)
  383. if err != nil {
  384. middleware.AbortLogWithMessage(c,
  385. http.StatusInternalServerError,
  386. "get request usage failed: "+err.Error(),
  387. )
  388. return
  389. }
  390. gbc := middleware.GetGroupBalanceConsumerFromContext(c)
  391. if !gbc.CheckBalance(consume.CalculateAmount(http.StatusOK, requestUsage, price)) {
  392. middleware.AbortLogWithMessage(c,
  393. http.StatusForbidden,
  394. fmt.Sprintf("group (%s) balance not enough", gbc.Group),
  395. &middleware.ErrorField{
  396. Code: middleware.GroupBalanceNotEnough,
  397. },
  398. )
  399. return
  400. }
  401. meta.RequestUsage = requestUsage
  402. }
  403. // First attempt
  404. result, retry := RelayHelper(c, meta, relayController.Handler)
  405. retryTimes := int(config.GetRetryTimes())
  406. if mc.RetryTimes > 0 {
  407. retryTimes = int(mc.RetryTimes)
  408. }
  409. if handleRelayResult(c, result.Error, retry, retryTimes) {
  410. recordResult(
  411. c,
  412. meta,
  413. price,
  414. result,
  415. 0,
  416. true,
  417. middleware.GetRequestUser(c),
  418. middleware.GetRequestMetadata(c),
  419. )
  420. return
  421. }
  422. // Setup retry state
  423. retryState := initRetryState(
  424. retryTimes,
  425. initialChannel,
  426. meta,
  427. result,
  428. price,
  429. )
  430. // Retry loop
  431. retryLoop(c, mode, retryState, relayController.Handler, log)
  432. }
  433. // recordResult records the consumption for the final result
  434. func recordResult(
  435. c *gin.Context,
  436. meta *meta.Meta,
  437. price model.Price,
  438. result *controller.HandleResult,
  439. retryTimes int,
  440. downstreamResult bool,
  441. user string,
  442. metadata map[string]string,
  443. ) {
  444. code := http.StatusOK
  445. content := ""
  446. if result.Error != nil {
  447. code = result.Error.StatusCode
  448. content = result.Error.JSONOrEmpty()
  449. }
  450. var detail *model.RequestDetail
  451. firstByteAt := result.Detail.FirstByteAt
  452. if code == http.StatusOK && !config.GetSaveAllLogDetail() {
  453. detail = nil
  454. } else {
  455. detail = &model.RequestDetail{
  456. RequestBody: result.Detail.RequestBody,
  457. ResponseBody: result.Detail.ResponseBody,
  458. }
  459. }
  460. gbc := middleware.GetGroupBalanceConsumerFromContext(c)
  461. amount := consume.CalculateAmount(
  462. code,
  463. result.Usage,
  464. price,
  465. )
  466. if amount > 0 {
  467. log := middleware.GetLogger(c)
  468. log.Data["amount"] = strconv.FormatFloat(amount, 'f', -1, 64)
  469. }
  470. consume.AsyncConsume(
  471. gbc.Consumer,
  472. code,
  473. firstByteAt,
  474. meta,
  475. result.Usage,
  476. price,
  477. content,
  478. c.ClientIP(),
  479. retryTimes,
  480. detail,
  481. downstreamResult,
  482. user,
  483. metadata,
  484. getChannelModelRequestRate(c, meta),
  485. middleware.GetGroupModelTokenRequestRate(c),
  486. )
  487. }
  488. type retryState struct {
  489. retryTimes int
  490. lastHasPermissionChannel *model.Channel
  491. ignoreChannelIDs []int64
  492. errorRates map[int64]float64
  493. exhausted bool
  494. meta *meta.Meta
  495. price model.Price
  496. requestUsage model.Usage
  497. result *controller.HandleResult
  498. migratedChannels []*model.Channel
  499. }
  500. type initialChannel struct {
  501. channel *model.Channel
  502. designatedChannel bool
  503. ignoreChannelIDs []int64
  504. errorRates map[int64]float64
  505. migratedChannels []*model.Channel
  506. }
  507. func getInitialChannel(c *gin.Context, modelName string, log *log.Entry) (*initialChannel, error) {
  508. if channel := middleware.GetChannel(c); channel != nil {
  509. log.Data["designated_channel"] = "true"
  510. return &initialChannel{channel: channel, designatedChannel: true}, nil
  511. }
  512. mc := middleware.GetModelCaches(c)
  513. ids, err := monitor.GetBannedChannelsWithModel(c.Request.Context(), modelName)
  514. if err != nil {
  515. log.Errorf("get %s auto banned channels failed: %+v", modelName, err)
  516. }
  517. log.Debugf("%s model banned channels: %+v", modelName, ids)
  518. errorRates, err := monitor.GetModelChannelErrorRate(c.Request.Context(), modelName)
  519. if err != nil {
  520. log.Errorf("get channel model error rates failed: %+v", err)
  521. }
  522. group := middleware.GetGroup(c)
  523. availableSet := group.GetAvailableSets()
  524. channel, migratedChannels, err := getChannelWithFallback(mc, availableSet, modelName, errorRates, ids...)
  525. if err != nil {
  526. return nil, err
  527. }
  528. return &initialChannel{
  529. channel: channel,
  530. ignoreChannelIDs: ids,
  531. errorRates: errorRates,
  532. migratedChannels: migratedChannels,
  533. }, nil
  534. }
  535. func handleRelayResult(c *gin.Context, bizErr *relaymodel.ErrorWithStatusCode, retry bool, retryTimes int) (done bool) {
  536. if bizErr == nil {
  537. return true
  538. }
  539. if !retry ||
  540. retryTimes == 0 ||
  541. c.Request.Context().Err() != nil {
  542. bizErr.Error.Message = middleware.MessageWithRequestID(c, bizErr.Error.Message)
  543. c.JSON(bizErr.StatusCode, bizErr)
  544. return true
  545. }
  546. return false
  547. }
  548. func initRetryState(retryTimes int, channel *initialChannel, meta *meta.Meta, result *controller.HandleResult, price model.Price) *retryState {
  549. state := &retryState{
  550. retryTimes: retryTimes,
  551. ignoreChannelIDs: channel.ignoreChannelIDs,
  552. errorRates: channel.errorRates,
  553. meta: meta,
  554. result: result,
  555. price: price,
  556. requestUsage: meta.RequestUsage,
  557. migratedChannels: channel.migratedChannels,
  558. }
  559. if channel.designatedChannel {
  560. state.exhausted = true
  561. }
  562. if !channelHasPermission(*result.Error) {
  563. state.ignoreChannelIDs = append(state.ignoreChannelIDs, int64(channel.channel.ID))
  564. } else {
  565. state.lastHasPermissionChannel = channel.channel
  566. }
  567. return state
  568. }
  569. func retryLoop(c *gin.Context, mode mode.Mode, state *retryState, relayController RelayHandler, log *log.Entry) {
  570. // do not use for i := range state.retryTimes, because the retryTimes is constant
  571. i := 0
  572. for {
  573. lastStatusCode := state.result.Error.StatusCode
  574. lastChannelID := state.meta.Channel.ID
  575. newChannel, err := getRetryChannel(state)
  576. if err == nil {
  577. err = prepareRetry(c)
  578. }
  579. if err != nil {
  580. if !errors.Is(err, ErrChannelsExhausted) {
  581. log.Errorf("prepare retry failed: %+v", err)
  582. }
  583. // when the last request has not recorded the result, record the result
  584. if state.meta != nil && state.result != nil {
  585. recordResult(
  586. c,
  587. state.meta,
  588. state.price,
  589. state.result,
  590. i,
  591. true,
  592. middleware.GetRequestUser(c),
  593. middleware.GetRequestMetadata(c),
  594. )
  595. }
  596. break
  597. }
  598. // when the last request has not recorded the result, record the result
  599. if state.meta != nil && state.result != nil {
  600. recordResult(
  601. c,
  602. state.meta,
  603. state.price,
  604. state.result,
  605. i,
  606. false,
  607. middleware.GetRequestUser(c),
  608. middleware.GetRequestMetadata(c),
  609. )
  610. state.meta = nil
  611. state.result = nil
  612. }
  613. log.Data["retry"] = strconv.Itoa(i + 1)
  614. log.Warnf("using channel %s (type: %d, id: %d) to retry (remain times %d)",
  615. newChannel.Name,
  616. newChannel.Type,
  617. newChannel.ID,
  618. state.retryTimes-i,
  619. )
  620. // Check if we should delay (using the same channel)
  621. if shouldDelay(lastStatusCode, lastChannelID, newChannel.ID) {
  622. relayDelay()
  623. }
  624. state.meta = NewMetaByContext(
  625. c,
  626. newChannel,
  627. mode,
  628. meta.WithRequestUsage(state.requestUsage),
  629. meta.WithRetryAt(time.Now()),
  630. )
  631. var retry bool
  632. state.result, retry = RelayHelper(c, state.meta, relayController)
  633. done := handleRetryResult(c, retry, newChannel, state)
  634. if done || i == state.retryTimes-1 {
  635. recordResult(
  636. c,
  637. state.meta,
  638. state.price,
  639. state.result,
  640. i+1,
  641. true,
  642. middleware.GetRequestUser(c),
  643. middleware.GetRequestMetadata(c),
  644. )
  645. break
  646. }
  647. i++
  648. }
  649. if state.result.Error != nil {
  650. state.result.Error.Error.Message = middleware.MessageWithRequestID(c, state.result.Error.Error.Message)
  651. c.JSON(state.result.Error.StatusCode, state.result.Error)
  652. }
  653. }
  654. func getRetryChannel(state *retryState) (*model.Channel, error) {
  655. if state.exhausted {
  656. if state.lastHasPermissionChannel == nil {
  657. return nil, ErrChannelsExhausted
  658. }
  659. return state.lastHasPermissionChannel, nil
  660. }
  661. newChannel, err := getRandomChannel(state.migratedChannels, state.errorRates, state.ignoreChannelIDs...)
  662. if err != nil {
  663. if !errors.Is(err, ErrChannelsExhausted) || state.lastHasPermissionChannel == nil {
  664. return nil, err
  665. }
  666. state.exhausted = true
  667. return state.lastHasPermissionChannel, nil
  668. }
  669. return newChannel, nil
  670. }
  671. func prepareRetry(c *gin.Context) error {
  672. requestBody, err := common.GetRequestBody(c.Request)
  673. if err != nil {
  674. return fmt.Errorf("get request body failed in prepare retry: %w", err)
  675. }
  676. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  677. return nil
  678. }
  679. func handleRetryResult(ctx *gin.Context, retry bool, newChannel *model.Channel, state *retryState) (done bool) {
  680. if ctx.Request.Context().Err() != nil {
  681. return true
  682. }
  683. if !retry || state.result.Error == nil {
  684. return true
  685. }
  686. hasPermission := channelHasPermission(*state.result.Error)
  687. if state.exhausted {
  688. if !hasPermission {
  689. return true
  690. }
  691. } else {
  692. if !hasPermission {
  693. state.ignoreChannelIDs = append(state.ignoreChannelIDs, int64(newChannel.ID))
  694. state.retryTimes++
  695. } else {
  696. state.lastHasPermissionChannel = newChannel
  697. }
  698. }
  699. return false
  700. }
  701. var channelNoRetryStatusCodesMap = map[int]struct{}{
  702. http.StatusBadRequest: {},
  703. http.StatusRequestEntityTooLarge: {},
  704. http.StatusUnprocessableEntity: {},
  705. http.StatusUnavailableForLegalReasons: {},
  706. }
  707. // 仅当是channel错误时,才需要记录,用户请求参数错误时,不需要记录
  708. func shouldRetry(_ *gin.Context, relayErr relaymodel.ErrorWithStatusCode) bool {
  709. if relayErr.Error.Code == ErrInvalidChannelTypeCode {
  710. return false
  711. }
  712. _, ok := channelNoRetryStatusCodesMap[relayErr.StatusCode]
  713. return !ok
  714. }
  715. var channelNoPermissionStatusCodesMap = map[int]struct{}{
  716. http.StatusUnauthorized: {},
  717. http.StatusPaymentRequired: {},
  718. http.StatusForbidden: {},
  719. http.StatusNotFound: {},
  720. }
  721. func channelHasPermission(relayErr relaymodel.ErrorWithStatusCode) bool {
  722. if relayErr.Error.Code == ErrInvalidChannelTypeCode {
  723. return false
  724. }
  725. _, ok := channelNoPermissionStatusCodesMap[relayErr.StatusCode]
  726. return !ok
  727. }
  728. // shouldDelay checks if we need to add a delay before retrying
  729. // Only adds delay when retrying with the same channel for rate limiting issues
  730. func shouldDelay(statusCode int, lastChannelID, newChannelID int) bool {
  731. if lastChannelID != newChannelID {
  732. return false
  733. }
  734. // Only delay for rate limiting or service unavailable errors
  735. return statusCode == http.StatusTooManyRequests ||
  736. statusCode == http.StatusServiceUnavailable
  737. }
  738. func relayDelay() {
  739. //nolint:gosec
  740. time.Sleep(time.Duration(rand.Float64()*float64(time.Second)) + time.Second)
  741. }
  742. func RelayNotImplemented(c *gin.Context) {
  743. c.JSON(http.StatusNotImplemented, gin.H{
  744. "error": &relaymodel.Error{
  745. Message: "API not implemented",
  746. Type: middleware.ErrorTypeAIPROXY,
  747. Code: "api_not_implemented",
  748. },
  749. })
  750. }