relay-controller.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629
  1. package controller
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "math/rand/v2"
  9. "net/http"
  10. "strconv"
  11. "time"
  12. "github.com/bytedance/sonic"
  13. "github.com/bytedance/sonic/ast"
  14. "github.com/gin-gonic/gin"
  15. "github.com/labring/aiproxy/core/common"
  16. "github.com/labring/aiproxy/core/common/config"
  17. "github.com/labring/aiproxy/core/common/consume"
  18. "github.com/labring/aiproxy/core/common/conv"
  19. "github.com/labring/aiproxy/core/middleware"
  20. "github.com/labring/aiproxy/core/model"
  21. "github.com/labring/aiproxy/core/relay/adaptor"
  22. "github.com/labring/aiproxy/core/relay/adaptors"
  23. "github.com/labring/aiproxy/core/relay/controller"
  24. "github.com/labring/aiproxy/core/relay/meta"
  25. "github.com/labring/aiproxy/core/relay/mode"
  26. relaymodel "github.com/labring/aiproxy/core/relay/model"
  27. "github.com/labring/aiproxy/core/relay/plugin"
  28. "github.com/labring/aiproxy/core/relay/plugin/cache"
  29. monitorplugin "github.com/labring/aiproxy/core/relay/plugin/monitor"
  30. "github.com/labring/aiproxy/core/relay/plugin/patch"
  31. "github.com/labring/aiproxy/core/relay/plugin/streamfake"
  32. "github.com/labring/aiproxy/core/relay/plugin/thinksplit"
  33. "github.com/labring/aiproxy/core/relay/plugin/timeout"
  34. websearch "github.com/labring/aiproxy/core/relay/plugin/web-search"
  35. )
  36. // https://platform.openai.com/docs/api-reference/chat
  37. type (
  38. RelayHandler func(*gin.Context, *meta.Meta) *controller.HandleResult
  39. GetRequestUsage func(*gin.Context, model.ModelConfig) (model.Usage, error)
  40. GetRequestPrice func(*gin.Context, model.ModelConfig) (model.Price, error)
  41. )
  42. type RelayController struct {
  43. GetRequestUsage GetRequestUsage
  44. GetRequestPrice GetRequestPrice
  45. Handler RelayHandler
  46. }
  47. var adaptorStore adaptor.Store = &storeImpl{}
  48. type storeImpl struct{}
  49. func (s *storeImpl) GetStore(group string, tokenID int, id string) (adaptor.StoreCache, error) {
  50. store, err := model.CacheGetStore(group, tokenID, id)
  51. if err != nil {
  52. return adaptor.StoreCache{}, err
  53. }
  54. return adaptor.StoreCache{
  55. ID: store.ID,
  56. GroupID: store.GroupID,
  57. TokenID: store.TokenID,
  58. ChannelID: store.ChannelID,
  59. Model: store.Model,
  60. ExpiresAt: store.ExpiresAt,
  61. }, nil
  62. }
  63. func (s *storeImpl) SaveStore(store adaptor.StoreCache) error {
  64. _, err := model.SaveStore(&model.StoreV2{
  65. ID: store.ID,
  66. GroupID: store.GroupID,
  67. TokenID: store.TokenID,
  68. ChannelID: store.ChannelID,
  69. Model: store.Model,
  70. ExpiresAt: store.ExpiresAt,
  71. })
  72. return err
  73. }
  74. func wrapPlugin(ctx context.Context, mc *model.ModelCaches, a adaptor.Adaptor) adaptor.Adaptor {
  75. return plugin.WrapperAdaptor(a,
  76. monitorplugin.NewGroupMonitorPlugin(),
  77. cache.NewCachePlugin(common.RDB),
  78. streamfake.NewStreamFakePlugin(),
  79. patch.NewPatchPlugin(),
  80. timeout.NewTimeoutPlugin(),
  81. websearch.NewWebSearchPlugin(func(modelName string) (*model.Channel, error) {
  82. return getWebSearchChannel(ctx, mc, modelName)
  83. }),
  84. thinksplit.NewThinkPlugin(),
  85. monitorplugin.NewChannelMonitorPlugin(),
  86. )
  87. }
  88. func relayHandler(c *gin.Context, meta *meta.Meta, mc *model.ModelCaches) *controller.HandleResult {
  89. log := common.GetLogger(c)
  90. middleware.SetLogFieldsFromMeta(meta, log.Data)
  91. adaptor, ok := adaptors.GetAdaptor(meta.Channel.Type)
  92. if !ok {
  93. return &controller.HandleResult{
  94. Error: relaymodel.WrapperOpenAIErrorWithMessage(
  95. fmt.Sprintf("invalid channel type: %d", meta.Channel.Type),
  96. "invalid_channel_type",
  97. http.StatusInternalServerError,
  98. ),
  99. }
  100. }
  101. adaptor = wrapPlugin(c.Request.Context(), mc, adaptor)
  102. return controller.Handle(adaptor, c, meta, adaptorStore)
  103. }
  104. func defaultPriceFunc(_ *gin.Context, mc model.ModelConfig) (model.Price, error) {
  105. return mc.Price, nil
  106. }
  107. func relayController(m mode.Mode) RelayController {
  108. c := RelayController{
  109. Handler: func(c *gin.Context, meta *meta.Meta) *controller.HandleResult {
  110. return relayHandler(c, meta, middleware.GetModelCaches(c))
  111. },
  112. GetRequestPrice: defaultPriceFunc,
  113. }
  114. switch m {
  115. case mode.ImagesGenerations:
  116. c.GetRequestPrice = controller.GetImagesRequestPrice
  117. c.GetRequestUsage = controller.GetImagesRequestUsage
  118. case mode.ImagesEdits:
  119. c.GetRequestPrice = controller.GetImagesEditsRequestPrice
  120. c.GetRequestUsage = controller.GetImagesEditsRequestUsage
  121. case mode.AudioSpeech:
  122. c.GetRequestUsage = controller.GetTTSRequestUsage
  123. case mode.AudioTranslation, mode.AudioTranscription:
  124. c.GetRequestUsage = controller.GetSTTRequestUsage
  125. case mode.ParsePdf:
  126. c.GetRequestUsage = controller.GetPdfRequestUsage
  127. case mode.Rerank:
  128. c.GetRequestUsage = controller.GetRerankRequestUsage
  129. case mode.Anthropic:
  130. c.GetRequestUsage = controller.GetAnthropicRequestUsage
  131. case mode.ChatCompletions:
  132. c.GetRequestUsage = controller.GetChatRequestUsage
  133. case mode.Embeddings:
  134. c.GetRequestUsage = controller.GetEmbedRequestUsage
  135. case mode.Completions:
  136. c.GetRequestUsage = controller.GetCompletionsRequestUsage
  137. case mode.VideoGenerationsJobs:
  138. c.GetRequestUsage = controller.GetVideoGenerationJobRequestUsage
  139. case mode.Responses:
  140. c.GetRequestUsage = controller.GetResponsesRequestUsage
  141. }
  142. return c
  143. }
  144. func RelayHelper(
  145. c *gin.Context,
  146. meta *meta.Meta,
  147. handel RelayHandler,
  148. ) (*controller.HandleResult, bool) {
  149. result := handel(c, meta)
  150. if result.Error == nil {
  151. return result, false
  152. }
  153. return result, monitorplugin.ShouldRetry(result.Error)
  154. }
  155. func NewRelay(mode mode.Mode) func(c *gin.Context) {
  156. relayController := relayController(mode)
  157. return func(c *gin.Context) {
  158. relay(c, mode, relayController)
  159. }
  160. }
  161. func NewMetaByContext(
  162. c *gin.Context,
  163. channel *model.Channel,
  164. mode mode.Mode,
  165. opts ...meta.Option,
  166. ) *meta.Meta {
  167. return middleware.NewMetaByContext(c, channel, mode, opts...)
  168. }
  169. func relay(c *gin.Context, mode mode.Mode, relayController RelayController) {
  170. requestModel := middleware.GetRequestModel(c)
  171. mc := middleware.GetModelConfig(c)
  172. // Get initial channel
  173. initialChannel, err := getInitialChannel(c, requestModel, mode)
  174. if err != nil || initialChannel == nil || initialChannel.channel == nil {
  175. middleware.AbortLogWithMessageWithMode(mode, c,
  176. http.StatusServiceUnavailable,
  177. "the upstream load is saturated, please try again later",
  178. )
  179. return
  180. }
  181. price := model.Price{}
  182. if relayController.GetRequestPrice != nil {
  183. price, err = relayController.GetRequestPrice(c, mc)
  184. if err != nil {
  185. middleware.AbortLogWithMessageWithMode(mode, c,
  186. http.StatusInternalServerError,
  187. "get request price failed: "+err.Error(),
  188. )
  189. return
  190. }
  191. }
  192. meta := NewMetaByContext(c, initialChannel.channel, mode)
  193. if relayController.GetRequestUsage != nil {
  194. requestUsage, err := relayController.GetRequestUsage(c, mc)
  195. if err != nil {
  196. middleware.AbortLogWithMessageWithMode(mode, c,
  197. http.StatusInternalServerError,
  198. "get request usage failed: "+err.Error(),
  199. )
  200. return
  201. }
  202. meta.RequestUsage = requestUsage
  203. }
  204. gbc := middleware.GetGroupBalanceConsumerFromContext(c)
  205. if !gbc.CheckBalance(consume.CalculateAmount(http.StatusOK, meta.RequestUsage, price)) {
  206. middleware.AbortLogWithMessageWithMode(mode, c,
  207. http.StatusForbidden,
  208. fmt.Sprintf("group (%s) balance not enough", gbc.Group),
  209. relaymodel.WithType(middleware.GroupBalanceNotEnough),
  210. )
  211. return
  212. }
  213. // First attempt
  214. result, retry := RelayHelper(c, meta, relayController.Handler)
  215. retryTimes := int(config.GetRetryTimes())
  216. if mc.RetryTimes > 0 {
  217. retryTimes = int(mc.RetryTimes)
  218. }
  219. if handleRelayResult(c, result.Error, retry, retryTimes) {
  220. recordResult(
  221. c,
  222. meta,
  223. price,
  224. result,
  225. 0,
  226. true,
  227. middleware.GetRequestUser(c),
  228. middleware.GetRequestMetadata(c),
  229. )
  230. return
  231. }
  232. // Setup retry state
  233. retryState := initRetryState(
  234. retryTimes,
  235. initialChannel,
  236. meta,
  237. result,
  238. price,
  239. )
  240. // Retry loop
  241. retryLoop(c, mode, retryState, relayController.Handler)
  242. }
  243. // recordResult records the consumption for the final result
  244. func recordResult(
  245. c *gin.Context,
  246. meta *meta.Meta,
  247. price model.Price,
  248. result *controller.HandleResult,
  249. retryTimes int,
  250. downstreamResult bool,
  251. user string,
  252. metadata map[string]string,
  253. ) {
  254. code := http.StatusOK
  255. content := ""
  256. if result.Error != nil {
  257. code = result.Error.StatusCode()
  258. respBody, _ := result.Error.MarshalJSON()
  259. content = conv.BytesToString(respBody)
  260. }
  261. var detail *model.RequestDetail
  262. firstByteAt := result.Detail.FirstByteAt
  263. if config.GetSaveAllLogDetail() || meta.ModelConfig.ForceSaveDetail || code != http.StatusOK {
  264. detail = &model.RequestDetail{
  265. RequestBody: result.Detail.RequestBody,
  266. ResponseBody: result.Detail.ResponseBody,
  267. }
  268. }
  269. gbc := middleware.GetGroupBalanceConsumerFromContext(c)
  270. amount := consume.CalculateAmount(
  271. code,
  272. result.Usage,
  273. price,
  274. )
  275. if amount > 0 {
  276. log := common.GetLogger(c)
  277. log.Data["amount"] = strconv.FormatFloat(amount, 'f', -1, 64)
  278. }
  279. consume.AsyncConsume(
  280. gbc.Consumer,
  281. code,
  282. firstByteAt,
  283. meta,
  284. result.Usage,
  285. price,
  286. content,
  287. c.ClientIP(),
  288. retryTimes,
  289. detail,
  290. downstreamResult,
  291. user,
  292. metadata,
  293. )
  294. }
  295. type retryState struct {
  296. retryTimes int
  297. lastHasPermissionChannel *model.Channel
  298. ignoreChannelIDs map[int64]struct{}
  299. errorRates map[int64]float64
  300. exhausted bool
  301. failedChannelIDs map[int64]struct{} // Track all failed channels in this request
  302. meta *meta.Meta
  303. price model.Price
  304. requestUsage model.Usage
  305. result *controller.HandleResult
  306. migratedChannels []*model.Channel
  307. }
  308. func handleRelayResult(
  309. c *gin.Context,
  310. bizErr adaptor.Error,
  311. retry bool,
  312. retryTimes int,
  313. ) (done bool) {
  314. if bizErr == nil {
  315. return true
  316. }
  317. if !retry ||
  318. retryTimes == 0 ||
  319. c.Request.Context().Err() != nil {
  320. ErrorWithRequestID(c, bizErr)
  321. return true
  322. }
  323. return false
  324. }
  325. func initRetryState(
  326. retryTimes int,
  327. channel *initialChannel,
  328. meta *meta.Meta,
  329. result *controller.HandleResult,
  330. price model.Price,
  331. ) *retryState {
  332. state := &retryState{
  333. retryTimes: retryTimes,
  334. ignoreChannelIDs: channel.ignoreChannelIDs,
  335. errorRates: channel.errorRates,
  336. meta: meta,
  337. result: result,
  338. price: price,
  339. requestUsage: meta.RequestUsage,
  340. migratedChannels: channel.migratedChannels,
  341. failedChannelIDs: make(map[int64]struct{}),
  342. }
  343. // Record initial failed channel
  344. state.failedChannelIDs[int64(meta.Channel.ID)] = struct{}{}
  345. if channel.designatedChannel {
  346. state.exhausted = true
  347. }
  348. if !monitorplugin.ChannelHasPermission(result.Error) {
  349. if state.ignoreChannelIDs == nil {
  350. state.ignoreChannelIDs = make(map[int64]struct{})
  351. }
  352. state.ignoreChannelIDs[int64(channel.channel.ID)] = struct{}{}
  353. } else {
  354. state.lastHasPermissionChannel = channel.channel
  355. }
  356. return state
  357. }
  358. func retryLoop(c *gin.Context, mode mode.Mode, state *retryState, relayController RelayHandler) {
  359. log := common.GetLogger(c)
  360. // do not use for i := range state.retryTimes, because the retryTimes is constant
  361. i := 0
  362. for {
  363. lastStatusCode := state.result.Error.StatusCode()
  364. lastChannelID := state.meta.Channel.ID
  365. newChannel, err := getRetryChannel(state, i, state.retryTimes)
  366. if err == nil {
  367. err = prepareRetry(c)
  368. }
  369. if err != nil {
  370. if !errors.Is(err, ErrChannelsExhausted) {
  371. log.Errorf("prepare retry failed: %+v", err)
  372. }
  373. // when the last request has not recorded the result, record the result
  374. if state.meta != nil && state.result != nil {
  375. recordResult(
  376. c,
  377. state.meta,
  378. state.price,
  379. state.result,
  380. i,
  381. true,
  382. middleware.GetRequestUser(c),
  383. middleware.GetRequestMetadata(c),
  384. )
  385. }
  386. break
  387. }
  388. // when the last request has not recorded the result, record the result
  389. if state.meta != nil && state.result != nil {
  390. recordResult(
  391. c,
  392. state.meta,
  393. state.price,
  394. state.result,
  395. i,
  396. false,
  397. middleware.GetRequestUser(c),
  398. middleware.GetRequestMetadata(c),
  399. )
  400. state.meta = nil
  401. state.result = nil
  402. }
  403. log.Data["retry"] = strconv.Itoa(i + 1)
  404. log.Warnf("using channel %s (type: %d, id: %d) to retry (remain times %d)",
  405. newChannel.Name,
  406. newChannel.Type,
  407. newChannel.ID,
  408. state.retryTimes-i,
  409. )
  410. // Check if we should delay (using the same channel)
  411. if shouldDelay(lastStatusCode, lastChannelID, newChannel.ID) {
  412. relayDelay()
  413. }
  414. state.meta = NewMetaByContext(
  415. c,
  416. newChannel,
  417. mode,
  418. meta.WithRequestUsage(state.requestUsage),
  419. meta.WithRetryAt(time.Now()),
  420. )
  421. var retry bool
  422. state.result, retry = RelayHelper(c, state.meta, relayController)
  423. done := handleRetryResult(c, retry, newChannel, state)
  424. // Record failed channel if retry is needed
  425. if !done && state.result.Error != nil {
  426. state.failedChannelIDs[int64(newChannel.ID)] = struct{}{}
  427. }
  428. if done || i == state.retryTimes-1 {
  429. recordResult(
  430. c,
  431. state.meta,
  432. state.price,
  433. state.result,
  434. i+1,
  435. true,
  436. middleware.GetRequestUser(c),
  437. middleware.GetRequestMetadata(c),
  438. )
  439. break
  440. }
  441. i++
  442. }
  443. if state.result.Error != nil {
  444. ErrorWithRequestID(c, state.result.Error)
  445. }
  446. }
  447. func prepareRetry(c *gin.Context) error {
  448. requestBody, err := common.GetRequestBodyReusable(c.Request)
  449. if err != nil {
  450. return fmt.Errorf("get request body failed in prepare retry: %w", err)
  451. }
  452. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  453. return nil
  454. }
  455. func handleRetryResult(
  456. ctx *gin.Context,
  457. retry bool,
  458. newChannel *model.Channel,
  459. state *retryState,
  460. ) (done bool) {
  461. if ctx.Request.Context().Err() != nil {
  462. return true
  463. }
  464. if !retry || state.result.Error == nil {
  465. return true
  466. }
  467. hasPermission := monitorplugin.ChannelHasPermission(state.result.Error)
  468. if state.exhausted {
  469. if !hasPermission {
  470. return true
  471. }
  472. } else {
  473. if !hasPermission {
  474. if state.ignoreChannelIDs == nil {
  475. state.ignoreChannelIDs = make(map[int64]struct{})
  476. }
  477. state.ignoreChannelIDs[int64(newChannel.ID)] = struct{}{}
  478. state.retryTimes++
  479. } else {
  480. state.lastHasPermissionChannel = newChannel
  481. }
  482. }
  483. return false
  484. }
  485. // shouldDelay checks if we need to add a delay before retrying
  486. // Only adds delay when retrying with the same channel for rate limiting issues
  487. func shouldDelay(statusCode, lastChannelID, newChannelID int) bool {
  488. if lastChannelID != newChannelID {
  489. return false
  490. }
  491. // Only delay for rate limiting or service unavailable errors
  492. return statusCode == http.StatusTooManyRequests ||
  493. statusCode == http.StatusServiceUnavailable
  494. }
  495. func relayDelay() {
  496. time.Sleep(time.Duration(rand.Float64()*float64(time.Second)) + time.Second)
  497. }
  498. func RelayNotImplemented(c *gin.Context) {
  499. ErrorWithRequestID(c,
  500. relaymodel.NewOpenAIError(http.StatusNotImplemented, relaymodel.OpenAIError{
  501. Message: "API not implemented",
  502. Type: relaymodel.ErrorTypeAIPROXY,
  503. Code: "api_not_implemented",
  504. }),
  505. )
  506. }
  507. func ErrorWithRequestID(c *gin.Context, relayErr adaptor.Error) {
  508. requestID := middleware.GetRequestID(c)
  509. if requestID == "" {
  510. c.JSON(relayErr.StatusCode(), relayErr)
  511. return
  512. }
  513. log := common.GetLogger(c)
  514. data, err := relayErr.MarshalJSON()
  515. if err != nil {
  516. log.Errorf("marshal error failed: %+v", err)
  517. c.JSON(relayErr.StatusCode(), relayErr)
  518. return
  519. }
  520. node, err := sonic.Get(data)
  521. if err != nil {
  522. log.Errorf("get node failed: %+v", err)
  523. c.JSON(relayErr.StatusCode(), relayErr)
  524. return
  525. }
  526. _, err = node.Set("aiproxy", ast.NewString(requestID))
  527. if err != nil {
  528. log.Errorf("set request id failed: %+v", err)
  529. c.JSON(relayErr.StatusCode(), relayErr)
  530. return
  531. }
  532. c.JSON(relayErr.StatusCode(), &node)
  533. }