relay-controller.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696
  1. package controller
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "math/rand/v2"
  9. "net/http"
  10. "slices"
  11. "strconv"
  12. "time"
  13. "github.com/gin-gonic/gin"
  14. "github.com/labring/aiproxy/core/common"
  15. "github.com/labring/aiproxy/core/common/config"
  16. "github.com/labring/aiproxy/core/common/consume"
  17. "github.com/labring/aiproxy/core/common/notify"
  18. "github.com/labring/aiproxy/core/common/trylock"
  19. "github.com/labring/aiproxy/core/middleware"
  20. "github.com/labring/aiproxy/core/model"
  21. "github.com/labring/aiproxy/core/monitor"
  22. "github.com/labring/aiproxy/core/relay/controller"
  23. "github.com/labring/aiproxy/core/relay/meta"
  24. "github.com/labring/aiproxy/core/relay/mode"
  25. relaymodel "github.com/labring/aiproxy/core/relay/model"
  26. log "github.com/sirupsen/logrus"
  27. )
  28. // https://platform.openai.com/docs/api-reference/chat
  29. type (
  30. RelayHandler func(*meta.Meta, *gin.Context) *controller.HandleResult
  31. GetRequestUsage func(*gin.Context, *model.ModelConfig) (model.Usage, error)
  32. GetRequestPrice func(*gin.Context, *model.ModelConfig) (model.Price, error)
  33. )
  34. type RelayController struct {
  35. GetRequestUsage GetRequestUsage
  36. GetRequestPrice GetRequestPrice
  37. Handler RelayHandler
  38. }
  39. func relayHandler(meta *meta.Meta, c *gin.Context) *controller.HandleResult {
  40. log := middleware.GetLogger(c)
  41. middleware.SetLogFieldsFromMeta(meta, log.Data)
  42. return controller.Handle(meta, c)
  43. }
  44. func relayController(m mode.Mode) RelayController {
  45. c := RelayController{
  46. Handler: relayHandler,
  47. }
  48. switch m {
  49. case mode.ImagesGenerations:
  50. c.GetRequestPrice = controller.GetImagesRequestPrice
  51. c.GetRequestUsage = controller.GetImagesRequestUsage
  52. case mode.ImagesEdits:
  53. c.GetRequestPrice = controller.GetImagesEditsRequestPrice
  54. c.GetRequestUsage = controller.GetImagesEditsRequestUsage
  55. case mode.AudioSpeech:
  56. c.GetRequestPrice = controller.GetTTSRequestPrice
  57. c.GetRequestUsage = controller.GetTTSRequestUsage
  58. case mode.AudioTranslation, mode.AudioTranscription:
  59. c.GetRequestPrice = controller.GetSTTRequestPrice
  60. c.GetRequestUsage = controller.GetSTTRequestUsage
  61. case mode.ParsePdf:
  62. c.GetRequestPrice = controller.GetPdfRequestPrice
  63. c.GetRequestUsage = controller.GetPdfRequestUsage
  64. case mode.Rerank:
  65. c.GetRequestPrice = controller.GetRerankRequestPrice
  66. c.GetRequestUsage = controller.GetRerankRequestUsage
  67. case mode.Anthropic:
  68. c.GetRequestPrice = controller.GetAnthropicRequestPrice
  69. c.GetRequestUsage = controller.GetAnthropicRequestUsage
  70. case mode.ChatCompletions:
  71. c.GetRequestPrice = controller.GetChatRequestPrice
  72. c.GetRequestUsage = controller.GetChatRequestUsage
  73. case mode.Embeddings:
  74. c.GetRequestPrice = controller.GetEmbedRequestPrice
  75. c.GetRequestUsage = controller.GetEmbedRequestUsage
  76. case mode.Completions:
  77. c.GetRequestPrice = controller.GetCompletionsRequestPrice
  78. c.GetRequestUsage = controller.GetCompletionsRequestUsage
  79. }
  80. return c
  81. }
  82. func RelayHelper(meta *meta.Meta, c *gin.Context, handel RelayHandler) (*controller.HandleResult, bool) {
  83. result := handel(meta, c)
  84. if result.Error == nil {
  85. if _, _, err := monitor.AddRequest(
  86. context.Background(),
  87. meta.OriginModel,
  88. int64(meta.Channel.ID),
  89. false,
  90. false,
  91. ); err != nil {
  92. log.Errorf("add request failed: %+v", err)
  93. }
  94. return result, false
  95. }
  96. shouldRetry := shouldRetry(c, *result.Error)
  97. if shouldRetry {
  98. hasPermission := channelHasPermission(*result.Error)
  99. beyondThreshold, banExecution, err := monitor.AddRequest(
  100. context.Background(),
  101. meta.OriginModel,
  102. int64(meta.Channel.ID),
  103. true,
  104. !hasPermission,
  105. )
  106. if err != nil {
  107. log.Errorf("add request failed: %+v", err)
  108. }
  109. switch {
  110. case banExecution:
  111. notifyChannelIssue(meta, "autoBanned", "Auto Banned", *result.Error)
  112. case beyondThreshold:
  113. notifyChannelIssue(meta, "beyondThreshold", "Error Rate Beyond Threshold", *result.Error)
  114. case !hasPermission:
  115. notifyChannelIssue(meta, "channelHasPermission", "No Permission", *result.Error)
  116. }
  117. }
  118. return result, shouldRetry
  119. }
  120. func notifyChannelIssue(meta *meta.Meta, issueType string, titleSuffix string, err relaymodel.ErrorWithStatusCode) {
  121. var notifyFunc func(title string, message string)
  122. lockKey := fmt.Sprintf("%s:%d:%s", issueType, meta.Channel.ID, meta.OriginModel)
  123. switch issueType {
  124. case "beyondThreshold":
  125. notifyFunc = func(title string, message string) {
  126. notify.WarnThrottle(lockKey, time.Minute, title, message)
  127. }
  128. default:
  129. notifyFunc = func(title string, message string) {
  130. notify.ErrorThrottle(lockKey, time.Minute, title, message)
  131. }
  132. }
  133. message := fmt.Sprintf(
  134. "channel: %s (type: %d, type name: %s, id: %d)\nmodel: %s\nmode: %s\nstatus code: %d\ndetail: %s\nrequest id: %s",
  135. meta.Channel.Name,
  136. meta.Channel.Type,
  137. meta.Channel.Type.String(),
  138. meta.Channel.ID,
  139. meta.OriginModel,
  140. meta.Mode,
  141. err.StatusCode,
  142. err.JSONOrEmpty(),
  143. meta.RequestID,
  144. )
  145. if err.StatusCode == http.StatusTooManyRequests {
  146. if !trylock.Lock(lockKey, time.Minute) {
  147. return
  148. }
  149. switch issueType {
  150. case "beyondThreshold":
  151. notifyFunc = notify.Warn
  152. default:
  153. notifyFunc = notify.Error
  154. }
  155. now := time.Now()
  156. group := "*"
  157. rpm, rpmErr := model.GetRPM(group, now, "", meta.OriginModel, meta.Channel.ID)
  158. tpm, tpmErr := model.GetTPM(group, now, "", meta.OriginModel, meta.Channel.ID)
  159. if rpmErr != nil {
  160. message += fmt.Sprintf("\nrpm: %v", rpmErr)
  161. } else {
  162. message += fmt.Sprintf("\nrpm: %d", rpm)
  163. }
  164. if tpmErr != nil {
  165. message += fmt.Sprintf("\ntpm: %v", tpmErr)
  166. } else {
  167. message += fmt.Sprintf("\ntpm: %d", tpm)
  168. }
  169. }
  170. notifyFunc(
  171. fmt.Sprintf("%s `%s` %s", meta.Channel.Name, meta.OriginModel, titleSuffix),
  172. message,
  173. )
  174. }
  175. func filterChannels(channels []*model.Channel, ignoreChannel ...int64) []*model.Channel {
  176. filtered := make([]*model.Channel, 0)
  177. for _, channel := range channels {
  178. if channel.Status != model.ChannelStatusEnabled {
  179. continue
  180. }
  181. if slices.Contains(ignoreChannel, int64(channel.ID)) {
  182. continue
  183. }
  184. filtered = append(filtered, channel)
  185. }
  186. return filtered
  187. }
  188. var (
  189. ErrChannelsNotFound = errors.New("channels not found")
  190. ErrChannelsExhausted = errors.New("channels exhausted")
  191. )
  192. func GetRandomChannel(mc *model.ModelCaches, availableSet []string, modelName string, errorRates map[int64]float64, ignoreChannel ...int64) (*model.Channel, []*model.Channel, error) {
  193. channelMap := make(map[int]*model.Channel)
  194. for _, set := range availableSet {
  195. for _, channel := range mc.EnabledModel2ChannelsBySet[set][modelName] {
  196. channelMap[channel.ID] = channel
  197. }
  198. }
  199. migratedChannels := make([]*model.Channel, 0, len(channelMap))
  200. for _, channel := range channelMap {
  201. migratedChannels = append(migratedChannels, channel)
  202. }
  203. channel, err := getRandomChannel(migratedChannels, errorRates, ignoreChannel...)
  204. return channel, migratedChannels, err
  205. }
  206. func getPriority(channel *model.Channel, errorRate float64) int32 {
  207. priority := channel.GetPriority()
  208. if errorRate > 1 {
  209. errorRate = 1
  210. } else if errorRate < 0.1 {
  211. errorRate = 0.1
  212. }
  213. return int32(float64(priority) / errorRate)
  214. }
  215. //nolint:gosec
  216. func getRandomChannel(channels []*model.Channel, errorRates map[int64]float64, ignoreChannel ...int64) (*model.Channel, error) {
  217. if len(channels) == 0 {
  218. return nil, ErrChannelsNotFound
  219. }
  220. channels = filterChannels(channels, ignoreChannel...)
  221. if len(channels) == 0 {
  222. return nil, ErrChannelsExhausted
  223. }
  224. if len(channels) == 1 {
  225. return channels[0], nil
  226. }
  227. var totalWeight int32
  228. cachedPrioritys := make([]int32, len(channels))
  229. for i, ch := range channels {
  230. priority := getPriority(ch, errorRates[int64(ch.ID)])
  231. totalWeight += priority
  232. cachedPrioritys[i] = priority
  233. }
  234. if totalWeight == 0 {
  235. return channels[rand.IntN(len(channels))], nil
  236. }
  237. r := rand.Int32N(totalWeight)
  238. for i, ch := range channels {
  239. r -= cachedPrioritys[i]
  240. if r < 0 {
  241. return ch, nil
  242. }
  243. }
  244. return channels[rand.IntN(len(channels))], nil
  245. }
  246. func getChannelWithFallback(cache *model.ModelCaches, availableSet []string, modelName string, errorRates map[int64]float64, ignoreChannelIDs ...int64) (*model.Channel, []*model.Channel, error) {
  247. channel, migratedChannels, err := GetRandomChannel(cache, availableSet, modelName, errorRates, ignoreChannelIDs...)
  248. if err == nil {
  249. return channel, migratedChannels, nil
  250. }
  251. if !errors.Is(err, ErrChannelsExhausted) {
  252. return nil, migratedChannels, err
  253. }
  254. channel, migratedChannels, err = GetRandomChannel(cache, availableSet, modelName, errorRates)
  255. return channel, migratedChannels, err
  256. }
  257. func NewRelay(mode mode.Mode) func(c *gin.Context) {
  258. relayController := relayController(mode)
  259. return func(c *gin.Context) {
  260. relay(c, mode, relayController)
  261. }
  262. }
  263. func NewMetaByContext(c *gin.Context, channel *model.Channel, mode mode.Mode, opts ...meta.Option) *meta.Meta {
  264. return middleware.NewMetaByContext(c, channel, mode, opts...)
  265. }
  266. func relay(c *gin.Context, mode mode.Mode, relayController RelayController) {
  267. log := middleware.GetLogger(c)
  268. requestModel := middleware.GetRequestModel(c)
  269. mc := middleware.GetModelConfig(c)
  270. // Get initial channel
  271. initialChannel, err := getInitialChannel(c, requestModel, log)
  272. if err != nil || initialChannel == nil || initialChannel.channel == nil {
  273. middleware.AbortLogWithMessage(c,
  274. http.StatusServiceUnavailable,
  275. "the upstream load is saturated, please try again later",
  276. )
  277. return
  278. }
  279. billingEnabled := config.GetBillingEnabled()
  280. price := model.Price{}
  281. if billingEnabled && relayController.GetRequestPrice != nil {
  282. price, err = relayController.GetRequestPrice(c, mc)
  283. if err != nil {
  284. middleware.AbortLogWithMessage(c,
  285. http.StatusInternalServerError,
  286. "get request price failed: "+err.Error(),
  287. )
  288. return
  289. }
  290. }
  291. meta := NewMetaByContext(c, initialChannel.channel, mode)
  292. if billingEnabled && relayController.GetRequestUsage != nil {
  293. requestUsage, err := relayController.GetRequestUsage(c, mc)
  294. if err != nil {
  295. middleware.AbortLogWithMessage(c,
  296. http.StatusInternalServerError,
  297. "get request usage failed: "+err.Error(),
  298. )
  299. return
  300. }
  301. gbc := middleware.GetGroupBalanceConsumerFromContext(c)
  302. if !gbc.CheckBalance(consume.CalculateAmount(http.StatusOK, requestUsage, price)) {
  303. middleware.AbortLogWithMessage(c,
  304. http.StatusForbidden,
  305. fmt.Sprintf("group (%s) balance not enough", gbc.Group),
  306. &middleware.ErrorField{
  307. Code: middleware.GroupBalanceNotEnough,
  308. },
  309. )
  310. return
  311. }
  312. meta.RequestUsage = requestUsage
  313. }
  314. // First attempt
  315. result, retry := RelayHelper(meta, c, relayController.Handler)
  316. retryTimes := int(config.GetRetryTimes())
  317. if mc.RetryTimes > 0 {
  318. retryTimes = int(mc.RetryTimes)
  319. }
  320. if handleRelayResult(c, result.Error, retry, retryTimes) {
  321. recordResult(c, meta, price, result, 0, true)
  322. return
  323. }
  324. // Setup retry state
  325. retryState := initRetryState(
  326. retryTimes,
  327. initialChannel,
  328. meta,
  329. result,
  330. price,
  331. )
  332. // Retry loop
  333. retryLoop(c, mode, retryState, relayController.Handler, log)
  334. }
  335. // recordResult records the consumption for the final result
  336. func recordResult(c *gin.Context, meta *meta.Meta, price model.Price, result *controller.HandleResult, retryTimes int, downstreamResult bool) {
  337. code := http.StatusOK
  338. content := ""
  339. if result.Error != nil {
  340. code = result.Error.StatusCode
  341. content = result.Error.JSONOrEmpty()
  342. }
  343. var detail *model.RequestDetail
  344. firstByteAt := result.Detail.FirstByteAt
  345. if code == http.StatusOK && !config.GetSaveAllLogDetail() {
  346. detail = nil
  347. } else {
  348. detail = &model.RequestDetail{
  349. RequestBody: result.Detail.RequestBody,
  350. ResponseBody: result.Detail.ResponseBody,
  351. }
  352. }
  353. gbc := middleware.GetGroupBalanceConsumerFromContext(c)
  354. amount := consume.CalculateAmount(
  355. code,
  356. result.Usage,
  357. price,
  358. )
  359. if amount > 0 {
  360. log := middleware.GetLogger(c)
  361. log.Data["amount"] = strconv.FormatFloat(amount, 'f', -1, 64)
  362. }
  363. consume.AsyncConsume(
  364. gbc.Consumer,
  365. code,
  366. firstByteAt,
  367. meta,
  368. result.Usage,
  369. price,
  370. content,
  371. c.ClientIP(),
  372. retryTimes,
  373. detail,
  374. downstreamResult,
  375. )
  376. }
  377. type retryState struct {
  378. retryTimes int
  379. lastHasPermissionChannel *model.Channel
  380. ignoreChannelIDs []int64
  381. errorRates map[int64]float64
  382. exhausted bool
  383. meta *meta.Meta
  384. price model.Price
  385. requestUsage model.Usage
  386. result *controller.HandleResult
  387. migratedChannels []*model.Channel
  388. }
  389. type initialChannel struct {
  390. channel *model.Channel
  391. designatedChannel bool
  392. ignoreChannelIDs []int64
  393. errorRates map[int64]float64
  394. migratedChannels []*model.Channel
  395. }
  396. func getInitialChannel(c *gin.Context, modelName string, log *log.Entry) (*initialChannel, error) {
  397. if channel := middleware.GetChannel(c); channel != nil {
  398. log.Data["designated_channel"] = "true"
  399. return &initialChannel{channel: channel, designatedChannel: true}, nil
  400. }
  401. mc := middleware.GetModelCaches(c)
  402. ids, err := monitor.GetBannedChannelsWithModel(c.Request.Context(), modelName)
  403. if err != nil {
  404. log.Errorf("get %s auto banned channels failed: %+v", modelName, err)
  405. }
  406. log.Debugf("%s model banned channels: %+v", modelName, ids)
  407. errorRates, err := monitor.GetModelChannelErrorRate(c.Request.Context(), modelName)
  408. if err != nil {
  409. log.Errorf("get channel model error rates failed: %+v", err)
  410. }
  411. group := middleware.GetGroup(c)
  412. availableSet := group.GetAvailableSets()
  413. channel, migratedChannels, err := getChannelWithFallback(mc, availableSet, modelName, errorRates, ids...)
  414. if err != nil {
  415. return nil, err
  416. }
  417. return &initialChannel{
  418. channel: channel,
  419. ignoreChannelIDs: ids,
  420. errorRates: errorRates,
  421. migratedChannels: migratedChannels,
  422. }, nil
  423. }
  424. func handleRelayResult(c *gin.Context, bizErr *relaymodel.ErrorWithStatusCode, retry bool, retryTimes int) (done bool) {
  425. if bizErr == nil {
  426. return true
  427. }
  428. if !retry ||
  429. retryTimes == 0 ||
  430. c.Request.Context().Err() != nil {
  431. bizErr.Error.Message = middleware.MessageWithRequestID(c, bizErr.Error.Message)
  432. c.JSON(bizErr.StatusCode, bizErr)
  433. return true
  434. }
  435. return false
  436. }
  437. func initRetryState(retryTimes int, channel *initialChannel, meta *meta.Meta, result *controller.HandleResult, price model.Price) *retryState {
  438. state := &retryState{
  439. retryTimes: retryTimes,
  440. ignoreChannelIDs: channel.ignoreChannelIDs,
  441. errorRates: channel.errorRates,
  442. meta: meta,
  443. result: result,
  444. price: price,
  445. requestUsage: meta.RequestUsage,
  446. migratedChannels: channel.migratedChannels,
  447. }
  448. if channel.designatedChannel {
  449. state.exhausted = true
  450. }
  451. if !channelHasPermission(*result.Error) {
  452. state.ignoreChannelIDs = append(state.ignoreChannelIDs, int64(channel.channel.ID))
  453. } else {
  454. state.lastHasPermissionChannel = channel.channel
  455. }
  456. return state
  457. }
  458. func retryLoop(c *gin.Context, mode mode.Mode, state *retryState, relayController RelayHandler, log *log.Entry) {
  459. // do not use for i := range state.retryTimes, because the retryTimes is constant
  460. i := 0
  461. for {
  462. lastStatusCode := state.result.Error.StatusCode
  463. lastChannelID := state.meta.Channel.ID
  464. newChannel, err := getRetryChannel(state)
  465. if err == nil {
  466. err = prepareRetry(c)
  467. }
  468. if err != nil {
  469. if !errors.Is(err, ErrChannelsExhausted) {
  470. log.Errorf("prepare retry failed: %+v", err)
  471. }
  472. // when the last request has not recorded the result, record the result
  473. if state.meta != nil && state.result != nil {
  474. recordResult(c, state.meta, state.price, state.result, i, true)
  475. }
  476. break
  477. }
  478. // when the last request has not recorded the result, record the result
  479. if state.meta != nil && state.result != nil {
  480. recordResult(c, state.meta, state.price, state.result, i, false)
  481. state.meta = nil
  482. state.result = nil
  483. }
  484. log.Data["retry"] = strconv.Itoa(i + 1)
  485. log.Warnf("using channel %s (type: %d, id: %d) to retry (remain times %d)",
  486. newChannel.Name,
  487. newChannel.Type,
  488. newChannel.ID,
  489. state.retryTimes-i,
  490. )
  491. // Check if we should delay (using the same channel)
  492. if shouldDelay(lastStatusCode, lastChannelID, newChannel.ID) {
  493. relayDelay()
  494. }
  495. state.meta = NewMetaByContext(
  496. c,
  497. newChannel,
  498. mode,
  499. meta.WithRequestUsage(state.requestUsage),
  500. meta.WithRetryAt(time.Now()),
  501. )
  502. var retry bool
  503. state.result, retry = RelayHelper(state.meta, c, relayController)
  504. done := handleRetryResult(c, retry, newChannel, state)
  505. if done || i == state.retryTimes-1 {
  506. recordResult(c, state.meta, state.price, state.result, i+1, true)
  507. break
  508. }
  509. i++
  510. }
  511. if state.result.Error != nil {
  512. state.result.Error.Error.Message = middleware.MessageWithRequestID(c, state.result.Error.Error.Message)
  513. c.JSON(state.result.Error.StatusCode, state.result.Error)
  514. }
  515. }
  516. func getRetryChannel(state *retryState) (*model.Channel, error) {
  517. if state.exhausted {
  518. if state.lastHasPermissionChannel == nil {
  519. return nil, ErrChannelsExhausted
  520. }
  521. return state.lastHasPermissionChannel, nil
  522. }
  523. newChannel, err := getRandomChannel(state.migratedChannels, state.errorRates, state.ignoreChannelIDs...)
  524. if err != nil {
  525. if !errors.Is(err, ErrChannelsExhausted) || state.lastHasPermissionChannel == nil {
  526. return nil, err
  527. }
  528. state.exhausted = true
  529. return state.lastHasPermissionChannel, nil
  530. }
  531. return newChannel, nil
  532. }
  533. func prepareRetry(c *gin.Context) error {
  534. requestBody, err := common.GetRequestBody(c.Request)
  535. if err != nil {
  536. return fmt.Errorf("get request body failed in prepare retry: %w", err)
  537. }
  538. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  539. return nil
  540. }
  541. func handleRetryResult(ctx *gin.Context, retry bool, newChannel *model.Channel, state *retryState) (done bool) {
  542. if ctx.Request.Context().Err() != nil {
  543. return true
  544. }
  545. if !retry || state.result.Error == nil {
  546. return true
  547. }
  548. hasPermission := channelHasPermission(*state.result.Error)
  549. if state.exhausted {
  550. if !hasPermission {
  551. return true
  552. }
  553. } else {
  554. if !hasPermission {
  555. state.ignoreChannelIDs = append(state.ignoreChannelIDs, int64(newChannel.ID))
  556. state.retryTimes++
  557. } else {
  558. state.lastHasPermissionChannel = newChannel
  559. }
  560. }
  561. return false
  562. }
  563. var channelNoRetryStatusCodesMap = map[int]struct{}{
  564. http.StatusBadRequest: {},
  565. http.StatusRequestEntityTooLarge: {},
  566. http.StatusUnprocessableEntity: {},
  567. http.StatusUnavailableForLegalReasons: {},
  568. }
  569. // 仅当是channel错误时,才需要记录,用户请求参数错误时,不需要记录
  570. func shouldRetry(_ *gin.Context, relayErr relaymodel.ErrorWithStatusCode) bool {
  571. if relayErr.Error.Code == controller.ErrInvalidChannelTypeCode {
  572. return false
  573. }
  574. _, ok := channelNoRetryStatusCodesMap[relayErr.StatusCode]
  575. return !ok
  576. }
  577. var channelNoPermissionStatusCodesMap = map[int]struct{}{
  578. http.StatusUnauthorized: {},
  579. http.StatusPaymentRequired: {},
  580. http.StatusForbidden: {},
  581. http.StatusNotFound: {},
  582. }
  583. func channelHasPermission(relayErr relaymodel.ErrorWithStatusCode) bool {
  584. if relayErr.Error.Code == controller.ErrInvalidChannelTypeCode {
  585. return false
  586. }
  587. _, ok := channelNoPermissionStatusCodesMap[relayErr.StatusCode]
  588. return !ok
  589. }
  590. // shouldDelay checks if we need to add a delay before retrying
  591. // Only adds delay when retrying with the same channel for rate limiting issues
  592. func shouldDelay(statusCode int, lastChannelID, newChannelID int) bool {
  593. if lastChannelID != newChannelID {
  594. return false
  595. }
  596. // Only delay for rate limiting or service unavailable errors
  597. return statusCode == http.StatusTooManyRequests ||
  598. statusCode == http.StatusServiceUnavailable
  599. }
  600. func relayDelay() {
  601. //nolint:gosec
  602. time.Sleep(time.Duration(rand.Float64()*float64(time.Second)) + time.Second)
  603. }
  604. func RelayNotImplemented(c *gin.Context) {
  605. c.JSON(http.StatusNotImplemented, gin.H{
  606. "error": &relaymodel.Error{
  607. Message: "API not implemented",
  608. Type: middleware.ErrorTypeAIPROXY,
  609. Code: "api_not_implemented",
  610. },
  611. })
  612. }