relay-controller.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607
  1. package controller
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "math/rand/v2"
  9. "net/http"
  10. "strconv"
  11. "time"
  12. "github.com/bytedance/sonic"
  13. "github.com/bytedance/sonic/ast"
  14. "github.com/gin-gonic/gin"
  15. "github.com/labring/aiproxy/core/common"
  16. "github.com/labring/aiproxy/core/common/config"
  17. "github.com/labring/aiproxy/core/common/consume"
  18. "github.com/labring/aiproxy/core/common/conv"
  19. "github.com/labring/aiproxy/core/middleware"
  20. "github.com/labring/aiproxy/core/model"
  21. "github.com/labring/aiproxy/core/relay/adaptor"
  22. "github.com/labring/aiproxy/core/relay/adaptors"
  23. "github.com/labring/aiproxy/core/relay/controller"
  24. "github.com/labring/aiproxy/core/relay/meta"
  25. "github.com/labring/aiproxy/core/relay/mode"
  26. relaymodel "github.com/labring/aiproxy/core/relay/model"
  27. "github.com/labring/aiproxy/core/relay/plugin"
  28. "github.com/labring/aiproxy/core/relay/plugin/cache"
  29. monitorplugin "github.com/labring/aiproxy/core/relay/plugin/monitor"
  30. "github.com/labring/aiproxy/core/relay/plugin/streamfake"
  31. "github.com/labring/aiproxy/core/relay/plugin/thinksplit"
  32. websearch "github.com/labring/aiproxy/core/relay/plugin/web-search"
  33. )
  34. // https://platform.openai.com/docs/api-reference/chat
  35. type (
  36. RelayHandler func(*gin.Context, *meta.Meta) *controller.HandleResult
  37. GetRequestUsage func(*gin.Context, model.ModelConfig) (model.Usage, error)
  38. GetRequestPrice func(*gin.Context, model.ModelConfig) (model.Price, error)
  39. )
  40. type RelayController struct {
  41. GetRequestUsage GetRequestUsage
  42. GetRequestPrice GetRequestPrice
  43. Handler RelayHandler
  44. }
  45. var adaptorStore adaptor.Store = &storeImpl{}
  46. type storeImpl struct{}
  47. func (s *storeImpl) GetStore(id string) (adaptor.StoreCache, error) {
  48. store, err := model.CacheGetStore(id)
  49. if err != nil {
  50. return adaptor.StoreCache{}, err
  51. }
  52. return adaptor.StoreCache{
  53. ID: store.ID,
  54. GroupID: store.GroupID,
  55. TokenID: store.TokenID,
  56. ChannelID: store.ChannelID,
  57. Model: store.Model,
  58. ExpiresAt: store.ExpiresAt,
  59. }, nil
  60. }
  61. func (s *storeImpl) SaveStore(store adaptor.StoreCache) error {
  62. _, err := model.SaveStore(&model.Store{
  63. ID: store.ID,
  64. GroupID: store.GroupID,
  65. TokenID: store.TokenID,
  66. ChannelID: store.ChannelID,
  67. Model: store.Model,
  68. ExpiresAt: store.ExpiresAt,
  69. })
  70. return err
  71. }
  72. func wrapPlugin(ctx context.Context, mc *model.ModelCaches, a adaptor.Adaptor) adaptor.Adaptor {
  73. return plugin.WrapperAdaptor(a,
  74. monitorplugin.NewGroupMonitorPlugin(),
  75. cache.NewCachePlugin(common.RDB),
  76. streamfake.NewStreamFakePlugin(),
  77. websearch.NewWebSearchPlugin(func(modelName string) (*model.Channel, error) {
  78. return getWebSearchChannel(ctx, mc, modelName)
  79. }),
  80. thinksplit.NewThinkPlugin(),
  81. monitorplugin.NewChannelMonitorPlugin(),
  82. )
  83. }
  84. func relayHandler(c *gin.Context, meta *meta.Meta, mc *model.ModelCaches) *controller.HandleResult {
  85. log := common.GetLogger(c)
  86. middleware.SetLogFieldsFromMeta(meta, log.Data)
  87. adaptor, ok := adaptors.GetAdaptor(meta.Channel.Type)
  88. if !ok {
  89. return &controller.HandleResult{
  90. Error: relaymodel.WrapperOpenAIErrorWithMessage(
  91. fmt.Sprintf("invalid channel type: %d", meta.Channel.Type),
  92. "invalid_channel_type",
  93. http.StatusInternalServerError,
  94. ),
  95. }
  96. }
  97. adaptor = wrapPlugin(c.Request.Context(), mc, adaptor)
  98. return controller.Handle(adaptor, c, meta, adaptorStore)
  99. }
  100. func relayController(m mode.Mode) RelayController {
  101. c := RelayController{
  102. Handler: func(c *gin.Context, meta *meta.Meta) *controller.HandleResult {
  103. return relayHandler(c, meta, middleware.GetModelCaches(c))
  104. },
  105. }
  106. switch m {
  107. case mode.ImagesGenerations:
  108. c.GetRequestPrice = controller.GetImagesRequestPrice
  109. c.GetRequestUsage = controller.GetImagesRequestUsage
  110. case mode.ImagesEdits:
  111. c.GetRequestPrice = controller.GetImagesEditsRequestPrice
  112. c.GetRequestUsage = controller.GetImagesEditsRequestUsage
  113. case mode.AudioSpeech:
  114. c.GetRequestPrice = controller.GetTTSRequestPrice
  115. c.GetRequestUsage = controller.GetTTSRequestUsage
  116. case mode.AudioTranslation, mode.AudioTranscription:
  117. c.GetRequestPrice = controller.GetSTTRequestPrice
  118. c.GetRequestUsage = controller.GetSTTRequestUsage
  119. case mode.ParsePdf:
  120. c.GetRequestPrice = controller.GetPdfRequestPrice
  121. c.GetRequestUsage = controller.GetPdfRequestUsage
  122. case mode.Rerank:
  123. c.GetRequestPrice = controller.GetRerankRequestPrice
  124. c.GetRequestUsage = controller.GetRerankRequestUsage
  125. case mode.Anthropic:
  126. c.GetRequestPrice = controller.GetAnthropicRequestPrice
  127. c.GetRequestUsage = controller.GetAnthropicRequestUsage
  128. case mode.ChatCompletions:
  129. c.GetRequestPrice = controller.GetChatRequestPrice
  130. c.GetRequestUsage = controller.GetChatRequestUsage
  131. case mode.Embeddings:
  132. c.GetRequestPrice = controller.GetEmbedRequestPrice
  133. c.GetRequestUsage = controller.GetEmbedRequestUsage
  134. case mode.Completions:
  135. c.GetRequestPrice = controller.GetCompletionsRequestPrice
  136. c.GetRequestUsage = controller.GetCompletionsRequestUsage
  137. case mode.VideoGenerationsJobs:
  138. c.GetRequestPrice = controller.GetVideoGenerationJobRequestPrice
  139. c.GetRequestUsage = controller.GetVideoGenerationJobRequestUsage
  140. }
  141. return c
  142. }
  143. func RelayHelper(
  144. c *gin.Context,
  145. meta *meta.Meta,
  146. handel RelayHandler,
  147. ) (*controller.HandleResult, bool) {
  148. result := handel(c, meta)
  149. if result.Error == nil {
  150. return result, false
  151. }
  152. return result, monitorplugin.ShouldRetry(result.Error)
  153. }
  154. func NewRelay(mode mode.Mode) func(c *gin.Context) {
  155. relayController := relayController(mode)
  156. return func(c *gin.Context) {
  157. relay(c, mode, relayController)
  158. }
  159. }
  160. func NewMetaByContext(
  161. c *gin.Context,
  162. channel *model.Channel,
  163. mode mode.Mode,
  164. opts ...meta.Option,
  165. ) *meta.Meta {
  166. return middleware.NewMetaByContext(c, channel, mode, opts...)
  167. }
  168. func relay(c *gin.Context, mode mode.Mode, relayController RelayController) {
  169. requestModel := middleware.GetRequestModel(c)
  170. mc := middleware.GetModelConfig(c)
  171. // Get initial channel
  172. initialChannel, err := getInitialChannel(c, requestModel, mode)
  173. if err != nil || initialChannel == nil || initialChannel.channel == nil {
  174. middleware.AbortLogWithMessageWithMode(mode, c,
  175. http.StatusServiceUnavailable,
  176. "the upstream load is saturated, please try again later",
  177. )
  178. return
  179. }
  180. price := model.Price{}
  181. if relayController.GetRequestPrice != nil {
  182. price, err = relayController.GetRequestPrice(c, mc)
  183. if err != nil {
  184. middleware.AbortLogWithMessageWithMode(mode, c,
  185. http.StatusInternalServerError,
  186. "get request price failed: "+err.Error(),
  187. )
  188. return
  189. }
  190. }
  191. meta := NewMetaByContext(c, initialChannel.channel, mode)
  192. if relayController.GetRequestUsage != nil {
  193. requestUsage, err := relayController.GetRequestUsage(c, mc)
  194. if err != nil {
  195. middleware.AbortLogWithMessageWithMode(mode, c,
  196. http.StatusInternalServerError,
  197. "get request usage failed: "+err.Error(),
  198. )
  199. return
  200. }
  201. gbc := middleware.GetGroupBalanceConsumerFromContext(c)
  202. if !gbc.CheckBalance(consume.CalculateAmount(http.StatusOK, requestUsage, price)) {
  203. middleware.AbortLogWithMessageWithMode(mode, c,
  204. http.StatusForbidden,
  205. fmt.Sprintf("group (%s) balance not enough", gbc.Group),
  206. relaymodel.WithType(middleware.GroupBalanceNotEnough),
  207. )
  208. return
  209. }
  210. meta.RequestUsage = requestUsage
  211. }
  212. // First attempt
  213. result, retry := RelayHelper(c, meta, relayController.Handler)
  214. retryTimes := int(config.GetRetryTimes())
  215. if mc.RetryTimes > 0 {
  216. retryTimes = int(mc.RetryTimes)
  217. }
  218. if handleRelayResult(c, result.Error, retry, retryTimes) {
  219. recordResult(
  220. c,
  221. meta,
  222. price,
  223. result,
  224. 0,
  225. true,
  226. middleware.GetRequestUser(c),
  227. middleware.GetRequestMetadata(c),
  228. )
  229. return
  230. }
  231. // Setup retry state
  232. retryState := initRetryState(
  233. retryTimes,
  234. initialChannel,
  235. meta,
  236. result,
  237. price,
  238. )
  239. // Retry loop
  240. retryLoop(c, mode, retryState, relayController.Handler)
  241. }
  242. // recordResult records the consumption for the final result
  243. func recordResult(
  244. c *gin.Context,
  245. meta *meta.Meta,
  246. price model.Price,
  247. result *controller.HandleResult,
  248. retryTimes int,
  249. downstreamResult bool,
  250. user string,
  251. metadata map[string]string,
  252. ) {
  253. code := http.StatusOK
  254. content := ""
  255. if result.Error != nil {
  256. code = result.Error.StatusCode()
  257. respBody, _ := result.Error.MarshalJSON()
  258. content = conv.BytesToString(respBody)
  259. }
  260. var detail *model.RequestDetail
  261. firstByteAt := result.Detail.FirstByteAt
  262. if config.GetSaveAllLogDetail() || meta.ModelConfig.ForceSaveDetail || code != http.StatusOK {
  263. detail = &model.RequestDetail{
  264. RequestBody: result.Detail.RequestBody,
  265. ResponseBody: result.Detail.ResponseBody,
  266. }
  267. }
  268. gbc := middleware.GetGroupBalanceConsumerFromContext(c)
  269. amount := consume.CalculateAmount(
  270. code,
  271. result.Usage,
  272. price,
  273. )
  274. if amount > 0 {
  275. log := common.GetLogger(c)
  276. log.Data["amount"] = strconv.FormatFloat(amount, 'f', -1, 64)
  277. }
  278. consume.AsyncConsume(
  279. gbc.Consumer,
  280. code,
  281. firstByteAt,
  282. meta,
  283. result.Usage,
  284. price,
  285. content,
  286. c.ClientIP(),
  287. retryTimes,
  288. detail,
  289. downstreamResult,
  290. user,
  291. metadata,
  292. )
  293. }
  294. type retryState struct {
  295. retryTimes int
  296. lastHasPermissionChannel *model.Channel
  297. ignoreChannelIDs []int64
  298. errorRates map[int64]float64
  299. exhausted bool
  300. meta *meta.Meta
  301. price model.Price
  302. requestUsage model.Usage
  303. result *controller.HandleResult
  304. migratedChannels []*model.Channel
  305. }
  306. func handleRelayResult(
  307. c *gin.Context,
  308. bizErr adaptor.Error,
  309. retry bool,
  310. retryTimes int,
  311. ) (done bool) {
  312. if bizErr == nil {
  313. return true
  314. }
  315. if !retry ||
  316. retryTimes == 0 ||
  317. c.Request.Context().Err() != nil {
  318. ErrorWithRequestID(c, bizErr)
  319. return true
  320. }
  321. return false
  322. }
  323. func initRetryState(
  324. retryTimes int,
  325. channel *initialChannel,
  326. meta *meta.Meta,
  327. result *controller.HandleResult,
  328. price model.Price,
  329. ) *retryState {
  330. state := &retryState{
  331. retryTimes: retryTimes,
  332. ignoreChannelIDs: channel.ignoreChannelIDs,
  333. errorRates: channel.errorRates,
  334. meta: meta,
  335. result: result,
  336. price: price,
  337. requestUsage: meta.RequestUsage,
  338. migratedChannels: channel.migratedChannels,
  339. }
  340. if channel.designatedChannel {
  341. state.exhausted = true
  342. }
  343. if !monitorplugin.ChannelHasPermission(result.Error) {
  344. state.ignoreChannelIDs = append(state.ignoreChannelIDs, int64(channel.channel.ID))
  345. } else {
  346. state.lastHasPermissionChannel = channel.channel
  347. }
  348. return state
  349. }
  350. func retryLoop(c *gin.Context, mode mode.Mode, state *retryState, relayController RelayHandler) {
  351. log := common.GetLogger(c)
  352. // do not use for i := range state.retryTimes, because the retryTimes is constant
  353. i := 0
  354. for {
  355. lastStatusCode := state.result.Error.StatusCode()
  356. lastChannelID := state.meta.Channel.ID
  357. newChannel, err := getRetryChannel(state)
  358. if err == nil {
  359. err = prepareRetry(c)
  360. }
  361. if err != nil {
  362. if !errors.Is(err, ErrChannelsExhausted) {
  363. log.Errorf("prepare retry failed: %+v", err)
  364. }
  365. // when the last request has not recorded the result, record the result
  366. if state.meta != nil && state.result != nil {
  367. recordResult(
  368. c,
  369. state.meta,
  370. state.price,
  371. state.result,
  372. i,
  373. true,
  374. middleware.GetRequestUser(c),
  375. middleware.GetRequestMetadata(c),
  376. )
  377. }
  378. break
  379. }
  380. // when the last request has not recorded the result, record the result
  381. if state.meta != nil && state.result != nil {
  382. recordResult(
  383. c,
  384. state.meta,
  385. state.price,
  386. state.result,
  387. i,
  388. false,
  389. middleware.GetRequestUser(c),
  390. middleware.GetRequestMetadata(c),
  391. )
  392. state.meta = nil
  393. state.result = nil
  394. }
  395. log.Data["retry"] = strconv.Itoa(i + 1)
  396. log.Warnf("using channel %s (type: %d, id: %d) to retry (remain times %d)",
  397. newChannel.Name,
  398. newChannel.Type,
  399. newChannel.ID,
  400. state.retryTimes-i,
  401. )
  402. // Check if we should delay (using the same channel)
  403. if shouldDelay(lastStatusCode, lastChannelID, newChannel.ID) {
  404. relayDelay()
  405. }
  406. state.meta = NewMetaByContext(
  407. c,
  408. newChannel,
  409. mode,
  410. meta.WithRequestUsage(state.requestUsage),
  411. meta.WithRetryAt(time.Now()),
  412. )
  413. var retry bool
  414. state.result, retry = RelayHelper(c, state.meta, relayController)
  415. done := handleRetryResult(c, retry, newChannel, state)
  416. if done || i == state.retryTimes-1 {
  417. recordResult(
  418. c,
  419. state.meta,
  420. state.price,
  421. state.result,
  422. i+1,
  423. true,
  424. middleware.GetRequestUser(c),
  425. middleware.GetRequestMetadata(c),
  426. )
  427. break
  428. }
  429. i++
  430. }
  431. if state.result.Error != nil {
  432. ErrorWithRequestID(c, state.result.Error)
  433. }
  434. }
  435. func prepareRetry(c *gin.Context) error {
  436. requestBody, err := common.GetRequestBodyReusable(c.Request)
  437. if err != nil {
  438. return fmt.Errorf("get request body failed in prepare retry: %w", err)
  439. }
  440. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  441. return nil
  442. }
  443. func handleRetryResult(
  444. ctx *gin.Context,
  445. retry bool,
  446. newChannel *model.Channel,
  447. state *retryState,
  448. ) (done bool) {
  449. if ctx.Request.Context().Err() != nil {
  450. return true
  451. }
  452. if !retry || state.result.Error == nil {
  453. return true
  454. }
  455. hasPermission := monitorplugin.ChannelHasPermission(state.result.Error)
  456. if state.exhausted {
  457. if !hasPermission {
  458. return true
  459. }
  460. } else {
  461. if !hasPermission {
  462. state.ignoreChannelIDs = append(state.ignoreChannelIDs, int64(newChannel.ID))
  463. state.retryTimes++
  464. } else {
  465. state.lastHasPermissionChannel = newChannel
  466. }
  467. }
  468. return false
  469. }
  470. // shouldDelay checks if we need to add a delay before retrying
  471. // Only adds delay when retrying with the same channel for rate limiting issues
  472. func shouldDelay(statusCode, lastChannelID, newChannelID int) bool {
  473. if lastChannelID != newChannelID {
  474. return false
  475. }
  476. // Only delay for rate limiting or service unavailable errors
  477. return statusCode == http.StatusTooManyRequests ||
  478. statusCode == http.StatusServiceUnavailable
  479. }
  480. func relayDelay() {
  481. time.Sleep(time.Duration(rand.Float64()*float64(time.Second)) + time.Second)
  482. }
  483. func RelayNotImplemented(c *gin.Context) {
  484. ErrorWithRequestID(c,
  485. relaymodel.NewOpenAIError(http.StatusNotImplemented, relaymodel.OpenAIError{
  486. Message: "API not implemented",
  487. Type: relaymodel.ErrorTypeAIPROXY,
  488. Code: "api_not_implemented",
  489. }),
  490. )
  491. }
  492. func ErrorWithRequestID(c *gin.Context, relayErr adaptor.Error) {
  493. requestID := middleware.GetRequestID(c)
  494. if requestID == "" {
  495. c.JSON(relayErr.StatusCode(), relayErr)
  496. return
  497. }
  498. log := common.GetLogger(c)
  499. data, err := relayErr.MarshalJSON()
  500. if err != nil {
  501. log.Errorf("marshal error failed: %+v", err)
  502. c.JSON(relayErr.StatusCode(), relayErr)
  503. return
  504. }
  505. node, err := sonic.Get(data)
  506. if err != nil {
  507. log.Errorf("get node failed: %+v", err)
  508. c.JSON(relayErr.StatusCode(), relayErr)
  509. return
  510. }
  511. _, err = node.Set("aiproxy", ast.NewString(requestID))
  512. if err != nil {
  513. log.Errorf("set request id failed: %+v", err)
  514. c.JSON(relayErr.StatusCode(), relayErr)
  515. return
  516. }
  517. c.JSON(relayErr.StatusCode(), &node)
  518. }