relay-controller.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631
  1. package controller
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "math/rand/v2"
  9. "net/http"
  10. "strconv"
  11. "time"
  12. "github.com/bytedance/sonic"
  13. "github.com/bytedance/sonic/ast"
  14. "github.com/gin-gonic/gin"
  15. "github.com/labring/aiproxy/core/common"
  16. "github.com/labring/aiproxy/core/common/config"
  17. "github.com/labring/aiproxy/core/common/consume"
  18. "github.com/labring/aiproxy/core/common/conv"
  19. "github.com/labring/aiproxy/core/middleware"
  20. "github.com/labring/aiproxy/core/model"
  21. "github.com/labring/aiproxy/core/relay/adaptor"
  22. "github.com/labring/aiproxy/core/relay/adaptors"
  23. "github.com/labring/aiproxy/core/relay/controller"
  24. "github.com/labring/aiproxy/core/relay/meta"
  25. "github.com/labring/aiproxy/core/relay/mode"
  26. relaymodel "github.com/labring/aiproxy/core/relay/model"
  27. "github.com/labring/aiproxy/core/relay/plugin"
  28. "github.com/labring/aiproxy/core/relay/plugin/cache"
  29. monitorplugin "github.com/labring/aiproxy/core/relay/plugin/monitor"
  30. "github.com/labring/aiproxy/core/relay/plugin/patch"
  31. "github.com/labring/aiproxy/core/relay/plugin/streamfake"
  32. "github.com/labring/aiproxy/core/relay/plugin/thinksplit"
  33. "github.com/labring/aiproxy/core/relay/plugin/timeout"
  34. websearch "github.com/labring/aiproxy/core/relay/plugin/web-search"
  35. )
  36. // https://platform.openai.com/docs/api-reference/chat
  37. type (
  38. RelayHandler func(*gin.Context, *meta.Meta) *controller.HandleResult
  39. GetRequestUsage func(*gin.Context, model.ModelConfig) (model.Usage, error)
  40. GetRequestPrice func(*gin.Context, model.ModelConfig) (model.Price, error)
  41. )
  42. type RelayController struct {
  43. GetRequestUsage GetRequestUsage
  44. GetRequestPrice GetRequestPrice
  45. Handler RelayHandler
  46. }
  47. var adaptorStore adaptor.Store = &storeImpl{}
  48. type storeImpl struct{}
  49. func (s *storeImpl) GetStore(group string, tokenID int, id string) (adaptor.StoreCache, error) {
  50. store, err := model.CacheGetStore(group, tokenID, id)
  51. if err != nil {
  52. return adaptor.StoreCache{}, err
  53. }
  54. return adaptor.StoreCache{
  55. ID: store.ID,
  56. GroupID: store.GroupID,
  57. TokenID: store.TokenID,
  58. ChannelID: store.ChannelID,
  59. Model: store.Model,
  60. ExpiresAt: store.ExpiresAt,
  61. }, nil
  62. }
  63. func (s *storeImpl) SaveStore(store adaptor.StoreCache) error {
  64. _, err := model.SaveStore(&model.StoreV2{
  65. ID: store.ID,
  66. GroupID: store.GroupID,
  67. TokenID: store.TokenID,
  68. ChannelID: store.ChannelID,
  69. Model: store.Model,
  70. ExpiresAt: store.ExpiresAt,
  71. })
  72. return err
  73. }
  74. func wrapPlugin(ctx context.Context, mc *model.ModelCaches, a adaptor.Adaptor) adaptor.Adaptor {
  75. return plugin.WrapperAdaptor(a,
  76. monitorplugin.NewGroupMonitorPlugin(),
  77. cache.NewCachePlugin(common.RDB),
  78. streamfake.NewStreamFakePlugin(),
  79. timeout.NewTimeoutPlugin(),
  80. websearch.NewWebSearchPlugin(func(modelName string) (*model.Channel, error) {
  81. return getWebSearchChannel(ctx, mc, modelName)
  82. }),
  83. thinksplit.NewThinkPlugin(),
  84. monitorplugin.NewChannelMonitorPlugin(),
  85. patch.NewPatchPlugin(),
  86. )
  87. }
  88. func relayHandler(c *gin.Context, meta *meta.Meta, mc *model.ModelCaches) *controller.HandleResult {
  89. log := common.GetLogger(c)
  90. middleware.SetLogFieldsFromMeta(meta, log.Data)
  91. adaptor, ok := adaptors.GetAdaptor(meta.Channel.Type)
  92. if !ok {
  93. return &controller.HandleResult{
  94. Error: relaymodel.WrapperOpenAIErrorWithMessage(
  95. fmt.Sprintf("invalid channel type: %d", meta.Channel.Type),
  96. "invalid_channel_type",
  97. http.StatusInternalServerError,
  98. ),
  99. }
  100. }
  101. adaptor = wrapPlugin(c.Request.Context(), mc, adaptor)
  102. return controller.Handle(adaptor, c, meta, adaptorStore)
  103. }
  104. func defaultPriceFunc(_ *gin.Context, mc model.ModelConfig) (model.Price, error) {
  105. return mc.Price, nil
  106. }
  107. func relayController(m mode.Mode) RelayController {
  108. c := RelayController{
  109. Handler: func(c *gin.Context, meta *meta.Meta) *controller.HandleResult {
  110. return relayHandler(c, meta, middleware.GetModelCaches(c))
  111. },
  112. GetRequestPrice: defaultPriceFunc,
  113. }
  114. switch m {
  115. case mode.ImagesGenerations:
  116. c.GetRequestPrice = controller.GetImagesRequestPrice
  117. c.GetRequestUsage = controller.GetImagesRequestUsage
  118. case mode.ImagesEdits:
  119. c.GetRequestPrice = controller.GetImagesEditsRequestPrice
  120. c.GetRequestUsage = controller.GetImagesEditsRequestUsage
  121. case mode.AudioSpeech:
  122. c.GetRequestUsage = controller.GetTTSRequestUsage
  123. case mode.AudioTranslation, mode.AudioTranscription:
  124. c.GetRequestUsage = controller.GetSTTRequestUsage
  125. case mode.ParsePdf:
  126. c.GetRequestUsage = controller.GetPdfRequestUsage
  127. case mode.Rerank:
  128. c.GetRequestUsage = controller.GetRerankRequestUsage
  129. case mode.Anthropic:
  130. c.GetRequestUsage = controller.GetAnthropicRequestUsage
  131. case mode.ChatCompletions:
  132. c.GetRequestUsage = controller.GetChatRequestUsage
  133. case mode.Gemini:
  134. c.GetRequestUsage = controller.GetGeminiRequestUsage
  135. case mode.Embeddings:
  136. c.GetRequestUsage = controller.GetEmbedRequestUsage
  137. case mode.Completions:
  138. c.GetRequestUsage = controller.GetCompletionsRequestUsage
  139. case mode.VideoGenerationsJobs:
  140. c.GetRequestUsage = controller.GetVideoGenerationJobRequestUsage
  141. case mode.Responses:
  142. c.GetRequestUsage = controller.GetResponsesRequestUsage
  143. }
  144. return c
  145. }
  146. func RelayHelper(
  147. c *gin.Context,
  148. meta *meta.Meta,
  149. handel RelayHandler,
  150. ) (*controller.HandleResult, bool) {
  151. result := handel(c, meta)
  152. if result.Error == nil {
  153. return result, false
  154. }
  155. return result, monitorplugin.ShouldRetry(result.Error)
  156. }
  157. func NewRelay(mode mode.Mode) func(c *gin.Context) {
  158. relayController := relayController(mode)
  159. return func(c *gin.Context) {
  160. relay(c, mode, relayController)
  161. }
  162. }
  163. func NewMetaByContext(
  164. c *gin.Context,
  165. channel *model.Channel,
  166. mode mode.Mode,
  167. opts ...meta.Option,
  168. ) *meta.Meta {
  169. return middleware.NewMetaByContext(c, channel, mode, opts...)
  170. }
  171. func relay(c *gin.Context, mode mode.Mode, relayController RelayController) {
  172. requestModel := middleware.GetRequestModel(c)
  173. mc := middleware.GetModelConfig(c)
  174. // Get initial channel
  175. initialChannel, err := getInitialChannel(c, requestModel, mode)
  176. if err != nil || initialChannel == nil || initialChannel.channel == nil {
  177. middleware.AbortLogWithMessageWithMode(mode, c,
  178. http.StatusServiceUnavailable,
  179. "the upstream load is saturated, please try again later",
  180. )
  181. return
  182. }
  183. price := model.Price{}
  184. if relayController.GetRequestPrice != nil {
  185. price, err = relayController.GetRequestPrice(c, mc)
  186. if err != nil {
  187. middleware.AbortLogWithMessageWithMode(mode, c,
  188. http.StatusInternalServerError,
  189. "get request price failed: "+err.Error(),
  190. )
  191. return
  192. }
  193. }
  194. meta := NewMetaByContext(c, initialChannel.channel, mode)
  195. if relayController.GetRequestUsage != nil {
  196. requestUsage, err := relayController.GetRequestUsage(c, mc)
  197. if err != nil {
  198. middleware.AbortLogWithMessageWithMode(mode, c,
  199. http.StatusInternalServerError,
  200. "get request usage failed: "+err.Error(),
  201. )
  202. return
  203. }
  204. meta.RequestUsage = requestUsage
  205. }
  206. gbc := middleware.GetGroupBalanceConsumerFromContext(c)
  207. if !gbc.CheckBalance(consume.CalculateAmount(http.StatusOK, meta.RequestUsage, price)) {
  208. middleware.AbortLogWithMessageWithMode(mode, c,
  209. http.StatusForbidden,
  210. fmt.Sprintf("group (%s) balance not enough", gbc.Group),
  211. relaymodel.WithType(middleware.GroupBalanceNotEnough),
  212. )
  213. return
  214. }
  215. // First attempt
  216. result, retry := RelayHelper(c, meta, relayController.Handler)
  217. retryTimes := int(config.GetRetryTimes())
  218. if mc.RetryTimes > 0 {
  219. retryTimes = int(mc.RetryTimes)
  220. }
  221. if handleRelayResult(c, result.Error, retry, retryTimes) {
  222. recordResult(
  223. c,
  224. meta,
  225. price,
  226. result,
  227. 0,
  228. true,
  229. middleware.GetRequestUser(c),
  230. middleware.GetRequestMetadata(c),
  231. )
  232. return
  233. }
  234. // Setup retry state
  235. retryState := initRetryState(
  236. retryTimes,
  237. initialChannel,
  238. meta,
  239. result,
  240. price,
  241. )
  242. // Retry loop
  243. retryLoop(c, mode, retryState, relayController.Handler)
  244. }
  245. // recordResult records the consumption for the final result
  246. func recordResult(
  247. c *gin.Context,
  248. meta *meta.Meta,
  249. price model.Price,
  250. result *controller.HandleResult,
  251. retryTimes int,
  252. downstreamResult bool,
  253. user string,
  254. metadata map[string]string,
  255. ) {
  256. code := http.StatusOK
  257. content := ""
  258. if result.Error != nil {
  259. code = result.Error.StatusCode()
  260. respBody, _ := result.Error.MarshalJSON()
  261. content = conv.BytesToString(respBody)
  262. }
  263. var detail *model.RequestDetail
  264. firstByteAt := result.Detail.FirstByteAt
  265. if config.GetSaveAllLogDetail() || meta.ModelConfig.ForceSaveDetail || code != http.StatusOK {
  266. detail = &model.RequestDetail{
  267. RequestBody: result.Detail.RequestBody,
  268. ResponseBody: result.Detail.ResponseBody,
  269. }
  270. }
  271. gbc := middleware.GetGroupBalanceConsumerFromContext(c)
  272. amount := consume.CalculateAmount(
  273. code,
  274. result.Usage,
  275. price,
  276. )
  277. if amount > 0 {
  278. log := common.GetLogger(c)
  279. log.Data["amount"] = strconv.FormatFloat(amount, 'f', -1, 64)
  280. }
  281. consume.AsyncConsume(
  282. gbc.Consumer,
  283. code,
  284. firstByteAt,
  285. meta,
  286. result.Usage,
  287. price,
  288. content,
  289. c.ClientIP(),
  290. retryTimes,
  291. detail,
  292. downstreamResult,
  293. user,
  294. metadata,
  295. )
  296. }
  297. type retryState struct {
  298. retryTimes int
  299. lastHasPermissionChannel *model.Channel
  300. ignoreChannelIDs map[int64]struct{}
  301. errorRates map[int64]float64
  302. exhausted bool
  303. failedChannelIDs map[int64]struct{} // Track all failed channels in this request
  304. meta *meta.Meta
  305. price model.Price
  306. requestUsage model.Usage
  307. result *controller.HandleResult
  308. migratedChannels []*model.Channel
  309. }
  310. func handleRelayResult(
  311. c *gin.Context,
  312. bizErr adaptor.Error,
  313. retry bool,
  314. retryTimes int,
  315. ) (done bool) {
  316. if bizErr == nil {
  317. return true
  318. }
  319. if !retry ||
  320. retryTimes == 0 ||
  321. c.Request.Context().Err() != nil {
  322. ErrorWithRequestID(c, bizErr)
  323. return true
  324. }
  325. return false
  326. }
  327. func initRetryState(
  328. retryTimes int,
  329. channel *initialChannel,
  330. meta *meta.Meta,
  331. result *controller.HandleResult,
  332. price model.Price,
  333. ) *retryState {
  334. state := &retryState{
  335. retryTimes: retryTimes,
  336. ignoreChannelIDs: channel.ignoreChannelIDs,
  337. errorRates: channel.errorRates,
  338. meta: meta,
  339. result: result,
  340. price: price,
  341. requestUsage: meta.RequestUsage,
  342. migratedChannels: channel.migratedChannels,
  343. failedChannelIDs: make(map[int64]struct{}),
  344. }
  345. // Record initial failed channel
  346. state.failedChannelIDs[int64(meta.Channel.ID)] = struct{}{}
  347. if channel.designatedChannel {
  348. state.exhausted = true
  349. }
  350. if !monitorplugin.ChannelHasPermission(result.Error) {
  351. if state.ignoreChannelIDs == nil {
  352. state.ignoreChannelIDs = make(map[int64]struct{})
  353. }
  354. state.ignoreChannelIDs[int64(channel.channel.ID)] = struct{}{}
  355. } else {
  356. state.lastHasPermissionChannel = channel.channel
  357. }
  358. return state
  359. }
  360. func retryLoop(c *gin.Context, mode mode.Mode, state *retryState, relayController RelayHandler) {
  361. log := common.GetLogger(c)
  362. // do not use for i := range state.retryTimes, because the retryTimes is constant
  363. i := 0
  364. for {
  365. lastStatusCode := state.result.Error.StatusCode()
  366. lastChannelID := state.meta.Channel.ID
  367. newChannel, err := getRetryChannel(state, i, state.retryTimes)
  368. if err == nil {
  369. err = prepareRetry(c)
  370. }
  371. if err != nil {
  372. if !errors.Is(err, ErrChannelsExhausted) {
  373. log.Errorf("prepare retry failed: %+v", err)
  374. }
  375. // when the last request has not recorded the result, record the result
  376. if state.meta != nil && state.result != nil {
  377. recordResult(
  378. c,
  379. state.meta,
  380. state.price,
  381. state.result,
  382. i,
  383. true,
  384. middleware.GetRequestUser(c),
  385. middleware.GetRequestMetadata(c),
  386. )
  387. }
  388. break
  389. }
  390. // when the last request has not recorded the result, record the result
  391. if state.meta != nil && state.result != nil {
  392. recordResult(
  393. c,
  394. state.meta,
  395. state.price,
  396. state.result,
  397. i,
  398. false,
  399. middleware.GetRequestUser(c),
  400. middleware.GetRequestMetadata(c),
  401. )
  402. state.meta = nil
  403. state.result = nil
  404. }
  405. log.Data["retry"] = strconv.Itoa(i + 1)
  406. log.Warnf("using channel %s (type: %d, id: %d) to retry (remain times %d)",
  407. newChannel.Name,
  408. newChannel.Type,
  409. newChannel.ID,
  410. state.retryTimes-i,
  411. )
  412. // Check if we should delay (using the same channel)
  413. if shouldDelay(lastStatusCode, lastChannelID, newChannel.ID) {
  414. relayDelay()
  415. }
  416. state.meta = NewMetaByContext(
  417. c,
  418. newChannel,
  419. mode,
  420. meta.WithRequestUsage(state.requestUsage),
  421. meta.WithRetryAt(time.Now()),
  422. )
  423. var retry bool
  424. state.result, retry = RelayHelper(c, state.meta, relayController)
  425. done := handleRetryResult(c, retry, newChannel, state)
  426. // Record failed channel if retry is needed
  427. if !done && state.result.Error != nil {
  428. state.failedChannelIDs[int64(newChannel.ID)] = struct{}{}
  429. }
  430. if done || i == state.retryTimes-1 {
  431. recordResult(
  432. c,
  433. state.meta,
  434. state.price,
  435. state.result,
  436. i+1,
  437. true,
  438. middleware.GetRequestUser(c),
  439. middleware.GetRequestMetadata(c),
  440. )
  441. break
  442. }
  443. i++
  444. }
  445. if state.result.Error != nil {
  446. ErrorWithRequestID(c, state.result.Error)
  447. }
  448. }
  449. func prepareRetry(c *gin.Context) error {
  450. requestBody, err := common.GetRequestBodyReusable(c.Request)
  451. if err != nil {
  452. return fmt.Errorf("get request body failed in prepare retry: %w", err)
  453. }
  454. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  455. return nil
  456. }
  457. func handleRetryResult(
  458. ctx *gin.Context,
  459. retry bool,
  460. newChannel *model.Channel,
  461. state *retryState,
  462. ) (done bool) {
  463. if ctx.Request.Context().Err() != nil {
  464. return true
  465. }
  466. if !retry || state.result.Error == nil {
  467. return true
  468. }
  469. hasPermission := monitorplugin.ChannelHasPermission(state.result.Error)
  470. if state.exhausted {
  471. if !hasPermission {
  472. return true
  473. }
  474. } else {
  475. if !hasPermission {
  476. if state.ignoreChannelIDs == nil {
  477. state.ignoreChannelIDs = make(map[int64]struct{})
  478. }
  479. state.ignoreChannelIDs[int64(newChannel.ID)] = struct{}{}
  480. state.retryTimes++
  481. } else {
  482. state.lastHasPermissionChannel = newChannel
  483. }
  484. }
  485. return false
  486. }
  487. // shouldDelay checks if we need to add a delay before retrying
  488. // Only adds delay when retrying with the same channel for rate limiting issues
  489. func shouldDelay(statusCode, lastChannelID, newChannelID int) bool {
  490. if lastChannelID != newChannelID {
  491. return false
  492. }
  493. // Only delay for rate limiting or service unavailable errors
  494. return statusCode == http.StatusTooManyRequests ||
  495. statusCode == http.StatusServiceUnavailable
  496. }
  497. func relayDelay() {
  498. time.Sleep(time.Duration(rand.Float64()*float64(time.Second)) + time.Second)
  499. }
  500. func RelayNotImplemented(c *gin.Context) {
  501. ErrorWithRequestID(c,
  502. relaymodel.NewOpenAIError(http.StatusNotImplemented, relaymodel.OpenAIError{
  503. Message: "API not implemented",
  504. Type: relaymodel.ErrorTypeAIPROXY,
  505. Code: "api_not_implemented",
  506. }),
  507. )
  508. }
  509. func ErrorWithRequestID(c *gin.Context, relayErr adaptor.Error) {
  510. requestID := middleware.GetRequestID(c)
  511. if requestID == "" {
  512. c.JSON(relayErr.StatusCode(), relayErr)
  513. return
  514. }
  515. log := common.GetLogger(c)
  516. data, err := relayErr.MarshalJSON()
  517. if err != nil {
  518. log.Errorf("marshal error failed: %+v", err)
  519. c.JSON(relayErr.StatusCode(), relayErr)
  520. return
  521. }
  522. node, err := sonic.Get(data)
  523. if err != nil {
  524. log.Errorf("get node failed: %+v", err)
  525. c.JSON(relayErr.StatusCode(), relayErr)
  526. return
  527. }
  528. _, err = node.Set("aiproxy", ast.NewString(requestID))
  529. if err != nil {
  530. log.Errorf("set request id failed: %+v", err)
  531. c.JSON(relayErr.StatusCode(), relayErr)
  532. return
  533. }
  534. c.JSON(relayErr.StatusCode(), &node)
  535. }