relay-controller.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. package controller
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "math/rand/v2"
  8. "net/http"
  9. "strconv"
  10. "time"
  11. "github.com/bytedance/sonic"
  12. "github.com/bytedance/sonic/ast"
  13. "github.com/gin-gonic/gin"
  14. "github.com/labring/aiproxy/core/common"
  15. "github.com/labring/aiproxy/core/common/config"
  16. "github.com/labring/aiproxy/core/common/consume"
  17. "github.com/labring/aiproxy/core/common/conv"
  18. "github.com/labring/aiproxy/core/middleware"
  19. "github.com/labring/aiproxy/core/model"
  20. "github.com/labring/aiproxy/core/relay/adaptor"
  21. "github.com/labring/aiproxy/core/relay/adaptors"
  22. "github.com/labring/aiproxy/core/relay/controller"
  23. "github.com/labring/aiproxy/core/relay/meta"
  24. "github.com/labring/aiproxy/core/relay/mode"
  25. relaymodel "github.com/labring/aiproxy/core/relay/model"
  26. "github.com/labring/aiproxy/core/relay/plugin"
  27. "github.com/labring/aiproxy/core/relay/plugin/cache"
  28. monitorplugin "github.com/labring/aiproxy/core/relay/plugin/monitor"
  29. "github.com/labring/aiproxy/core/relay/plugin/thinksplit"
  30. websearch "github.com/labring/aiproxy/core/relay/plugin/web-search"
  31. )
  32. // https://platform.openai.com/docs/api-reference/chat
  33. type (
  34. RelayHandler func(*gin.Context, *meta.Meta) *controller.HandleResult
  35. GetRequestUsage func(*gin.Context, model.ModelConfig) (model.Usage, error)
  36. GetRequestPrice func(*gin.Context, model.ModelConfig) (model.Price, error)
  37. )
  38. type RelayController struct {
  39. GetRequestUsage GetRequestUsage
  40. GetRequestPrice GetRequestPrice
  41. Handler RelayHandler
  42. }
  43. var adaptorStore adaptor.Store = &storeImpl{}
  44. type storeImpl struct{}
  45. func (s *storeImpl) GetStore(id string) (adaptor.StoreCache, error) {
  46. store, err := model.CacheGetStore(id)
  47. if err != nil {
  48. return adaptor.StoreCache{}, err
  49. }
  50. return adaptor.StoreCache{
  51. ID: store.ID,
  52. GroupID: store.GroupID,
  53. TokenID: store.TokenID,
  54. ChannelID: store.ChannelID,
  55. Model: store.Model,
  56. ExpiresAt: store.ExpiresAt,
  57. }, nil
  58. }
  59. func (s *storeImpl) SaveStore(store adaptor.StoreCache) error {
  60. _, err := model.SaveStore(&model.Store{
  61. ID: store.ID,
  62. GroupID: store.GroupID,
  63. TokenID: store.TokenID,
  64. ChannelID: store.ChannelID,
  65. Model: store.Model,
  66. ExpiresAt: store.ExpiresAt,
  67. })
  68. return err
  69. }
  70. func relayHandler(c *gin.Context, meta *meta.Meta) *controller.HandleResult {
  71. log := common.GetLogger(c)
  72. middleware.SetLogFieldsFromMeta(meta, log.Data)
  73. adaptor, ok := adaptors.GetAdaptor(meta.Channel.Type)
  74. if !ok {
  75. return &controller.HandleResult{
  76. Error: relaymodel.WrapperOpenAIErrorWithMessage(
  77. fmt.Sprintf("invalid channel type: %d", meta.Channel.Type),
  78. "invalid_channel_type",
  79. http.StatusInternalServerError,
  80. ),
  81. }
  82. }
  83. a := plugin.WrapperAdaptor(adaptor,
  84. monitorplugin.NewGroupMonitorPlugin(),
  85. cache.NewCachePlugin(common.RDB),
  86. websearch.NewWebSearchPlugin(func(modelName string) (*model.Channel, error) {
  87. return getWebSearchChannel(c, modelName)
  88. }),
  89. thinksplit.NewThinkPlugin(),
  90. monitorplugin.NewChannelMonitorPlugin(),
  91. )
  92. return controller.Handle(a, c, meta, adaptorStore)
  93. }
  94. func relayController(m mode.Mode) RelayController {
  95. c := RelayController{
  96. Handler: relayHandler,
  97. }
  98. switch m {
  99. case mode.ImagesGenerations:
  100. c.GetRequestPrice = controller.GetImagesRequestPrice
  101. c.GetRequestUsage = controller.GetImagesRequestUsage
  102. case mode.ImagesEdits:
  103. c.GetRequestPrice = controller.GetImagesEditsRequestPrice
  104. c.GetRequestUsage = controller.GetImagesEditsRequestUsage
  105. case mode.AudioSpeech:
  106. c.GetRequestPrice = controller.GetTTSRequestPrice
  107. c.GetRequestUsage = controller.GetTTSRequestUsage
  108. case mode.AudioTranslation, mode.AudioTranscription:
  109. c.GetRequestPrice = controller.GetSTTRequestPrice
  110. c.GetRequestUsage = controller.GetSTTRequestUsage
  111. case mode.ParsePdf:
  112. c.GetRequestPrice = controller.GetPdfRequestPrice
  113. c.GetRequestUsage = controller.GetPdfRequestUsage
  114. case mode.Rerank:
  115. c.GetRequestPrice = controller.GetRerankRequestPrice
  116. c.GetRequestUsage = controller.GetRerankRequestUsage
  117. case mode.Anthropic:
  118. c.GetRequestPrice = controller.GetAnthropicRequestPrice
  119. c.GetRequestUsage = controller.GetAnthropicRequestUsage
  120. case mode.ChatCompletions:
  121. c.GetRequestPrice = controller.GetChatRequestPrice
  122. c.GetRequestUsage = controller.GetChatRequestUsage
  123. case mode.Embeddings:
  124. c.GetRequestPrice = controller.GetEmbedRequestPrice
  125. c.GetRequestUsage = controller.GetEmbedRequestUsage
  126. case mode.Completions:
  127. c.GetRequestPrice = controller.GetCompletionsRequestPrice
  128. c.GetRequestUsage = controller.GetCompletionsRequestUsage
  129. case mode.VideoGenerationsJobs:
  130. c.GetRequestPrice = controller.GetVideoGenerationJobRequestPrice
  131. c.GetRequestUsage = controller.GetVideoGenerationJobRequestUsage
  132. }
  133. return c
  134. }
  135. func RelayHelper(
  136. c *gin.Context,
  137. meta *meta.Meta,
  138. handel RelayHandler,
  139. ) (*controller.HandleResult, bool) {
  140. result := handel(c, meta)
  141. if result.Error == nil {
  142. return result, false
  143. }
  144. return result, monitorplugin.ShouldRetry(result.Error)
  145. }
  146. func NewRelay(mode mode.Mode) func(c *gin.Context) {
  147. relayController := relayController(mode)
  148. return func(c *gin.Context) {
  149. relay(c, mode, relayController)
  150. }
  151. }
  152. func NewMetaByContext(
  153. c *gin.Context,
  154. channel *model.Channel,
  155. mode mode.Mode,
  156. opts ...meta.Option,
  157. ) *meta.Meta {
  158. return middleware.NewMetaByContext(c, channel, mode, opts...)
  159. }
  160. func relay(c *gin.Context, mode mode.Mode, relayController RelayController) {
  161. requestModel := middleware.GetRequestModel(c)
  162. mc := middleware.GetModelConfig(c)
  163. // Get initial channel
  164. initialChannel, err := getInitialChannel(c, requestModel, mode)
  165. if err != nil || initialChannel == nil || initialChannel.channel == nil {
  166. middleware.AbortLogWithMessageWithMode(mode, c,
  167. http.StatusServiceUnavailable,
  168. "the upstream load is saturated, please try again later",
  169. )
  170. return
  171. }
  172. price := model.Price{}
  173. if relayController.GetRequestPrice != nil {
  174. price, err = relayController.GetRequestPrice(c, mc)
  175. if err != nil {
  176. middleware.AbortLogWithMessageWithMode(mode, c,
  177. http.StatusInternalServerError,
  178. "get request price failed: "+err.Error(),
  179. )
  180. return
  181. }
  182. }
  183. meta := NewMetaByContext(c, initialChannel.channel, mode)
  184. if relayController.GetRequestUsage != nil {
  185. requestUsage, err := relayController.GetRequestUsage(c, mc)
  186. if err != nil {
  187. middleware.AbortLogWithMessageWithMode(mode, c,
  188. http.StatusInternalServerError,
  189. "get request usage failed: "+err.Error(),
  190. )
  191. return
  192. }
  193. gbc := middleware.GetGroupBalanceConsumerFromContext(c)
  194. if !gbc.CheckBalance(consume.CalculateAmount(http.StatusOK, requestUsage, price)) {
  195. middleware.AbortLogWithMessageWithMode(mode, c,
  196. http.StatusForbidden,
  197. fmt.Sprintf("group (%s) balance not enough", gbc.Group),
  198. middleware.GroupBalanceNotEnough,
  199. )
  200. return
  201. }
  202. meta.RequestUsage = requestUsage
  203. }
  204. // First attempt
  205. result, retry := RelayHelper(c, meta, relayController.Handler)
  206. retryTimes := int(config.GetRetryTimes())
  207. if mc.RetryTimes > 0 {
  208. retryTimes = int(mc.RetryTimes)
  209. }
  210. if handleRelayResult(c, result.Error, retry, retryTimes) {
  211. recordResult(
  212. c,
  213. meta,
  214. price,
  215. result,
  216. 0,
  217. true,
  218. middleware.GetRequestUser(c),
  219. middleware.GetRequestMetadata(c),
  220. )
  221. return
  222. }
  223. // Setup retry state
  224. retryState := initRetryState(
  225. retryTimes,
  226. initialChannel,
  227. meta,
  228. result,
  229. price,
  230. )
  231. // Retry loop
  232. retryLoop(c, mode, retryState, relayController.Handler)
  233. }
  234. // recordResult records the consumption for the final result
  235. func recordResult(
  236. c *gin.Context,
  237. meta *meta.Meta,
  238. price model.Price,
  239. result *controller.HandleResult,
  240. retryTimes int,
  241. downstreamResult bool,
  242. user string,
  243. metadata map[string]string,
  244. ) {
  245. code := http.StatusOK
  246. content := ""
  247. if result.Error != nil {
  248. code = result.Error.StatusCode()
  249. respBody, _ := result.Error.MarshalJSON()
  250. content = conv.BytesToString(respBody)
  251. }
  252. var detail *model.RequestDetail
  253. firstByteAt := result.Detail.FirstByteAt
  254. if config.GetSaveAllLogDetail() || meta.ModelConfig.ForceSaveDetail || code != http.StatusOK {
  255. detail = &model.RequestDetail{
  256. RequestBody: result.Detail.RequestBody,
  257. ResponseBody: result.Detail.ResponseBody,
  258. }
  259. }
  260. gbc := middleware.GetGroupBalanceConsumerFromContext(c)
  261. amount := consume.CalculateAmount(
  262. code,
  263. result.Usage,
  264. price,
  265. )
  266. if amount > 0 {
  267. log := common.GetLogger(c)
  268. log.Data["amount"] = strconv.FormatFloat(amount, 'f', -1, 64)
  269. }
  270. consume.AsyncConsume(
  271. gbc.Consumer,
  272. code,
  273. firstByteAt,
  274. meta,
  275. result.Usage,
  276. price,
  277. content,
  278. c.ClientIP(),
  279. retryTimes,
  280. detail,
  281. downstreamResult,
  282. user,
  283. metadata,
  284. monitorplugin.GetChannelModelRequestRate(c, meta),
  285. middleware.GetGroupModelTokenRequestRate(c),
  286. )
  287. }
  288. type retryState struct {
  289. retryTimes int
  290. lastHasPermissionChannel *model.Channel
  291. ignoreChannelIDs []int64
  292. errorRates map[int64]float64
  293. exhausted bool
  294. meta *meta.Meta
  295. price model.Price
  296. requestUsage model.Usage
  297. result *controller.HandleResult
  298. migratedChannels []*model.Channel
  299. }
  300. func handleRelayResult(
  301. c *gin.Context,
  302. bizErr adaptor.Error,
  303. retry bool,
  304. retryTimes int,
  305. ) (done bool) {
  306. if bizErr == nil {
  307. return true
  308. }
  309. if !retry ||
  310. retryTimes == 0 ||
  311. c.Request.Context().Err() != nil {
  312. ErrorWithRequestID(c, bizErr)
  313. return true
  314. }
  315. return false
  316. }
  317. func initRetryState(
  318. retryTimes int,
  319. channel *initialChannel,
  320. meta *meta.Meta,
  321. result *controller.HandleResult,
  322. price model.Price,
  323. ) *retryState {
  324. state := &retryState{
  325. retryTimes: retryTimes,
  326. ignoreChannelIDs: channel.ignoreChannelIDs,
  327. errorRates: channel.errorRates,
  328. meta: meta,
  329. result: result,
  330. price: price,
  331. requestUsage: meta.RequestUsage,
  332. migratedChannels: channel.migratedChannels,
  333. }
  334. if channel.designatedChannel {
  335. state.exhausted = true
  336. }
  337. if !monitorplugin.ChannelHasPermission(result.Error) {
  338. state.ignoreChannelIDs = append(state.ignoreChannelIDs, int64(channel.channel.ID))
  339. } else {
  340. state.lastHasPermissionChannel = channel.channel
  341. }
  342. return state
  343. }
  344. func retryLoop(c *gin.Context, mode mode.Mode, state *retryState, relayController RelayHandler) {
  345. log := common.GetLogger(c)
  346. // do not use for i := range state.retryTimes, because the retryTimes is constant
  347. i := 0
  348. for {
  349. lastStatusCode := state.result.Error.StatusCode()
  350. lastChannelID := state.meta.Channel.ID
  351. newChannel, err := getRetryChannel(state)
  352. if err == nil {
  353. err = prepareRetry(c)
  354. }
  355. if err != nil {
  356. if !errors.Is(err, ErrChannelsExhausted) {
  357. log.Errorf("prepare retry failed: %+v", err)
  358. }
  359. // when the last request has not recorded the result, record the result
  360. if state.meta != nil && state.result != nil {
  361. recordResult(
  362. c,
  363. state.meta,
  364. state.price,
  365. state.result,
  366. i,
  367. true,
  368. middleware.GetRequestUser(c),
  369. middleware.GetRequestMetadata(c),
  370. )
  371. }
  372. break
  373. }
  374. // when the last request has not recorded the result, record the result
  375. if state.meta != nil && state.result != nil {
  376. recordResult(
  377. c,
  378. state.meta,
  379. state.price,
  380. state.result,
  381. i,
  382. false,
  383. middleware.GetRequestUser(c),
  384. middleware.GetRequestMetadata(c),
  385. )
  386. state.meta = nil
  387. state.result = nil
  388. }
  389. log.Data["retry"] = strconv.Itoa(i + 1)
  390. log.Warnf("using channel %s (type: %d, id: %d) to retry (remain times %d)",
  391. newChannel.Name,
  392. newChannel.Type,
  393. newChannel.ID,
  394. state.retryTimes-i,
  395. )
  396. // Check if we should delay (using the same channel)
  397. if shouldDelay(lastStatusCode, lastChannelID, newChannel.ID) {
  398. relayDelay()
  399. }
  400. state.meta = NewMetaByContext(
  401. c,
  402. newChannel,
  403. mode,
  404. meta.WithRequestUsage(state.requestUsage),
  405. meta.WithRetryAt(time.Now()),
  406. )
  407. var retry bool
  408. state.result, retry = RelayHelper(c, state.meta, relayController)
  409. done := handleRetryResult(c, retry, newChannel, state)
  410. if done || i == state.retryTimes-1 {
  411. recordResult(
  412. c,
  413. state.meta,
  414. state.price,
  415. state.result,
  416. i+1,
  417. true,
  418. middleware.GetRequestUser(c),
  419. middleware.GetRequestMetadata(c),
  420. )
  421. break
  422. }
  423. i++
  424. }
  425. if state.result.Error != nil {
  426. ErrorWithRequestID(c, state.result.Error)
  427. }
  428. }
  429. func prepareRetry(c *gin.Context) error {
  430. requestBody, err := common.GetRequestBody(c.Request)
  431. if err != nil {
  432. return fmt.Errorf("get request body failed in prepare retry: %w", err)
  433. }
  434. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  435. return nil
  436. }
  437. func handleRetryResult(
  438. ctx *gin.Context,
  439. retry bool,
  440. newChannel *model.Channel,
  441. state *retryState,
  442. ) (done bool) {
  443. if ctx.Request.Context().Err() != nil {
  444. return true
  445. }
  446. if !retry || state.result.Error == nil {
  447. return true
  448. }
  449. hasPermission := monitorplugin.ChannelHasPermission(state.result.Error)
  450. if state.exhausted {
  451. if !hasPermission {
  452. return true
  453. }
  454. } else {
  455. if !hasPermission {
  456. state.ignoreChannelIDs = append(state.ignoreChannelIDs, int64(newChannel.ID))
  457. state.retryTimes++
  458. } else {
  459. state.lastHasPermissionChannel = newChannel
  460. }
  461. }
  462. return false
  463. }
  464. // shouldDelay checks if we need to add a delay before retrying
  465. // Only adds delay when retrying with the same channel for rate limiting issues
  466. func shouldDelay(statusCode, lastChannelID, newChannelID int) bool {
  467. if lastChannelID != newChannelID {
  468. return false
  469. }
  470. // Only delay for rate limiting or service unavailable errors
  471. return statusCode == http.StatusTooManyRequests ||
  472. statusCode == http.StatusServiceUnavailable
  473. }
  474. func relayDelay() {
  475. time.Sleep(time.Duration(rand.Float64()*float64(time.Second)) + time.Second)
  476. }
  477. func RelayNotImplemented(c *gin.Context) {
  478. ErrorWithRequestID(c,
  479. relaymodel.NewOpenAIError(http.StatusNotImplemented, relaymodel.OpenAIError{
  480. Message: "API not implemented",
  481. Type: relaymodel.ErrorTypeAIPROXY,
  482. Code: "api_not_implemented",
  483. }),
  484. )
  485. }
  486. func ErrorWithRequestID(c *gin.Context, relayErr adaptor.Error) {
  487. requestID := middleware.GetRequestID(c)
  488. if requestID == "" {
  489. c.JSON(relayErr.StatusCode(), relayErr)
  490. return
  491. }
  492. log := common.GetLogger(c)
  493. data, err := relayErr.MarshalJSON()
  494. if err != nil {
  495. log.Errorf("marshal error failed: %+v", err)
  496. c.JSON(relayErr.StatusCode(), relayErr)
  497. return
  498. }
  499. node, err := sonic.Get(data)
  500. if err != nil {
  501. log.Errorf("get node failed: %+v", err)
  502. c.JSON(relayErr.StatusCode(), relayErr)
  503. return
  504. }
  505. _, err = node.Set("aiproxy", ast.NewString(requestID))
  506. if err != nil {
  507. log.Errorf("set request id failed: %+v", err)
  508. c.JSON(relayErr.StatusCode(), relayErr)
  509. return
  510. }
  511. c.JSON(relayErr.StatusCode(), &node)
  512. }