relay-controller.go 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997
  1. package controller
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "math/rand/v2"
  9. "net/http"
  10. "slices"
  11. "strconv"
  12. "time"
  13. "github.com/bytedance/sonic"
  14. "github.com/bytedance/sonic/ast"
  15. "github.com/gin-gonic/gin"
  16. "github.com/labring/aiproxy/core/common"
  17. "github.com/labring/aiproxy/core/common/config"
  18. "github.com/labring/aiproxy/core/common/consume"
  19. "github.com/labring/aiproxy/core/common/conv"
  20. "github.com/labring/aiproxy/core/common/notify"
  21. "github.com/labring/aiproxy/core/common/reqlimit"
  22. "github.com/labring/aiproxy/core/common/trylock"
  23. "github.com/labring/aiproxy/core/middleware"
  24. "github.com/labring/aiproxy/core/model"
  25. "github.com/labring/aiproxy/core/monitor"
  26. "github.com/labring/aiproxy/core/relay/adaptor"
  27. "github.com/labring/aiproxy/core/relay/adaptors"
  28. "github.com/labring/aiproxy/core/relay/controller"
  29. "github.com/labring/aiproxy/core/relay/meta"
  30. "github.com/labring/aiproxy/core/relay/mode"
  31. relaymodel "github.com/labring/aiproxy/core/relay/model"
  32. "github.com/labring/aiproxy/core/relay/plugin"
  33. "github.com/labring/aiproxy/core/relay/plugin/cache"
  34. "github.com/labring/aiproxy/core/relay/plugin/thinksplit"
  35. websearch "github.com/labring/aiproxy/core/relay/plugin/web-search"
  36. log "github.com/sirupsen/logrus"
  37. )
  38. // https://platform.openai.com/docs/api-reference/chat
  39. type (
  40. RelayHandler func(*gin.Context, *meta.Meta) *controller.HandleResult
  41. GetRequestUsage func(*gin.Context, model.ModelConfig) (model.Usage, error)
  42. GetRequestPrice func(*gin.Context, model.ModelConfig) (model.Price, error)
  43. )
  44. type RelayController struct {
  45. GetRequestUsage GetRequestUsage
  46. GetRequestPrice GetRequestPrice
  47. Handler RelayHandler
  48. }
  49. // TODO: convert to plugin
  50. type wrapAdaptor struct {
  51. adaptor.Adaptor
  52. }
  53. const (
  54. MetaChannelModelKeyRPM = "channel_model_rpm"
  55. MetaChannelModelKeyRPS = "channel_model_rps"
  56. MetaChannelModelKeyTPM = "channel_model_tpm"
  57. MetaChannelModelKeyTPS = "channel_model_tps"
  58. )
  59. func getChannelModelRequestRate(c *gin.Context, meta *meta.Meta) model.RequestRate {
  60. rate := model.RequestRate{}
  61. if rpm, ok := meta.Get(MetaChannelModelKeyRPM); ok {
  62. rate.RPM, _ = rpm.(int64)
  63. rate.RPS = meta.GetInt64(MetaChannelModelKeyRPS)
  64. } else {
  65. rpm, rps := reqlimit.GetChannelModelRequest(context.Background(), strconv.Itoa(meta.Channel.ID), meta.OriginModel)
  66. rate.RPM = rpm
  67. rate.RPS = rps
  68. updateChannelModelRequestRate(c, meta, rpm, rps)
  69. }
  70. if tpm, ok := meta.Get(MetaChannelModelKeyTPM); ok {
  71. rate.TPM, _ = tpm.(int64)
  72. rate.TPS = meta.GetInt64(MetaChannelModelKeyTPS)
  73. } else {
  74. tpm, tps := reqlimit.GetChannelModelTokensRequest(context.Background(), strconv.Itoa(meta.Channel.ID), meta.OriginModel)
  75. rate.TPM = tpm
  76. rate.TPS = tps
  77. updateChannelModelTokensRequestRate(c, meta, tpm, tps)
  78. }
  79. return rate
  80. }
  81. func updateChannelModelRequestRate(c *gin.Context, meta *meta.Meta, rpm, rps int64) {
  82. meta.Set(MetaChannelModelKeyRPM, rpm)
  83. meta.Set(MetaChannelModelKeyRPS, rps)
  84. log := middleware.GetLogger(c)
  85. log.Data["ch_rpm"] = rpm
  86. log.Data["ch_rps"] = rps
  87. }
  88. func updateChannelModelTokensRequestRate(c *gin.Context, meta *meta.Meta, tpm, tps int64) {
  89. meta.Set(MetaChannelModelKeyTPM, tpm)
  90. meta.Set(MetaChannelModelKeyTPS, tps)
  91. log := middleware.GetLogger(c)
  92. log.Data["ch_tpm"] = tpm
  93. log.Data["ch_tps"] = tps
  94. }
  95. func (w *wrapAdaptor) DoRequest(
  96. meta *meta.Meta,
  97. c *gin.Context,
  98. req *http.Request,
  99. ) (*http.Response, error) {
  100. count, overLimitCount, secondCount := reqlimit.PushChannelModelRequest(
  101. context.Background(),
  102. strconv.Itoa(meta.Channel.ID),
  103. meta.OriginModel,
  104. )
  105. updateChannelModelRequestRate(c, meta, count+overLimitCount, secondCount)
  106. return w.Adaptor.DoRequest(meta, c, req)
  107. }
  108. func (w *wrapAdaptor) DoResponse(
  109. meta *meta.Meta,
  110. c *gin.Context,
  111. resp *http.Response,
  112. ) (*model.Usage, adaptor.Error) {
  113. usage, relayErr := w.Adaptor.DoResponse(meta, c, resp)
  114. if usage == nil {
  115. return nil, relayErr
  116. }
  117. if usage.TotalTokens > 0 {
  118. count, overLimitCount, secondCount := reqlimit.PushChannelModelTokensRequest(
  119. context.Background(),
  120. strconv.Itoa(meta.Channel.ID),
  121. meta.OriginModel,
  122. int64(usage.TotalTokens),
  123. )
  124. updateChannelModelTokensRequestRate(c, meta, count+overLimitCount, secondCount)
  125. count, overLimitCount, secondCount = reqlimit.PushGroupModelTokensRequest(
  126. context.Background(),
  127. meta.Group.ID,
  128. meta.OriginModel,
  129. meta.ModelConfig.TPM,
  130. int64(usage.TotalTokens),
  131. )
  132. middleware.UpdateGroupModelTokensRequest(c, meta.Group, count+overLimitCount, secondCount)
  133. count, overLimitCount, secondCount = reqlimit.PushGroupModelTokennameTokensRequest(
  134. context.Background(),
  135. meta.Group.ID,
  136. meta.OriginModel,
  137. meta.Token.Name,
  138. int64(usage.TotalTokens),
  139. )
  140. middleware.UpdateGroupModelTokennameTokensRequest(c, count+overLimitCount, secondCount)
  141. }
  142. return usage, relayErr
  143. }
  144. func relayHandler(c *gin.Context, meta *meta.Meta) *controller.HandleResult {
  145. log := middleware.GetLogger(c)
  146. middleware.SetLogFieldsFromMeta(meta, log.Data)
  147. adaptor, ok := adaptors.GetAdaptor(meta.Channel.Type)
  148. if !ok {
  149. return &controller.HandleResult{
  150. Error: relaymodel.WrapperOpenAIErrorWithMessage(
  151. fmt.Sprintf("invalid channel type: %d", meta.Channel.Type),
  152. "invalid_channel_type",
  153. http.StatusInternalServerError,
  154. ),
  155. }
  156. }
  157. a := plugin.WrapperAdaptor(&wrapAdaptor{adaptor},
  158. cache.NewCachePlugin(common.RDB),
  159. websearch.NewWebSearchPlugin(func(modelName string) (*model.Channel, error) {
  160. return getWebSearchChannel(c, modelName)
  161. }),
  162. thinksplit.NewThinkPlugin(),
  163. )
  164. return controller.Handle(a, c, meta)
  165. }
  166. func relayController(m mode.Mode) RelayController {
  167. c := RelayController{
  168. Handler: relayHandler,
  169. }
  170. switch m {
  171. case mode.ImagesGenerations:
  172. c.GetRequestPrice = controller.GetImagesRequestPrice
  173. c.GetRequestUsage = controller.GetImagesRequestUsage
  174. case mode.ImagesEdits:
  175. c.GetRequestPrice = controller.GetImagesEditsRequestPrice
  176. c.GetRequestUsage = controller.GetImagesEditsRequestUsage
  177. case mode.AudioSpeech:
  178. c.GetRequestPrice = controller.GetTTSRequestPrice
  179. c.GetRequestUsage = controller.GetTTSRequestUsage
  180. case mode.AudioTranslation, mode.AudioTranscription:
  181. c.GetRequestPrice = controller.GetSTTRequestPrice
  182. c.GetRequestUsage = controller.GetSTTRequestUsage
  183. case mode.ParsePdf:
  184. c.GetRequestPrice = controller.GetPdfRequestPrice
  185. c.GetRequestUsage = controller.GetPdfRequestUsage
  186. case mode.Rerank:
  187. c.GetRequestPrice = controller.GetRerankRequestPrice
  188. c.GetRequestUsage = controller.GetRerankRequestUsage
  189. case mode.Anthropic:
  190. c.GetRequestPrice = controller.GetAnthropicRequestPrice
  191. c.GetRequestUsage = controller.GetAnthropicRequestUsage
  192. case mode.ChatCompletions:
  193. c.GetRequestPrice = controller.GetChatRequestPrice
  194. c.GetRequestUsage = controller.GetChatRequestUsage
  195. case mode.Embeddings:
  196. c.GetRequestPrice = controller.GetEmbedRequestPrice
  197. c.GetRequestUsage = controller.GetEmbedRequestUsage
  198. case mode.Completions:
  199. c.GetRequestPrice = controller.GetCompletionsRequestPrice
  200. c.GetRequestUsage = controller.GetCompletionsRequestUsage
  201. }
  202. return c
  203. }
  204. func RelayHelper(
  205. c *gin.Context,
  206. meta *meta.Meta,
  207. handel RelayHandler,
  208. ) (*controller.HandleResult, bool) {
  209. result := handel(c, meta)
  210. if result.Error == nil {
  211. if _, _, err := monitor.AddRequest(
  212. context.Background(),
  213. meta.OriginModel,
  214. int64(meta.Channel.ID),
  215. false,
  216. false,
  217. meta.ModelConfig.MaxErrorRate,
  218. ); err != nil {
  219. log.Errorf("add request failed: %+v", err)
  220. }
  221. return result, false
  222. }
  223. shouldRetry := shouldRetry(c, result.Error)
  224. if shouldRetry {
  225. hasPermission := channelHasPermission(result.Error)
  226. beyondThreshold, banExecution, err := monitor.AddRequest(
  227. context.Background(),
  228. meta.OriginModel,
  229. int64(meta.Channel.ID),
  230. true,
  231. !hasPermission,
  232. meta.ModelConfig.MaxErrorRate,
  233. )
  234. if err != nil {
  235. log.Errorf("add request failed: %+v", err)
  236. }
  237. switch {
  238. case banExecution:
  239. notifyChannelIssue(c, meta, "autoBanned", "Auto Banned", result.Error)
  240. case beyondThreshold:
  241. notifyChannelIssue(
  242. c,
  243. meta,
  244. "beyondThreshold",
  245. "Error Rate Beyond Threshold",
  246. result.Error,
  247. )
  248. case !hasPermission:
  249. notifyChannelIssue(c, meta, "channelHasPermission", "No Permission", result.Error)
  250. }
  251. }
  252. return result, shouldRetry
  253. }
  254. func notifyChannelIssue(
  255. c *gin.Context,
  256. meta *meta.Meta,
  257. issueType, titleSuffix string,
  258. err adaptor.Error,
  259. ) {
  260. var notifyFunc func(title, message string)
  261. lockKey := fmt.Sprintf("%s:%d:%s", issueType, meta.Channel.ID, meta.OriginModel)
  262. switch issueType {
  263. case "beyondThreshold":
  264. notifyFunc = func(title, message string) {
  265. notify.WarnThrottle(lockKey, time.Minute, title, message)
  266. }
  267. default:
  268. notifyFunc = func(title, message string) {
  269. notify.ErrorThrottle(lockKey, time.Minute, title, message)
  270. }
  271. }
  272. respBody, _ := err.MarshalJSON()
  273. message := fmt.Sprintf(
  274. "channel: %s (type: %d, type name: %s, id: %d)\nmodel: %s\nmode: %s\nstatus code: %d\ndetail: %s\nrequest id: %s",
  275. meta.Channel.Name,
  276. meta.Channel.Type,
  277. meta.Channel.Type.String(),
  278. meta.Channel.ID,
  279. meta.OriginModel,
  280. meta.Mode,
  281. err.StatusCode(),
  282. conv.BytesToString(respBody),
  283. meta.RequestID,
  284. )
  285. if err.StatusCode() == http.StatusTooManyRequests {
  286. if !trylock.Lock(lockKey, time.Minute) {
  287. return
  288. }
  289. switch issueType {
  290. case "beyondThreshold":
  291. notifyFunc = notify.Warn
  292. default:
  293. notifyFunc = notify.Error
  294. }
  295. rate := getChannelModelRequestRate(c, meta)
  296. message += fmt.Sprintf(
  297. "\nrpm: %d\nrps: %d\ntpm: %d\ntps: %d",
  298. rate.RPM,
  299. rate.RPS,
  300. rate.TPM,
  301. rate.TPS,
  302. )
  303. }
  304. notifyFunc(
  305. fmt.Sprintf("%s `%s` %s", meta.Channel.Name, meta.OriginModel, titleSuffix),
  306. message,
  307. )
  308. }
  309. func filterChannels(channels []*model.Channel, ignoreChannel ...int64) []*model.Channel {
  310. filtered := make([]*model.Channel, 0)
  311. for _, channel := range channels {
  312. if channel.Status != model.ChannelStatusEnabled {
  313. continue
  314. }
  315. if slices.Contains(ignoreChannel, int64(channel.ID)) {
  316. continue
  317. }
  318. filtered = append(filtered, channel)
  319. }
  320. return filtered
  321. }
  322. var (
  323. ErrChannelsNotFound = errors.New("channels not found")
  324. ErrChannelsExhausted = errors.New("channels exhausted")
  325. )
  326. func GetRandomChannel(
  327. mc *model.ModelCaches,
  328. availableSet []string,
  329. modelName string,
  330. errorRates map[int64]float64,
  331. ignoreChannel ...int64,
  332. ) (*model.Channel, []*model.Channel, error) {
  333. channelMap := make(map[int]*model.Channel)
  334. if len(availableSet) != 0 {
  335. for _, set := range availableSet {
  336. for _, channel := range mc.EnabledModel2ChannelsBySet[set][modelName] {
  337. channelMap[channel.ID] = channel
  338. }
  339. }
  340. } else {
  341. for _, sets := range mc.EnabledModel2ChannelsBySet {
  342. for _, channel := range sets[modelName] {
  343. channelMap[channel.ID] = channel
  344. }
  345. }
  346. }
  347. migratedChannels := make([]*model.Channel, 0, len(channelMap))
  348. for _, channel := range channelMap {
  349. migratedChannels = append(migratedChannels, channel)
  350. }
  351. channel, err := getRandomChannel(migratedChannels, errorRates, ignoreChannel...)
  352. return channel, migratedChannels, err
  353. }
  354. func getPriority(channel *model.Channel, errorRate float64) int32 {
  355. priority := channel.GetPriority()
  356. if errorRate > 1 {
  357. errorRate = 1
  358. } else if errorRate < 0.1 {
  359. errorRate = 0.1
  360. }
  361. return int32(float64(priority) / errorRate)
  362. }
  363. //
  364. func getRandomChannel(
  365. channels []*model.Channel,
  366. errorRates map[int64]float64,
  367. ignoreChannel ...int64,
  368. ) (*model.Channel, error) {
  369. if len(channels) == 0 {
  370. return nil, ErrChannelsNotFound
  371. }
  372. channels = filterChannels(channels, ignoreChannel...)
  373. if len(channels) == 0 {
  374. return nil, ErrChannelsExhausted
  375. }
  376. if len(channels) == 1 {
  377. return channels[0], nil
  378. }
  379. var totalWeight int32
  380. cachedPrioritys := make([]int32, len(channels))
  381. for i, ch := range channels {
  382. priority := getPriority(ch, errorRates[int64(ch.ID)])
  383. totalWeight += priority
  384. cachedPrioritys[i] = priority
  385. }
  386. if totalWeight == 0 {
  387. return channels[rand.IntN(len(channels))], nil
  388. }
  389. r := rand.Int32N(totalWeight)
  390. for i, ch := range channels {
  391. r -= cachedPrioritys[i]
  392. if r < 0 {
  393. return ch, nil
  394. }
  395. }
  396. return channels[rand.IntN(len(channels))], nil
  397. }
  398. func getChannelWithFallback(
  399. cache *model.ModelCaches,
  400. availableSet []string,
  401. modelName string,
  402. errorRates map[int64]float64,
  403. ignoreChannelIDs ...int64,
  404. ) (*model.Channel, []*model.Channel, error) {
  405. channel, migratedChannels, err := GetRandomChannel(
  406. cache,
  407. availableSet,
  408. modelName,
  409. errorRates,
  410. ignoreChannelIDs...)
  411. if err == nil {
  412. return channel, migratedChannels, nil
  413. }
  414. if !errors.Is(err, ErrChannelsExhausted) {
  415. return nil, migratedChannels, err
  416. }
  417. channel, migratedChannels, err = GetRandomChannel(cache, availableSet, modelName, errorRates)
  418. return channel, migratedChannels, err
  419. }
  420. func NewRelay(mode mode.Mode) func(c *gin.Context) {
  421. relayController := relayController(mode)
  422. return func(c *gin.Context) {
  423. relay(c, mode, relayController)
  424. }
  425. }
  426. func NewMetaByContext(
  427. c *gin.Context,
  428. channel *model.Channel,
  429. mode mode.Mode,
  430. opts ...meta.Option,
  431. ) *meta.Meta {
  432. return middleware.NewMetaByContext(c, channel, mode, opts...)
  433. }
  434. func relay(c *gin.Context, mode mode.Mode, relayController RelayController) {
  435. requestModel := middleware.GetRequestModel(c)
  436. mc := middleware.GetModelConfig(c)
  437. // Get initial channel
  438. initialChannel, err := getInitialChannel(c, requestModel)
  439. if err != nil || initialChannel == nil || initialChannel.channel == nil {
  440. middleware.AbortLogWithMessageWithMode(mode, c,
  441. http.StatusServiceUnavailable,
  442. "the upstream load is saturated, please try again later",
  443. )
  444. return
  445. }
  446. price := model.Price{}
  447. if relayController.GetRequestPrice != nil {
  448. price, err = relayController.GetRequestPrice(c, mc)
  449. if err != nil {
  450. middleware.AbortLogWithMessageWithMode(mode, c,
  451. http.StatusInternalServerError,
  452. "get request price failed: "+err.Error(),
  453. )
  454. return
  455. }
  456. }
  457. meta := NewMetaByContext(c, initialChannel.channel, mode)
  458. if relayController.GetRequestUsage != nil {
  459. requestUsage, err := relayController.GetRequestUsage(c, mc)
  460. if err != nil {
  461. middleware.AbortLogWithMessageWithMode(mode, c,
  462. http.StatusInternalServerError,
  463. "get request usage failed: "+err.Error(),
  464. )
  465. return
  466. }
  467. gbc := middleware.GetGroupBalanceConsumerFromContext(c)
  468. if !gbc.CheckBalance(consume.CalculateAmount(http.StatusOK, requestUsage, price)) {
  469. middleware.AbortLogWithMessageWithMode(mode, c,
  470. http.StatusForbidden,
  471. fmt.Sprintf("group (%s) balance not enough", gbc.Group),
  472. middleware.GroupBalanceNotEnough,
  473. )
  474. return
  475. }
  476. meta.RequestUsage = requestUsage
  477. }
  478. // First attempt
  479. result, retry := RelayHelper(c, meta, relayController.Handler)
  480. retryTimes := int(config.GetRetryTimes())
  481. if mc.RetryTimes > 0 {
  482. retryTimes = int(mc.RetryTimes)
  483. }
  484. if handleRelayResult(c, result.Error, retry, retryTimes) {
  485. recordResult(
  486. c,
  487. meta,
  488. price,
  489. result,
  490. 0,
  491. true,
  492. middleware.GetRequestUser(c),
  493. middleware.GetRequestMetadata(c),
  494. )
  495. return
  496. }
  497. // Setup retry state
  498. retryState := initRetryState(
  499. retryTimes,
  500. initialChannel,
  501. meta,
  502. result,
  503. price,
  504. )
  505. // Retry loop
  506. retryLoop(c, mode, retryState, relayController.Handler)
  507. }
  508. // recordResult records the consumption for the final result
  509. func recordResult(
  510. c *gin.Context,
  511. meta *meta.Meta,
  512. price model.Price,
  513. result *controller.HandleResult,
  514. retryTimes int,
  515. downstreamResult bool,
  516. user string,
  517. metadata map[string]string,
  518. ) {
  519. code := http.StatusOK
  520. content := ""
  521. if result.Error != nil {
  522. code = result.Error.StatusCode()
  523. respBody, _ := result.Error.MarshalJSON()
  524. content = conv.BytesToString(respBody)
  525. }
  526. var detail *model.RequestDetail
  527. firstByteAt := result.Detail.FirstByteAt
  528. if code == http.StatusOK && !config.GetSaveAllLogDetail() {
  529. detail = nil
  530. } else {
  531. detail = &model.RequestDetail{
  532. RequestBody: result.Detail.RequestBody,
  533. ResponseBody: result.Detail.ResponseBody,
  534. }
  535. }
  536. gbc := middleware.GetGroupBalanceConsumerFromContext(c)
  537. amount := consume.CalculateAmount(
  538. code,
  539. result.Usage,
  540. price,
  541. )
  542. if amount > 0 {
  543. log := middleware.GetLogger(c)
  544. log.Data["amount"] = strconv.FormatFloat(amount, 'f', -1, 64)
  545. }
  546. consume.AsyncConsume(
  547. gbc.Consumer,
  548. code,
  549. firstByteAt,
  550. meta,
  551. result.Usage,
  552. price,
  553. content,
  554. c.ClientIP(),
  555. retryTimes,
  556. detail,
  557. downstreamResult,
  558. user,
  559. metadata,
  560. getChannelModelRequestRate(c, meta),
  561. middleware.GetGroupModelTokenRequestRate(c),
  562. )
  563. }
  564. type retryState struct {
  565. retryTimes int
  566. lastHasPermissionChannel *model.Channel
  567. ignoreChannelIDs []int64
  568. errorRates map[int64]float64
  569. exhausted bool
  570. meta *meta.Meta
  571. price model.Price
  572. requestUsage model.Usage
  573. result *controller.HandleResult
  574. migratedChannels []*model.Channel
  575. }
  576. type initialChannel struct {
  577. channel *model.Channel
  578. designatedChannel bool
  579. ignoreChannelIDs []int64
  580. errorRates map[int64]float64
  581. migratedChannels []*model.Channel
  582. }
  583. func getInitialChannel(c *gin.Context, modelName string) (*initialChannel, error) {
  584. log := middleware.GetLogger(c)
  585. if channel := middleware.GetChannel(c); channel != nil {
  586. log.Data["designated_channel"] = "true"
  587. return &initialChannel{channel: channel, designatedChannel: true}, nil
  588. }
  589. mc := middleware.GetModelCaches(c)
  590. ids, err := monitor.GetBannedChannelsWithModel(c.Request.Context(), modelName)
  591. if err != nil {
  592. log.Errorf("get %s auto banned channels failed: %+v", modelName, err)
  593. }
  594. log.Debugf("%s model banned channels: %+v", modelName, ids)
  595. errorRates, err := monitor.GetModelChannelErrorRate(c.Request.Context(), modelName)
  596. if err != nil {
  597. log.Errorf("get channel model error rates failed: %+v", err)
  598. }
  599. group := middleware.GetGroup(c)
  600. availableSet := group.GetAvailableSets()
  601. channel, migratedChannels, err := getChannelWithFallback(
  602. mc,
  603. availableSet,
  604. modelName,
  605. errorRates,
  606. ids...)
  607. if err != nil {
  608. return nil, err
  609. }
  610. return &initialChannel{
  611. channel: channel,
  612. ignoreChannelIDs: ids,
  613. errorRates: errorRates,
  614. migratedChannels: migratedChannels,
  615. }, nil
  616. }
  617. func getWebSearchChannel(c *gin.Context, modelName string) (*model.Channel, error) {
  618. log := middleware.GetLogger(c)
  619. mc := middleware.GetModelCaches(c)
  620. ids, err := monitor.GetBannedChannelsWithModel(c.Request.Context(), modelName)
  621. if err != nil {
  622. log.Errorf("get %s auto banned channels failed: %+v", modelName, err)
  623. }
  624. log.Debugf("%s model banned channels: %+v", modelName, ids)
  625. errorRates, err := monitor.GetModelChannelErrorRate(c.Request.Context(), modelName)
  626. if err != nil {
  627. log.Errorf("get channel model error rates failed: %+v", err)
  628. }
  629. channel, _, err := getChannelWithFallback(mc, nil, modelName, errorRates, ids...)
  630. if err != nil {
  631. return nil, err
  632. }
  633. return channel, nil
  634. }
  635. func handleRelayResult(
  636. c *gin.Context,
  637. bizErr adaptor.Error,
  638. retry bool,
  639. retryTimes int,
  640. ) (done bool) {
  641. if bizErr == nil {
  642. return true
  643. }
  644. if !retry ||
  645. retryTimes == 0 ||
  646. c.Request.Context().Err() != nil {
  647. ErrorWithRequestID(c, bizErr)
  648. return true
  649. }
  650. return false
  651. }
  652. func initRetryState(
  653. retryTimes int,
  654. channel *initialChannel,
  655. meta *meta.Meta,
  656. result *controller.HandleResult,
  657. price model.Price,
  658. ) *retryState {
  659. state := &retryState{
  660. retryTimes: retryTimes,
  661. ignoreChannelIDs: channel.ignoreChannelIDs,
  662. errorRates: channel.errorRates,
  663. meta: meta,
  664. result: result,
  665. price: price,
  666. requestUsage: meta.RequestUsage,
  667. migratedChannels: channel.migratedChannels,
  668. }
  669. if channel.designatedChannel {
  670. state.exhausted = true
  671. }
  672. if !channelHasPermission(result.Error) {
  673. state.ignoreChannelIDs = append(state.ignoreChannelIDs, int64(channel.channel.ID))
  674. } else {
  675. state.lastHasPermissionChannel = channel.channel
  676. }
  677. return state
  678. }
  679. func retryLoop(c *gin.Context, mode mode.Mode, state *retryState, relayController RelayHandler) {
  680. log := middleware.GetLogger(c)
  681. // do not use for i := range state.retryTimes, because the retryTimes is constant
  682. i := 0
  683. for {
  684. lastStatusCode := state.result.Error.StatusCode()
  685. lastChannelID := state.meta.Channel.ID
  686. newChannel, err := getRetryChannel(state)
  687. if err == nil {
  688. err = prepareRetry(c)
  689. }
  690. if err != nil {
  691. if !errors.Is(err, ErrChannelsExhausted) {
  692. log.Errorf("prepare retry failed: %+v", err)
  693. }
  694. // when the last request has not recorded the result, record the result
  695. if state.meta != nil && state.result != nil {
  696. recordResult(
  697. c,
  698. state.meta,
  699. state.price,
  700. state.result,
  701. i,
  702. true,
  703. middleware.GetRequestUser(c),
  704. middleware.GetRequestMetadata(c),
  705. )
  706. }
  707. break
  708. }
  709. // when the last request has not recorded the result, record the result
  710. if state.meta != nil && state.result != nil {
  711. recordResult(
  712. c,
  713. state.meta,
  714. state.price,
  715. state.result,
  716. i,
  717. false,
  718. middleware.GetRequestUser(c),
  719. middleware.GetRequestMetadata(c),
  720. )
  721. state.meta = nil
  722. state.result = nil
  723. }
  724. log.Data["retry"] = strconv.Itoa(i + 1)
  725. log.Warnf("using channel %s (type: %d, id: %d) to retry (remain times %d)",
  726. newChannel.Name,
  727. newChannel.Type,
  728. newChannel.ID,
  729. state.retryTimes-i,
  730. )
  731. // Check if we should delay (using the same channel)
  732. if shouldDelay(lastStatusCode, lastChannelID, newChannel.ID) {
  733. relayDelay()
  734. }
  735. state.meta = NewMetaByContext(
  736. c,
  737. newChannel,
  738. mode,
  739. meta.WithRequestUsage(state.requestUsage),
  740. meta.WithRetryAt(time.Now()),
  741. )
  742. var retry bool
  743. state.result, retry = RelayHelper(c, state.meta, relayController)
  744. done := handleRetryResult(c, retry, newChannel, state)
  745. if done || i == state.retryTimes-1 {
  746. recordResult(
  747. c,
  748. state.meta,
  749. state.price,
  750. state.result,
  751. i+1,
  752. true,
  753. middleware.GetRequestUser(c),
  754. middleware.GetRequestMetadata(c),
  755. )
  756. break
  757. }
  758. i++
  759. }
  760. if state.result.Error != nil {
  761. ErrorWithRequestID(c, state.result.Error)
  762. }
  763. }
  764. func getRetryChannel(state *retryState) (*model.Channel, error) {
  765. if state.exhausted {
  766. if state.lastHasPermissionChannel == nil {
  767. return nil, ErrChannelsExhausted
  768. }
  769. return state.lastHasPermissionChannel, nil
  770. }
  771. newChannel, err := getRandomChannel(
  772. state.migratedChannels,
  773. state.errorRates,
  774. state.ignoreChannelIDs...)
  775. if err != nil {
  776. if !errors.Is(err, ErrChannelsExhausted) || state.lastHasPermissionChannel == nil {
  777. return nil, err
  778. }
  779. state.exhausted = true
  780. return state.lastHasPermissionChannel, nil
  781. }
  782. return newChannel, nil
  783. }
  784. func prepareRetry(c *gin.Context) error {
  785. requestBody, err := common.GetRequestBody(c.Request)
  786. if err != nil {
  787. return fmt.Errorf("get request body failed in prepare retry: %w", err)
  788. }
  789. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  790. return nil
  791. }
  792. func handleRetryResult(
  793. ctx *gin.Context,
  794. retry bool,
  795. newChannel *model.Channel,
  796. state *retryState,
  797. ) (done bool) {
  798. if ctx.Request.Context().Err() != nil {
  799. return true
  800. }
  801. if !retry || state.result.Error == nil {
  802. return true
  803. }
  804. hasPermission := channelHasPermission(state.result.Error)
  805. if state.exhausted {
  806. if !hasPermission {
  807. return true
  808. }
  809. } else {
  810. if !hasPermission {
  811. state.ignoreChannelIDs = append(state.ignoreChannelIDs, int64(newChannel.ID))
  812. state.retryTimes++
  813. } else {
  814. state.lastHasPermissionChannel = newChannel
  815. }
  816. }
  817. return false
  818. }
  819. var channelNoRetryStatusCodesMap = map[int]struct{}{
  820. http.StatusBadRequest: {},
  821. http.StatusRequestEntityTooLarge: {},
  822. http.StatusUnprocessableEntity: {},
  823. http.StatusUnavailableForLegalReasons: {},
  824. }
  825. // 仅当是channel错误时,才需要记录,用户请求参数错误时,不需要记录
  826. func shouldRetry(_ *gin.Context, relayErr adaptor.Error) bool {
  827. _, ok := channelNoRetryStatusCodesMap[relayErr.StatusCode()]
  828. return !ok
  829. }
  830. var channelNoPermissionStatusCodesMap = map[int]struct{}{
  831. http.StatusUnauthorized: {},
  832. http.StatusPaymentRequired: {},
  833. http.StatusForbidden: {},
  834. http.StatusNotFound: {},
  835. }
  836. func channelHasPermission(relayErr adaptor.Error) bool {
  837. _, ok := channelNoPermissionStatusCodesMap[relayErr.StatusCode()]
  838. return !ok
  839. }
  840. // shouldDelay checks if we need to add a delay before retrying
  841. // Only adds delay when retrying with the same channel for rate limiting issues
  842. func shouldDelay(statusCode, lastChannelID, newChannelID int) bool {
  843. if lastChannelID != newChannelID {
  844. return false
  845. }
  846. // Only delay for rate limiting or service unavailable errors
  847. return statusCode == http.StatusTooManyRequests ||
  848. statusCode == http.StatusServiceUnavailable
  849. }
  850. func relayDelay() {
  851. time.Sleep(time.Duration(rand.Float64()*float64(time.Second)) + time.Second)
  852. }
  853. func RelayNotImplemented(c *gin.Context) {
  854. ErrorWithRequestID(c,
  855. relaymodel.NewOpenAIError(http.StatusNotImplemented, relaymodel.OpenAIError{
  856. Message: "API not implemented",
  857. Type: relaymodel.ErrorTypeAIPROXY,
  858. Code: "api_not_implemented",
  859. }),
  860. )
  861. }
  862. func ErrorWithRequestID(c *gin.Context, relayErr adaptor.Error) {
  863. requestID := middleware.GetRequestID(c)
  864. if requestID == "" {
  865. c.JSON(relayErr.StatusCode(), relayErr)
  866. return
  867. }
  868. log := middleware.GetLogger(c)
  869. data, err := relayErr.MarshalJSON()
  870. if err != nil {
  871. log.Errorf("marshal error failed: %+v", err)
  872. c.JSON(relayErr.StatusCode(), relayErr)
  873. return
  874. }
  875. node, err := sonic.Get(data)
  876. if err != nil {
  877. log.Errorf("get node failed: %+v", err)
  878. c.JSON(relayErr.StatusCode(), relayErr)
  879. return
  880. }
  881. _, err = node.Set("aiproxy", ast.NewString(requestID))
  882. if err != nil {
  883. log.Errorf("set request id failed: %+v", err)
  884. c.JSON(relayErr.StatusCode(), relayErr)
  885. return
  886. }
  887. c.JSON(relayErr.StatusCode(), &node)
  888. }