relay-controller.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743
  1. package controller
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "math/rand/v2"
  9. "net/http"
  10. "slices"
  11. "strconv"
  12. "time"
  13. "github.com/gin-gonic/gin"
  14. "github.com/labring/aiproxy/core/common"
  15. "github.com/labring/aiproxy/core/common/config"
  16. "github.com/labring/aiproxy/core/common/consume"
  17. "github.com/labring/aiproxy/core/common/notify"
  18. "github.com/labring/aiproxy/core/common/trylock"
  19. "github.com/labring/aiproxy/core/middleware"
  20. "github.com/labring/aiproxy/core/model"
  21. "github.com/labring/aiproxy/core/monitor"
  22. "github.com/labring/aiproxy/core/relay/controller"
  23. "github.com/labring/aiproxy/core/relay/meta"
  24. "github.com/labring/aiproxy/core/relay/mode"
  25. relaymodel "github.com/labring/aiproxy/core/relay/model"
  26. log "github.com/sirupsen/logrus"
  27. )
  28. // https://platform.openai.com/docs/api-reference/chat
  29. type (
  30. RelayHandler func(*meta.Meta, *gin.Context) *controller.HandleResult
  31. GetRequestUsage func(*gin.Context, *model.ModelConfig) (model.Usage, error)
  32. GetRequestPrice func(*gin.Context, *model.ModelConfig) (model.Price, error)
  33. )
  34. type RelayController struct {
  35. GetRequestUsage GetRequestUsage
  36. GetRequestPrice GetRequestPrice
  37. Handler RelayHandler
  38. }
  39. func relayHandler(meta *meta.Meta, c *gin.Context) *controller.HandleResult {
  40. log := middleware.GetLogger(c)
  41. middleware.SetLogFieldsFromMeta(meta, log.Data)
  42. return controller.Handle(meta, c)
  43. }
  44. func relayController(m mode.Mode) RelayController {
  45. c := RelayController{
  46. Handler: relayHandler,
  47. }
  48. switch m {
  49. case mode.ImagesGenerations:
  50. c.GetRequestPrice = controller.GetImagesRequestPrice
  51. c.GetRequestUsage = controller.GetImagesRequestUsage
  52. case mode.ImagesEdits:
  53. c.GetRequestPrice = controller.GetImagesEditsRequestPrice
  54. c.GetRequestUsage = controller.GetImagesEditsRequestUsage
  55. case mode.AudioSpeech:
  56. c.GetRequestPrice = controller.GetTTSRequestPrice
  57. c.GetRequestUsage = controller.GetTTSRequestUsage
  58. case mode.AudioTranslation, mode.AudioTranscription:
  59. c.GetRequestPrice = controller.GetSTTRequestPrice
  60. c.GetRequestUsage = controller.GetSTTRequestUsage
  61. case mode.ParsePdf:
  62. c.GetRequestPrice = controller.GetPdfRequestPrice
  63. c.GetRequestUsage = controller.GetPdfRequestUsage
  64. case mode.Rerank:
  65. c.GetRequestPrice = controller.GetRerankRequestPrice
  66. c.GetRequestUsage = controller.GetRerankRequestUsage
  67. case mode.Anthropic:
  68. c.GetRequestPrice = controller.GetAnthropicRequestPrice
  69. c.GetRequestUsage = controller.GetAnthropicRequestUsage
  70. case mode.ChatCompletions:
  71. c.GetRequestPrice = controller.GetChatRequestPrice
  72. c.GetRequestUsage = controller.GetChatRequestUsage
  73. case mode.Embeddings:
  74. c.GetRequestPrice = controller.GetEmbedRequestPrice
  75. c.GetRequestUsage = controller.GetEmbedRequestUsage
  76. case mode.Completions:
  77. c.GetRequestPrice = controller.GetCompletionsRequestPrice
  78. c.GetRequestUsage = controller.GetCompletionsRequestUsage
  79. }
  80. return c
  81. }
  82. func RelayHelper(meta *meta.Meta, c *gin.Context, handel RelayHandler) (*controller.HandleResult, bool) {
  83. result := handel(meta, c)
  84. if result.Error == nil {
  85. if _, _, err := monitor.AddRequest(
  86. context.Background(),
  87. meta.OriginModel,
  88. int64(meta.Channel.ID),
  89. false,
  90. false,
  91. ); err != nil {
  92. log.Errorf("add request failed: %+v", err)
  93. }
  94. return result, false
  95. }
  96. shouldRetry := shouldRetry(c, *result.Error)
  97. if shouldRetry {
  98. hasPermission := channelHasPermission(*result.Error)
  99. beyondThreshold, banExecution, err := monitor.AddRequest(
  100. context.Background(),
  101. meta.OriginModel,
  102. int64(meta.Channel.ID),
  103. true,
  104. !hasPermission,
  105. )
  106. if err != nil {
  107. log.Errorf("add request failed: %+v", err)
  108. }
  109. switch {
  110. case banExecution:
  111. notifyChannelIssue(meta, "autoBanned", "Auto Banned", *result.Error)
  112. case beyondThreshold:
  113. notifyChannelIssue(meta, "beyondThreshold", "Error Rate Beyond Threshold", *result.Error)
  114. case !hasPermission:
  115. notifyChannelIssue(meta, "channelHasPermission", "No Permission", *result.Error)
  116. }
  117. }
  118. return result, shouldRetry
  119. }
  120. func notifyChannelIssue(meta *meta.Meta, issueType string, titleSuffix string, err relaymodel.ErrorWithStatusCode) {
  121. var notifyFunc func(title string, message string)
  122. lockKey := fmt.Sprintf("%s:%d:%s", issueType, meta.Channel.ID, meta.OriginModel)
  123. switch issueType {
  124. case "beyondThreshold":
  125. notifyFunc = func(title string, message string) {
  126. notify.WarnThrottle(lockKey, time.Minute, title, message)
  127. }
  128. default:
  129. notifyFunc = func(title string, message string) {
  130. notify.ErrorThrottle(lockKey, time.Minute, title, message)
  131. }
  132. }
  133. message := fmt.Sprintf(
  134. "channel: %s (type: %d, type name: %s, id: %d)\nmodel: %s\nmode: %s\nstatus code: %d\ndetail: %s\nrequest id: %s",
  135. meta.Channel.Name,
  136. meta.Channel.Type,
  137. meta.Channel.Type.String(),
  138. meta.Channel.ID,
  139. meta.OriginModel,
  140. meta.Mode,
  141. err.StatusCode,
  142. err.JSONOrEmpty(),
  143. meta.RequestID,
  144. )
  145. if err.StatusCode == http.StatusTooManyRequests {
  146. if !trylock.Lock(lockKey, time.Minute) {
  147. return
  148. }
  149. switch issueType {
  150. case "beyondThreshold":
  151. notifyFunc = notify.Warn
  152. default:
  153. notifyFunc = notify.Error
  154. }
  155. now := time.Now()
  156. group := "*"
  157. rpm, rpmErr := model.GetRPM(group, now, "", meta.OriginModel, meta.Channel.ID)
  158. tpm, tpmErr := model.GetTPM(group, now, "", meta.OriginModel, meta.Channel.ID)
  159. if rpmErr != nil {
  160. message += fmt.Sprintf("\nrpm: %v", rpmErr)
  161. } else {
  162. message += fmt.Sprintf("\nrpm: %d", rpm)
  163. }
  164. if tpmErr != nil {
  165. message += fmt.Sprintf("\ntpm: %v", tpmErr)
  166. } else {
  167. message += fmt.Sprintf("\ntpm: %d", tpm)
  168. }
  169. }
  170. notifyFunc(
  171. fmt.Sprintf("%s `%s` %s", meta.Channel.Name, meta.OriginModel, titleSuffix),
  172. message,
  173. )
  174. }
  175. func filterChannels(channels []*model.Channel, ignoreChannel ...int64) []*model.Channel {
  176. filtered := make([]*model.Channel, 0)
  177. for _, channel := range channels {
  178. if channel.Status != model.ChannelStatusEnabled {
  179. continue
  180. }
  181. if slices.Contains(ignoreChannel, int64(channel.ID)) {
  182. continue
  183. }
  184. filtered = append(filtered, channel)
  185. }
  186. return filtered
  187. }
  188. var (
  189. ErrChannelsNotFound = errors.New("channels not found")
  190. ErrChannelsExhausted = errors.New("channels exhausted")
  191. )
  192. func GetRandomChannel(mc *model.ModelCaches, availableSet []string, modelName string, errorRates map[int64]float64, ignoreChannel ...int64) (*model.Channel, []*model.Channel, error) {
  193. channelMap := make(map[int]*model.Channel)
  194. for _, set := range availableSet {
  195. for _, channel := range mc.EnabledModel2ChannelsBySet[set][modelName] {
  196. channelMap[channel.ID] = channel
  197. }
  198. }
  199. migratedChannels := make([]*model.Channel, 0, len(channelMap))
  200. for _, channel := range channelMap {
  201. migratedChannels = append(migratedChannels, channel)
  202. }
  203. channel, err := getRandomChannel(migratedChannels, errorRates, ignoreChannel...)
  204. return channel, migratedChannels, err
  205. }
  206. func getPriority(channel *model.Channel, errorRate float64) int32 {
  207. priority := channel.GetPriority()
  208. if errorRate > 1 {
  209. errorRate = 1
  210. } else if errorRate < 0.1 {
  211. errorRate = 0.1
  212. }
  213. return int32(float64(priority) / errorRate)
  214. }
  215. //nolint:gosec
  216. func getRandomChannel(channels []*model.Channel, errorRates map[int64]float64, ignoreChannel ...int64) (*model.Channel, error) {
  217. if len(channels) == 0 {
  218. return nil, ErrChannelsNotFound
  219. }
  220. channels = filterChannels(channels, ignoreChannel...)
  221. if len(channels) == 0 {
  222. return nil, ErrChannelsExhausted
  223. }
  224. if len(channels) == 1 {
  225. return channels[0], nil
  226. }
  227. var totalWeight int32
  228. cachedPrioritys := make([]int32, len(channels))
  229. for i, ch := range channels {
  230. priority := getPriority(ch, errorRates[int64(ch.ID)])
  231. totalWeight += priority
  232. cachedPrioritys[i] = priority
  233. }
  234. if totalWeight == 0 {
  235. return channels[rand.IntN(len(channels))], nil
  236. }
  237. r := rand.Int32N(totalWeight)
  238. for i, ch := range channels {
  239. r -= cachedPrioritys[i]
  240. if r < 0 {
  241. return ch, nil
  242. }
  243. }
  244. return channels[rand.IntN(len(channels))], nil
  245. }
  246. func getChannelWithFallback(cache *model.ModelCaches, availableSet []string, modelName string, errorRates map[int64]float64, ignoreChannelIDs ...int64) (*model.Channel, []*model.Channel, error) {
  247. channel, migratedChannels, err := GetRandomChannel(cache, availableSet, modelName, errorRates, ignoreChannelIDs...)
  248. if err == nil {
  249. return channel, migratedChannels, nil
  250. }
  251. if !errors.Is(err, ErrChannelsExhausted) {
  252. return nil, migratedChannels, err
  253. }
  254. channel, migratedChannels, err = GetRandomChannel(cache, availableSet, modelName, errorRates)
  255. return channel, migratedChannels, err
  256. }
  257. func NewRelay(mode mode.Mode) func(c *gin.Context) {
  258. relayController := relayController(mode)
  259. return func(c *gin.Context) {
  260. relay(c, mode, relayController)
  261. }
  262. }
  263. func NewMetaByContext(c *gin.Context, channel *model.Channel, mode mode.Mode, opts ...meta.Option) *meta.Meta {
  264. return middleware.NewMetaByContext(c, channel, mode, opts...)
  265. }
  266. func relay(c *gin.Context, mode mode.Mode, relayController RelayController) {
  267. log := middleware.GetLogger(c)
  268. requestModel := middleware.GetRequestModel(c)
  269. mc := middleware.GetModelConfig(c)
  270. // Get initial channel
  271. initialChannel, err := getInitialChannel(c, requestModel, log)
  272. if err != nil || initialChannel == nil || initialChannel.channel == nil {
  273. middleware.AbortLogWithMessage(c,
  274. http.StatusServiceUnavailable,
  275. "the upstream load is saturated, please try again later",
  276. )
  277. return
  278. }
  279. billingEnabled := config.GetBillingEnabled()
  280. price := model.Price{}
  281. if billingEnabled && relayController.GetRequestPrice != nil {
  282. price, err = relayController.GetRequestPrice(c, mc)
  283. if err != nil {
  284. middleware.AbortLogWithMessage(c,
  285. http.StatusInternalServerError,
  286. "get request price failed: "+err.Error(),
  287. )
  288. return
  289. }
  290. }
  291. meta := NewMetaByContext(c, initialChannel.channel, mode)
  292. if billingEnabled && relayController.GetRequestUsage != nil {
  293. requestUsage, err := relayController.GetRequestUsage(c, mc)
  294. if err != nil {
  295. middleware.AbortLogWithMessage(c,
  296. http.StatusInternalServerError,
  297. "get request usage failed: "+err.Error(),
  298. )
  299. return
  300. }
  301. gbc := middleware.GetGroupBalanceConsumerFromContext(c)
  302. if !gbc.CheckBalance(consume.CalculateAmount(http.StatusOK, requestUsage, price)) {
  303. middleware.AbortLogWithMessage(c,
  304. http.StatusForbidden,
  305. fmt.Sprintf("group (%s) balance not enough", gbc.Group),
  306. &middleware.ErrorField{
  307. Code: middleware.GroupBalanceNotEnough,
  308. },
  309. )
  310. return
  311. }
  312. meta.RequestUsage = requestUsage
  313. }
  314. // First attempt
  315. result, retry := RelayHelper(meta, c, relayController.Handler)
  316. retryTimes := int(config.GetRetryTimes())
  317. if mc.RetryTimes > 0 {
  318. retryTimes = int(mc.RetryTimes)
  319. }
  320. if handleRelayResult(c, result.Error, retry, retryTimes) {
  321. recordResult(
  322. c,
  323. meta,
  324. price,
  325. result,
  326. 0,
  327. true,
  328. middleware.GetRequestUser(c),
  329. middleware.GetRequestMetadata(c),
  330. )
  331. return
  332. }
  333. // Setup retry state
  334. retryState := initRetryState(
  335. retryTimes,
  336. initialChannel,
  337. meta,
  338. result,
  339. price,
  340. )
  341. // Retry loop
  342. retryLoop(c, mode, retryState, relayController.Handler, log)
  343. }
  344. // recordResult records the consumption for the final result
  345. func recordResult(
  346. c *gin.Context,
  347. meta *meta.Meta,
  348. price model.Price,
  349. result *controller.HandleResult,
  350. retryTimes int,
  351. downstreamResult bool,
  352. user string,
  353. metadata map[string]string,
  354. ) {
  355. code := http.StatusOK
  356. content := ""
  357. if result.Error != nil {
  358. code = result.Error.StatusCode
  359. content = result.Error.JSONOrEmpty()
  360. }
  361. var detail *model.RequestDetail
  362. firstByteAt := result.Detail.FirstByteAt
  363. if code == http.StatusOK && !config.GetSaveAllLogDetail() {
  364. detail = nil
  365. } else {
  366. detail = &model.RequestDetail{
  367. RequestBody: result.Detail.RequestBody,
  368. ResponseBody: result.Detail.ResponseBody,
  369. }
  370. }
  371. gbc := middleware.GetGroupBalanceConsumerFromContext(c)
  372. amount := consume.CalculateAmount(
  373. code,
  374. result.Usage,
  375. price,
  376. )
  377. if amount > 0 {
  378. log := middleware.GetLogger(c)
  379. log.Data["amount"] = strconv.FormatFloat(amount, 'f', -1, 64)
  380. }
  381. consume.AsyncConsume(
  382. gbc.Consumer,
  383. code,
  384. firstByteAt,
  385. meta,
  386. result.Usage,
  387. price,
  388. content,
  389. c.ClientIP(),
  390. retryTimes,
  391. detail,
  392. downstreamResult,
  393. user,
  394. metadata,
  395. )
  396. }
  397. type retryState struct {
  398. retryTimes int
  399. lastHasPermissionChannel *model.Channel
  400. ignoreChannelIDs []int64
  401. errorRates map[int64]float64
  402. exhausted bool
  403. meta *meta.Meta
  404. price model.Price
  405. requestUsage model.Usage
  406. result *controller.HandleResult
  407. migratedChannels []*model.Channel
  408. }
  409. type initialChannel struct {
  410. channel *model.Channel
  411. designatedChannel bool
  412. ignoreChannelIDs []int64
  413. errorRates map[int64]float64
  414. migratedChannels []*model.Channel
  415. }
  416. func getInitialChannel(c *gin.Context, modelName string, log *log.Entry) (*initialChannel, error) {
  417. if channel := middleware.GetChannel(c); channel != nil {
  418. log.Data["designated_channel"] = "true"
  419. return &initialChannel{channel: channel, designatedChannel: true}, nil
  420. }
  421. mc := middleware.GetModelCaches(c)
  422. ids, err := monitor.GetBannedChannelsWithModel(c.Request.Context(), modelName)
  423. if err != nil {
  424. log.Errorf("get %s auto banned channels failed: %+v", modelName, err)
  425. }
  426. log.Debugf("%s model banned channels: %+v", modelName, ids)
  427. errorRates, err := monitor.GetModelChannelErrorRate(c.Request.Context(), modelName)
  428. if err != nil {
  429. log.Errorf("get channel model error rates failed: %+v", err)
  430. }
  431. group := middleware.GetGroup(c)
  432. availableSet := group.GetAvailableSets()
  433. channel, migratedChannels, err := getChannelWithFallback(mc, availableSet, modelName, errorRates, ids...)
  434. if err != nil {
  435. return nil, err
  436. }
  437. return &initialChannel{
  438. channel: channel,
  439. ignoreChannelIDs: ids,
  440. errorRates: errorRates,
  441. migratedChannels: migratedChannels,
  442. }, nil
  443. }
  444. func handleRelayResult(c *gin.Context, bizErr *relaymodel.ErrorWithStatusCode, retry bool, retryTimes int) (done bool) {
  445. if bizErr == nil {
  446. return true
  447. }
  448. if !retry ||
  449. retryTimes == 0 ||
  450. c.Request.Context().Err() != nil {
  451. bizErr.Error.Message = middleware.MessageWithRequestID(c, bizErr.Error.Message)
  452. c.JSON(bizErr.StatusCode, bizErr)
  453. return true
  454. }
  455. return false
  456. }
  457. func initRetryState(retryTimes int, channel *initialChannel, meta *meta.Meta, result *controller.HandleResult, price model.Price) *retryState {
  458. state := &retryState{
  459. retryTimes: retryTimes,
  460. ignoreChannelIDs: channel.ignoreChannelIDs,
  461. errorRates: channel.errorRates,
  462. meta: meta,
  463. result: result,
  464. price: price,
  465. requestUsage: meta.RequestUsage,
  466. migratedChannels: channel.migratedChannels,
  467. }
  468. if channel.designatedChannel {
  469. state.exhausted = true
  470. }
  471. if !channelHasPermission(*result.Error) {
  472. state.ignoreChannelIDs = append(state.ignoreChannelIDs, int64(channel.channel.ID))
  473. } else {
  474. state.lastHasPermissionChannel = channel.channel
  475. }
  476. return state
  477. }
  478. func retryLoop(c *gin.Context, mode mode.Mode, state *retryState, relayController RelayHandler, log *log.Entry) {
  479. // do not use for i := range state.retryTimes, because the retryTimes is constant
  480. i := 0
  481. for {
  482. lastStatusCode := state.result.Error.StatusCode
  483. lastChannelID := state.meta.Channel.ID
  484. newChannel, err := getRetryChannel(state)
  485. if err == nil {
  486. err = prepareRetry(c)
  487. }
  488. if err != nil {
  489. if !errors.Is(err, ErrChannelsExhausted) {
  490. log.Errorf("prepare retry failed: %+v", err)
  491. }
  492. // when the last request has not recorded the result, record the result
  493. if state.meta != nil && state.result != nil {
  494. recordResult(
  495. c,
  496. state.meta,
  497. state.price,
  498. state.result,
  499. i,
  500. true,
  501. middleware.GetRequestUser(c),
  502. middleware.GetRequestMetadata(c),
  503. )
  504. }
  505. break
  506. }
  507. // when the last request has not recorded the result, record the result
  508. if state.meta != nil && state.result != nil {
  509. recordResult(
  510. c,
  511. state.meta,
  512. state.price,
  513. state.result,
  514. i,
  515. false,
  516. middleware.GetRequestUser(c),
  517. middleware.GetRequestMetadata(c),
  518. )
  519. state.meta = nil
  520. state.result = nil
  521. }
  522. log.Data["retry"] = strconv.Itoa(i + 1)
  523. log.Warnf("using channel %s (type: %d, id: %d) to retry (remain times %d)",
  524. newChannel.Name,
  525. newChannel.Type,
  526. newChannel.ID,
  527. state.retryTimes-i,
  528. )
  529. // Check if we should delay (using the same channel)
  530. if shouldDelay(lastStatusCode, lastChannelID, newChannel.ID) {
  531. relayDelay()
  532. }
  533. state.meta = NewMetaByContext(
  534. c,
  535. newChannel,
  536. mode,
  537. meta.WithRequestUsage(state.requestUsage),
  538. meta.WithRetryAt(time.Now()),
  539. )
  540. var retry bool
  541. state.result, retry = RelayHelper(state.meta, c, relayController)
  542. done := handleRetryResult(c, retry, newChannel, state)
  543. if done || i == state.retryTimes-1 {
  544. recordResult(
  545. c,
  546. state.meta,
  547. state.price,
  548. state.result,
  549. i+1,
  550. true,
  551. middleware.GetRequestUser(c),
  552. middleware.GetRequestMetadata(c),
  553. )
  554. break
  555. }
  556. i++
  557. }
  558. if state.result.Error != nil {
  559. state.result.Error.Error.Message = middleware.MessageWithRequestID(c, state.result.Error.Error.Message)
  560. c.JSON(state.result.Error.StatusCode, state.result.Error)
  561. }
  562. }
  563. func getRetryChannel(state *retryState) (*model.Channel, error) {
  564. if state.exhausted {
  565. if state.lastHasPermissionChannel == nil {
  566. return nil, ErrChannelsExhausted
  567. }
  568. return state.lastHasPermissionChannel, nil
  569. }
  570. newChannel, err := getRandomChannel(state.migratedChannels, state.errorRates, state.ignoreChannelIDs...)
  571. if err != nil {
  572. if !errors.Is(err, ErrChannelsExhausted) || state.lastHasPermissionChannel == nil {
  573. return nil, err
  574. }
  575. state.exhausted = true
  576. return state.lastHasPermissionChannel, nil
  577. }
  578. return newChannel, nil
  579. }
  580. func prepareRetry(c *gin.Context) error {
  581. requestBody, err := common.GetRequestBody(c.Request)
  582. if err != nil {
  583. return fmt.Errorf("get request body failed in prepare retry: %w", err)
  584. }
  585. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  586. return nil
  587. }
  588. func handleRetryResult(ctx *gin.Context, retry bool, newChannel *model.Channel, state *retryState) (done bool) {
  589. if ctx.Request.Context().Err() != nil {
  590. return true
  591. }
  592. if !retry || state.result.Error == nil {
  593. return true
  594. }
  595. hasPermission := channelHasPermission(*state.result.Error)
  596. if state.exhausted {
  597. if !hasPermission {
  598. return true
  599. }
  600. } else {
  601. if !hasPermission {
  602. state.ignoreChannelIDs = append(state.ignoreChannelIDs, int64(newChannel.ID))
  603. state.retryTimes++
  604. } else {
  605. state.lastHasPermissionChannel = newChannel
  606. }
  607. }
  608. return false
  609. }
  610. var channelNoRetryStatusCodesMap = map[int]struct{}{
  611. http.StatusBadRequest: {},
  612. http.StatusRequestEntityTooLarge: {},
  613. http.StatusUnprocessableEntity: {},
  614. http.StatusUnavailableForLegalReasons: {},
  615. }
  616. // 仅当是channel错误时,才需要记录,用户请求参数错误时,不需要记录
  617. func shouldRetry(_ *gin.Context, relayErr relaymodel.ErrorWithStatusCode) bool {
  618. if relayErr.Error.Code == controller.ErrInvalidChannelTypeCode {
  619. return false
  620. }
  621. _, ok := channelNoRetryStatusCodesMap[relayErr.StatusCode]
  622. return !ok
  623. }
  624. var channelNoPermissionStatusCodesMap = map[int]struct{}{
  625. http.StatusUnauthorized: {},
  626. http.StatusPaymentRequired: {},
  627. http.StatusForbidden: {},
  628. http.StatusNotFound: {},
  629. }
  630. func channelHasPermission(relayErr relaymodel.ErrorWithStatusCode) bool {
  631. if relayErr.Error.Code == controller.ErrInvalidChannelTypeCode {
  632. return false
  633. }
  634. _, ok := channelNoPermissionStatusCodesMap[relayErr.StatusCode]
  635. return !ok
  636. }
  637. // shouldDelay checks if we need to add a delay before retrying
  638. // Only adds delay when retrying with the same channel for rate limiting issues
  639. func shouldDelay(statusCode int, lastChannelID, newChannelID int) bool {
  640. if lastChannelID != newChannelID {
  641. return false
  642. }
  643. // Only delay for rate limiting or service unavailable errors
  644. return statusCode == http.StatusTooManyRequests ||
  645. statusCode == http.StatusServiceUnavailable
  646. }
  647. func relayDelay() {
  648. //nolint:gosec
  649. time.Sleep(time.Duration(rand.Float64()*float64(time.Second)) + time.Second)
  650. }
  651. func RelayNotImplemented(c *gin.Context) {
  652. c.JSON(http.StatusNotImplemented, gin.H{
  653. "error": &relaymodel.Error{
  654. Message: "API not implemented",
  655. Type: middleware.ErrorTypeAIPROXY,
  656. Code: "api_not_implemented",
  657. },
  658. })
  659. }