relay-controller.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569
  1. package controller
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "math/rand/v2"
  8. "net/http"
  9. "strconv"
  10. "time"
  11. "github.com/bytedance/sonic"
  12. "github.com/bytedance/sonic/ast"
  13. "github.com/gin-gonic/gin"
  14. "github.com/labring/aiproxy/core/common"
  15. "github.com/labring/aiproxy/core/common/config"
  16. "github.com/labring/aiproxy/core/common/consume"
  17. "github.com/labring/aiproxy/core/common/conv"
  18. "github.com/labring/aiproxy/core/middleware"
  19. "github.com/labring/aiproxy/core/model"
  20. "github.com/labring/aiproxy/core/relay/adaptor"
  21. "github.com/labring/aiproxy/core/relay/adaptors"
  22. "github.com/labring/aiproxy/core/relay/controller"
  23. "github.com/labring/aiproxy/core/relay/meta"
  24. "github.com/labring/aiproxy/core/relay/mode"
  25. relaymodel "github.com/labring/aiproxy/core/relay/model"
  26. "github.com/labring/aiproxy/core/relay/plugin"
  27. "github.com/labring/aiproxy/core/relay/plugin/cache"
  28. monitorplugin "github.com/labring/aiproxy/core/relay/plugin/monitor"
  29. "github.com/labring/aiproxy/core/relay/plugin/thinksplit"
  30. websearch "github.com/labring/aiproxy/core/relay/plugin/web-search"
  31. )
  32. // https://platform.openai.com/docs/api-reference/chat
  33. type (
  34. RelayHandler func(*gin.Context, *meta.Meta) *controller.HandleResult
  35. GetRequestUsage func(*gin.Context, model.ModelConfig) (model.Usage, error)
  36. GetRequestPrice func(*gin.Context, model.ModelConfig) (model.Price, error)
  37. )
  38. type RelayController struct {
  39. GetRequestUsage GetRequestUsage
  40. GetRequestPrice GetRequestPrice
  41. Handler RelayHandler
  42. }
  43. var adaptorStore adaptor.Store = &storeImpl{}
  44. type storeImpl struct{}
  45. func (s *storeImpl) GetStore(id string) (adaptor.StoreCache, error) {
  46. store, err := model.CacheGetStore(id)
  47. if err != nil {
  48. return adaptor.StoreCache{}, err
  49. }
  50. return adaptor.StoreCache{
  51. ID: store.ID,
  52. GroupID: store.GroupID,
  53. TokenID: store.TokenID,
  54. ChannelID: store.ChannelID,
  55. Model: store.Model,
  56. ExpiresAt: store.ExpiresAt,
  57. }, nil
  58. }
  59. func (s *storeImpl) SaveStore(store adaptor.StoreCache) error {
  60. _, err := model.SaveStore(&model.Store{
  61. ID: store.ID,
  62. GroupID: store.GroupID,
  63. TokenID: store.TokenID,
  64. ChannelID: store.ChannelID,
  65. Model: store.Model,
  66. ExpiresAt: store.ExpiresAt,
  67. })
  68. return err
  69. }
  70. func relayHandler(c *gin.Context, meta *meta.Meta) *controller.HandleResult {
  71. log := common.GetLogger(c)
  72. middleware.SetLogFieldsFromMeta(meta, log.Data)
  73. adaptor, ok := adaptors.GetAdaptor(meta.Channel.Type)
  74. if !ok {
  75. return &controller.HandleResult{
  76. Error: relaymodel.WrapperOpenAIErrorWithMessage(
  77. fmt.Sprintf("invalid channel type: %d", meta.Channel.Type),
  78. "invalid_channel_type",
  79. http.StatusInternalServerError,
  80. ),
  81. }
  82. }
  83. a := plugin.WrapperAdaptor(adaptor,
  84. monitorplugin.NewGroupMonitorPlugin(),
  85. cache.NewCachePlugin(common.RDB),
  86. websearch.NewWebSearchPlugin(func(modelName string) (*model.Channel, error) {
  87. return getWebSearchChannel(c, modelName)
  88. }),
  89. thinksplit.NewThinkPlugin(),
  90. monitorplugin.NewChannelMonitorPlugin(),
  91. )
  92. return controller.Handle(a, c, meta, adaptorStore)
  93. }
  94. func relayController(m mode.Mode) RelayController {
  95. c := RelayController{
  96. Handler: relayHandler,
  97. }
  98. switch m {
  99. case mode.ImagesGenerations:
  100. c.GetRequestPrice = controller.GetImagesRequestPrice
  101. c.GetRequestUsage = controller.GetImagesRequestUsage
  102. case mode.ImagesEdits:
  103. c.GetRequestPrice = controller.GetImagesEditsRequestPrice
  104. c.GetRequestUsage = controller.GetImagesEditsRequestUsage
  105. case mode.AudioSpeech:
  106. c.GetRequestPrice = controller.GetTTSRequestPrice
  107. c.GetRequestUsage = controller.GetTTSRequestUsage
  108. case mode.AudioTranslation, mode.AudioTranscription:
  109. c.GetRequestPrice = controller.GetSTTRequestPrice
  110. c.GetRequestUsage = controller.GetSTTRequestUsage
  111. case mode.ParsePdf:
  112. c.GetRequestPrice = controller.GetPdfRequestPrice
  113. c.GetRequestUsage = controller.GetPdfRequestUsage
  114. case mode.Rerank:
  115. c.GetRequestPrice = controller.GetRerankRequestPrice
  116. c.GetRequestUsage = controller.GetRerankRequestUsage
  117. case mode.Anthropic:
  118. c.GetRequestPrice = controller.GetAnthropicRequestPrice
  119. c.GetRequestUsage = controller.GetAnthropicRequestUsage
  120. case mode.ChatCompletions:
  121. c.GetRequestPrice = controller.GetChatRequestPrice
  122. c.GetRequestUsage = controller.GetChatRequestUsage
  123. case mode.Embeddings:
  124. c.GetRequestPrice = controller.GetEmbedRequestPrice
  125. c.GetRequestUsage = controller.GetEmbedRequestUsage
  126. case mode.Completions:
  127. c.GetRequestPrice = controller.GetCompletionsRequestPrice
  128. c.GetRequestUsage = controller.GetCompletionsRequestUsage
  129. case mode.VideoGenerationsJobs:
  130. c.GetRequestPrice = controller.GetVideoGenerationJobRequestPrice
  131. c.GetRequestUsage = controller.GetVideoGenerationJobRequestUsage
  132. }
  133. return c
  134. }
  135. func RelayHelper(
  136. c *gin.Context,
  137. meta *meta.Meta,
  138. handel RelayHandler,
  139. ) (*controller.HandleResult, bool) {
  140. result := handel(c, meta)
  141. if result.Error == nil {
  142. return result, false
  143. }
  144. return result, monitorplugin.ShouldRetry(result.Error)
  145. }
  146. func NewRelay(mode mode.Mode) func(c *gin.Context) {
  147. relayController := relayController(mode)
  148. return func(c *gin.Context) {
  149. relay(c, mode, relayController)
  150. }
  151. }
  152. func NewMetaByContext(
  153. c *gin.Context,
  154. channel *model.Channel,
  155. mode mode.Mode,
  156. opts ...meta.Option,
  157. ) *meta.Meta {
  158. return middleware.NewMetaByContext(c, channel, mode, opts...)
  159. }
  160. func relay(c *gin.Context, mode mode.Mode, relayController RelayController) {
  161. requestModel := middleware.GetRequestModel(c)
  162. mc := middleware.GetModelConfig(c)
  163. // Get initial channel
  164. initialChannel, err := getInitialChannel(c, requestModel, mode)
  165. if err != nil || initialChannel == nil || initialChannel.channel == nil {
  166. middleware.AbortLogWithMessageWithMode(mode, c,
  167. http.StatusServiceUnavailable,
  168. "the upstream load is saturated, please try again later",
  169. )
  170. return
  171. }
  172. price := model.Price{}
  173. if relayController.GetRequestPrice != nil {
  174. price, err = relayController.GetRequestPrice(c, mc)
  175. if err != nil {
  176. middleware.AbortLogWithMessageWithMode(mode, c,
  177. http.StatusInternalServerError,
  178. "get request price failed: "+err.Error(),
  179. )
  180. return
  181. }
  182. }
  183. meta := NewMetaByContext(c, initialChannel.channel, mode)
  184. if relayController.GetRequestUsage != nil {
  185. requestUsage, err := relayController.GetRequestUsage(c, mc)
  186. if err != nil {
  187. middleware.AbortLogWithMessageWithMode(mode, c,
  188. http.StatusInternalServerError,
  189. "get request usage failed: "+err.Error(),
  190. )
  191. return
  192. }
  193. gbc := middleware.GetGroupBalanceConsumerFromContext(c)
  194. if !gbc.CheckBalance(consume.CalculateAmount(http.StatusOK, requestUsage, price)) {
  195. middleware.AbortLogWithMessageWithMode(mode, c,
  196. http.StatusForbidden,
  197. fmt.Sprintf("group (%s) balance not enough", gbc.Group),
  198. relaymodel.WithType(middleware.GroupBalanceNotEnough),
  199. )
  200. return
  201. }
  202. meta.RequestUsage = requestUsage
  203. }
  204. // First attempt
  205. result, retry := RelayHelper(c, meta, relayController.Handler)
  206. retryTimes := int(config.GetRetryTimes())
  207. if mc.RetryTimes > 0 {
  208. retryTimes = int(mc.RetryTimes)
  209. }
  210. if handleRelayResult(c, result.Error, retry, retryTimes) {
  211. recordResult(
  212. c,
  213. meta,
  214. price,
  215. result,
  216. 0,
  217. true,
  218. middleware.GetRequestUser(c),
  219. middleware.GetRequestMetadata(c),
  220. )
  221. return
  222. }
  223. // Setup retry state
  224. retryState := initRetryState(
  225. retryTimes,
  226. initialChannel,
  227. meta,
  228. result,
  229. price,
  230. )
  231. // Retry loop
  232. retryLoop(c, mode, retryState, relayController.Handler)
  233. }
  234. // recordResult records the consumption for the final result
  235. func recordResult(
  236. c *gin.Context,
  237. meta *meta.Meta,
  238. price model.Price,
  239. result *controller.HandleResult,
  240. retryTimes int,
  241. downstreamResult bool,
  242. user string,
  243. metadata map[string]string,
  244. ) {
  245. code := http.StatusOK
  246. content := ""
  247. if result.Error != nil {
  248. code = result.Error.StatusCode()
  249. respBody, _ := result.Error.MarshalJSON()
  250. content = conv.BytesToString(respBody)
  251. }
  252. var detail *model.RequestDetail
  253. firstByteAt := result.Detail.FirstByteAt
  254. if config.GetSaveAllLogDetail() || meta.ModelConfig.ForceSaveDetail || code != http.StatusOK {
  255. detail = &model.RequestDetail{
  256. RequestBody: result.Detail.RequestBody,
  257. ResponseBody: result.Detail.ResponseBody,
  258. }
  259. }
  260. gbc := middleware.GetGroupBalanceConsumerFromContext(c)
  261. amount := consume.CalculateAmount(
  262. code,
  263. result.Usage,
  264. price,
  265. )
  266. if amount > 0 {
  267. log := common.GetLogger(c)
  268. log.Data["amount"] = strconv.FormatFloat(amount, 'f', -1, 64)
  269. }
  270. consume.AsyncConsume(
  271. gbc.Consumer,
  272. code,
  273. firstByteAt,
  274. meta,
  275. result.Usage,
  276. price,
  277. content,
  278. c.ClientIP(),
  279. retryTimes,
  280. detail,
  281. downstreamResult,
  282. user,
  283. metadata,
  284. )
  285. }
  286. type retryState struct {
  287. retryTimes int
  288. lastHasPermissionChannel *model.Channel
  289. ignoreChannelIDs []int64
  290. errorRates map[int64]float64
  291. exhausted bool
  292. meta *meta.Meta
  293. price model.Price
  294. requestUsage model.Usage
  295. result *controller.HandleResult
  296. migratedChannels []*model.Channel
  297. }
  298. func handleRelayResult(
  299. c *gin.Context,
  300. bizErr adaptor.Error,
  301. retry bool,
  302. retryTimes int,
  303. ) (done bool) {
  304. if bizErr == nil {
  305. return true
  306. }
  307. if !retry ||
  308. retryTimes == 0 ||
  309. c.Request.Context().Err() != nil {
  310. ErrorWithRequestID(c, bizErr)
  311. return true
  312. }
  313. return false
  314. }
  315. func initRetryState(
  316. retryTimes int,
  317. channel *initialChannel,
  318. meta *meta.Meta,
  319. result *controller.HandleResult,
  320. price model.Price,
  321. ) *retryState {
  322. state := &retryState{
  323. retryTimes: retryTimes,
  324. ignoreChannelIDs: channel.ignoreChannelIDs,
  325. errorRates: channel.errorRates,
  326. meta: meta,
  327. result: result,
  328. price: price,
  329. requestUsage: meta.RequestUsage,
  330. migratedChannels: channel.migratedChannels,
  331. }
  332. if channel.designatedChannel {
  333. state.exhausted = true
  334. }
  335. if !monitorplugin.ChannelHasPermission(result.Error) {
  336. state.ignoreChannelIDs = append(state.ignoreChannelIDs, int64(channel.channel.ID))
  337. } else {
  338. state.lastHasPermissionChannel = channel.channel
  339. }
  340. return state
  341. }
  342. func retryLoop(c *gin.Context, mode mode.Mode, state *retryState, relayController RelayHandler) {
  343. log := common.GetLogger(c)
  344. // do not use for i := range state.retryTimes, because the retryTimes is constant
  345. i := 0
  346. for {
  347. lastStatusCode := state.result.Error.StatusCode()
  348. lastChannelID := state.meta.Channel.ID
  349. newChannel, err := getRetryChannel(state)
  350. if err == nil {
  351. err = prepareRetry(c)
  352. }
  353. if err != nil {
  354. if !errors.Is(err, ErrChannelsExhausted) {
  355. log.Errorf("prepare retry failed: %+v", err)
  356. }
  357. // when the last request has not recorded the result, record the result
  358. if state.meta != nil && state.result != nil {
  359. recordResult(
  360. c,
  361. state.meta,
  362. state.price,
  363. state.result,
  364. i,
  365. true,
  366. middleware.GetRequestUser(c),
  367. middleware.GetRequestMetadata(c),
  368. )
  369. }
  370. break
  371. }
  372. // when the last request has not recorded the result, record the result
  373. if state.meta != nil && state.result != nil {
  374. recordResult(
  375. c,
  376. state.meta,
  377. state.price,
  378. state.result,
  379. i,
  380. false,
  381. middleware.GetRequestUser(c),
  382. middleware.GetRequestMetadata(c),
  383. )
  384. state.meta = nil
  385. state.result = nil
  386. }
  387. log.Data["retry"] = strconv.Itoa(i + 1)
  388. log.Warnf("using channel %s (type: %d, id: %d) to retry (remain times %d)",
  389. newChannel.Name,
  390. newChannel.Type,
  391. newChannel.ID,
  392. state.retryTimes-i,
  393. )
  394. // Check if we should delay (using the same channel)
  395. if shouldDelay(lastStatusCode, lastChannelID, newChannel.ID) {
  396. relayDelay()
  397. }
  398. state.meta = NewMetaByContext(
  399. c,
  400. newChannel,
  401. mode,
  402. meta.WithRequestUsage(state.requestUsage),
  403. meta.WithRetryAt(time.Now()),
  404. )
  405. var retry bool
  406. state.result, retry = RelayHelper(c, state.meta, relayController)
  407. done := handleRetryResult(c, retry, newChannel, state)
  408. if done || i == state.retryTimes-1 {
  409. recordResult(
  410. c,
  411. state.meta,
  412. state.price,
  413. state.result,
  414. i+1,
  415. true,
  416. middleware.GetRequestUser(c),
  417. middleware.GetRequestMetadata(c),
  418. )
  419. break
  420. }
  421. i++
  422. }
  423. if state.result.Error != nil {
  424. ErrorWithRequestID(c, state.result.Error)
  425. }
  426. }
  427. func prepareRetry(c *gin.Context) error {
  428. requestBody, err := common.GetRequestBodyReusable(c.Request)
  429. if err != nil {
  430. return fmt.Errorf("get request body failed in prepare retry: %w", err)
  431. }
  432. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  433. return nil
  434. }
  435. func handleRetryResult(
  436. ctx *gin.Context,
  437. retry bool,
  438. newChannel *model.Channel,
  439. state *retryState,
  440. ) (done bool) {
  441. if ctx.Request.Context().Err() != nil {
  442. return true
  443. }
  444. if !retry || state.result.Error == nil {
  445. return true
  446. }
  447. hasPermission := monitorplugin.ChannelHasPermission(state.result.Error)
  448. if state.exhausted {
  449. if !hasPermission {
  450. return true
  451. }
  452. } else {
  453. if !hasPermission {
  454. state.ignoreChannelIDs = append(state.ignoreChannelIDs, int64(newChannel.ID))
  455. state.retryTimes++
  456. } else {
  457. state.lastHasPermissionChannel = newChannel
  458. }
  459. }
  460. return false
  461. }
  462. // shouldDelay checks if we need to add a delay before retrying
  463. // Only adds delay when retrying with the same channel for rate limiting issues
  464. func shouldDelay(statusCode, lastChannelID, newChannelID int) bool {
  465. if lastChannelID != newChannelID {
  466. return false
  467. }
  468. // Only delay for rate limiting or service unavailable errors
  469. return statusCode == http.StatusTooManyRequests ||
  470. statusCode == http.StatusServiceUnavailable
  471. }
  472. func relayDelay() {
  473. time.Sleep(time.Duration(rand.Float64()*float64(time.Second)) + time.Second)
  474. }
  475. func RelayNotImplemented(c *gin.Context) {
  476. ErrorWithRequestID(c,
  477. relaymodel.NewOpenAIError(http.StatusNotImplemented, relaymodel.OpenAIError{
  478. Message: "API not implemented",
  479. Type: relaymodel.ErrorTypeAIPROXY,
  480. Code: "api_not_implemented",
  481. }),
  482. )
  483. }
  484. func ErrorWithRequestID(c *gin.Context, relayErr adaptor.Error) {
  485. requestID := middleware.GetRequestID(c)
  486. if requestID == "" {
  487. c.JSON(relayErr.StatusCode(), relayErr)
  488. return
  489. }
  490. log := common.GetLogger(c)
  491. data, err := relayErr.MarshalJSON()
  492. if err != nil {
  493. log.Errorf("marshal error failed: %+v", err)
  494. c.JSON(relayErr.StatusCode(), relayErr)
  495. return
  496. }
  497. node, err := sonic.Get(data)
  498. if err != nil {
  499. log.Errorf("get node failed: %+v", err)
  500. c.JSON(relayErr.StatusCode(), relayErr)
  501. return
  502. }
  503. _, err = node.Set("aiproxy", ast.NewString(requestID))
  504. if err != nil {
  505. log.Errorf("set request id failed: %+v", err)
  506. c.JSON(relayErr.StatusCode(), relayErr)
  507. return
  508. }
  509. c.JSON(relayErr.StatusCode(), &node)
  510. }