distributor.go 17 KB


  1. package middleware
  2. import (
  3. "errors"
  4. "fmt"
  5. "net/http"
  6. "strconv"
  7. "strings"
  8. "time"
  9. "github.com/bytedance/sonic"
  10. "github.com/bytedance/sonic/ast"
  11. "github.com/gin-gonic/gin"
  12. "github.com/labring/aiproxy/core/common"
  13. "github.com/labring/aiproxy/core/common/balance"
  14. "github.com/labring/aiproxy/core/common/config"
  15. "github.com/labring/aiproxy/core/common/consume"
  16. "github.com/labring/aiproxy/core/common/notify"
  17. "github.com/labring/aiproxy/core/common/reqlimit"
  18. "github.com/labring/aiproxy/core/model"
  19. "github.com/labring/aiproxy/core/relay/meta"
  20. "github.com/labring/aiproxy/core/relay/mode"
  21. relaymodel "github.com/labring/aiproxy/core/relay/model"
  22. monitorplugin "github.com/labring/aiproxy/core/relay/plugin/monitor"
  23. )
  24. func calculateGroupConsumeLevelRatio(usedAmount float64) float64 {
  25. v := config.GetGroupConsumeLevelRatio()
  26. if len(v) == 0 {
  27. return 1
  28. }
  29. var (
  30. maxConsumeLevel float64 = -1
  31. groupConsumeLevelRatio float64
  32. )
  33. for consumeLevel, ratio := range v {
  34. if usedAmount < consumeLevel {
  35. continue
  36. }
  37. if consumeLevel > maxConsumeLevel {
  38. maxConsumeLevel = consumeLevel
  39. groupConsumeLevelRatio = ratio
  40. }
  41. }
  42. if groupConsumeLevelRatio <= 0 {
  43. groupConsumeLevelRatio = 1
  44. }
  45. return groupConsumeLevelRatio
  46. }
  47. func getGroupPMRatio(group model.GroupCache) (float64, float64) {
  48. groupRPMRatio := group.RPMRatio
  49. if groupRPMRatio <= 0 {
  50. groupRPMRatio = 1
  51. }
  52. groupTPMRatio := group.TPMRatio
  53. if groupTPMRatio <= 0 {
  54. groupTPMRatio = 1
  55. }
  56. return groupRPMRatio, groupTPMRatio
  57. }
  58. func GetGroupAdjustedModelConfig(group model.GroupCache, mc model.ModelConfig) model.ModelConfig {
  59. if groupModelConfig, ok := group.ModelConfigs[mc.Model]; ok {
  60. mc = mc.LoadFromGroupModelConfig(groupModelConfig)
  61. }
  62. rpmRatio, tpmRatio := getGroupPMRatio(group)
  63. groupConsumeLevelRatio := calculateGroupConsumeLevelRatio(group.UsedAmount)
  64. mc.RPM = int64(float64(mc.RPM) * rpmRatio * groupConsumeLevelRatio)
  65. mc.TPM = int64(float64(mc.TPM) * tpmRatio * groupConsumeLevelRatio)
  66. return mc
  67. }
  68. var (
  69. ErrRequestRateLimitExceeded = errors.New("request rate limit exceeded, please try again later")
  70. ErrRequestTpmLimitExceeded = errors.New("request tpm limit exceeded, please try again later")
  71. )
  72. const (
  73. XRateLimitLimitRequests = "X-RateLimit-Limit-Requests"
  74. //nolint:gosec
  75. XRateLimitLimitTokens = "X-RateLimit-Limit-Tokens"
  76. XRateLimitRemainingRequests = "X-RateLimit-Remaining-Requests"
  77. //nolint:gosec
  78. XRateLimitRemainingTokens = "X-RateLimit-Remaining-Tokens"
  79. XRateLimitResetRequests = "X-RateLimit-Reset-Requests"
  80. //nolint:gosec
  81. XRateLimitResetTokens = "X-RateLimit-Reset-Tokens"
  82. )
  83. func setRpmHeaders(c *gin.Context, rpm, remainingRequests int64) {
  84. c.Header(XRateLimitLimitRequests, strconv.FormatInt(rpm, 10))
  85. c.Header(XRateLimitRemainingRequests, strconv.FormatInt(remainingRequests, 10))
  86. c.Header(XRateLimitResetRequests, "1m0s")
  87. }
  88. func setTpmHeaders(c *gin.Context, tpm, remainingRequests int64) {
  89. c.Header(XRateLimitLimitTokens, strconv.FormatInt(tpm, 10))
  90. c.Header(XRateLimitRemainingTokens, strconv.FormatInt(remainingRequests, 10))
  91. c.Header(XRateLimitResetTokens, "1m0s")
  92. }
  93. func checkGroupModelRPMAndTPM(
  94. c *gin.Context,
  95. group model.GroupCache,
  96. mc model.ModelConfig,
  97. tokenName string,
  98. ) error {
  99. log := common.GetLogger(c)
  100. groupModelCount, groupModelOverLimitCount, groupModelSecondCount := reqlimit.PushGroupModelRequest(
  101. c.Request.Context(),
  102. group.ID,
  103. mc.Model,
  104. mc.RPM,
  105. )
  106. monitorplugin.UpdateGroupModelRequest(
  107. c,
  108. group,
  109. groupModelCount+groupModelOverLimitCount,
  110. groupModelSecondCount,
  111. )
  112. groupModelTokenCount, groupModelTokenOverLimitCount, groupModelTokenSecondCount := reqlimit.PushGroupModelTokennameRequest(
  113. c.Request.Context(),
  114. group.ID,
  115. mc.Model,
  116. tokenName,
  117. )
  118. monitorplugin.UpdateGroupModelTokennameRequest(
  119. c,
  120. groupModelTokenCount+groupModelTokenOverLimitCount,
  121. groupModelTokenSecondCount,
  122. )
  123. if group.Status != model.GroupStatusInternal &&
  124. mc.RPM > 0 {
  125. log.Data["group_rpm_limit"] = strconv.FormatInt(mc.RPM, 10)
  126. if groupModelCount > mc.RPM {
  127. setRpmHeaders(c, mc.RPM, 0)
  128. return ErrRequestRateLimitExceeded
  129. }
  130. setRpmHeaders(c, mc.RPM, mc.RPM-groupModelCount)
  131. }
  132. groupModelCountTPM, groupModelCountTPS := reqlimit.GetGroupModelTokensRequest(
  133. c.Request.Context(),
  134. group.ID,
  135. mc.Model,
  136. )
  137. monitorplugin.UpdateGroupModelTokensRequest(c, group, groupModelCountTPM, groupModelCountTPS)
  138. groupModelTokenCountTPM, groupModelTokenCountTPS := reqlimit.GetGroupModelTokennameTokensRequest(
  139. c.Request.Context(),
  140. group.ID,
  141. mc.Model,
  142. tokenName,
  143. )
  144. monitorplugin.UpdateGroupModelTokennameTokensRequest(
  145. c,
  146. groupModelTokenCountTPM,
  147. groupModelTokenCountTPS,
  148. )
  149. if group.Status != model.GroupStatusInternal &&
  150. mc.TPM > 0 {
  151. log.Data["group_tpm_limit"] = strconv.FormatInt(mc.TPM, 10)
  152. if groupModelCountTPM >= mc.TPM {
  153. setTpmHeaders(c, mc.TPM, 0)
  154. return ErrRequestTpmLimitExceeded
  155. }
  156. setTpmHeaders(c, mc.TPM, mc.TPM-groupModelCountTPM)
  157. }
  158. return nil
  159. }
  160. type GroupBalanceConsumer struct {
  161. Group string
  162. balance float64
  163. CheckBalance func(amount float64) bool
  164. Consumer balance.PostGroupConsumer
  165. }
  166. func GetGroupBalanceConsumerFromContext(c *gin.Context) *GroupBalanceConsumer {
  167. gbcI, ok := c.Get(GroupBalance)
  168. if ok {
  169. groupBalanceConsumer, ok := gbcI.(*GroupBalanceConsumer)
  170. if !ok {
  171. panic("internal error: group balance consumer unavailable")
  172. }
  173. return groupBalanceConsumer
  174. }
  175. return nil
  176. }
  177. func GetGroupBalanceConsumer(
  178. c *gin.Context,
  179. group model.GroupCache,
  180. ) (*GroupBalanceConsumer, error) {
  181. gbc := GetGroupBalanceConsumerFromContext(c)
  182. if gbc != nil {
  183. return gbc, nil
  184. }
  185. if group.Status == model.GroupStatusInternal {
  186. gbc = &GroupBalanceConsumer{
  187. Group: group.ID,
  188. CheckBalance: func(_ float64) bool {
  189. return true
  190. },
  191. Consumer: nil,
  192. }
  193. } else {
  194. log := common.GetLogger(c)
  195. groupBalance, consumer, err := balance.GetGroupRemainBalance(c.Request.Context(), group)
  196. if err != nil {
  197. return nil, err
  198. }
  199. log.Data["balance"] = strconv.FormatFloat(groupBalance, 'f', -1, 64)
  200. gbc = &GroupBalanceConsumer{
  201. Group: group.ID,
  202. balance: groupBalance,
  203. CheckBalance: func(amount float64) bool {
  204. return groupBalance >= amount
  205. },
  206. Consumer: consumer,
  207. }
  208. }
  209. c.Set(GroupBalance, gbc)
  210. return gbc, nil
  211. }
  212. const (
  213. GroupBalanceNotEnough = "group_balance_not_enough"
  214. )
  215. func checkGroupBalance(c *gin.Context, group model.GroupCache) bool {
  216. gbc, err := GetGroupBalanceConsumer(c, group)
  217. if err != nil {
  218. if errors.Is(err, balance.ErrNoRealNameUsedAmountLimit) {
  219. AbortLogWithMessage(
  220. c,
  221. http.StatusForbidden,
  222. err.Error(),
  223. )
  224. return false
  225. }
  226. notify.ErrorThrottle(
  227. "getGroupBalanceError",
  228. time.Minute*3,
  229. fmt.Sprintf("Get group `%s` balance error", group.ID),
  230. err.Error(),
  231. )
  232. AbortWithMessage(
  233. c,
  234. http.StatusInternalServerError,
  235. fmt.Sprintf("get group `%s` balance error", group.ID),
  236. )
  237. return false
  238. }
  239. if group.Status != model.GroupStatusInternal &&
  240. group.BalanceAlertEnabled &&
  241. !gbc.CheckBalance(group.BalanceAlertThreshold) {
  242. notify.ErrorThrottle(
  243. "groupBalanceAlert:"+group.ID,
  244. time.Minute*30,
  245. fmt.Sprintf("Group `%s` balance below threshold", group.ID),
  246. fmt.Sprintf(
  247. "Group `%s` balance has fallen below the threshold\nCurrent balance: %.2f",
  248. group.ID,
  249. gbc.balance,
  250. ),
  251. )
  252. }
  253. if !gbc.CheckBalance(0) {
  254. AbortLogWithMessage(
  255. c,
  256. http.StatusForbidden,
  257. fmt.Sprintf("group `%s` balance not enough", group.ID),
  258. relaymodel.WithType(GroupBalanceNotEnough),
  259. )
  260. return false
  261. }
  262. return true
  263. }
  264. func NewDistribute(mode mode.Mode) gin.HandlerFunc {
  265. return func(c *gin.Context) {
  266. distribute(c, mode)
  267. }
  268. }
  269. func CheckRelayMode(requestMode, modelMode mode.Mode) bool {
  270. if modelMode == mode.Unknown {
  271. return true
  272. }
  273. switch requestMode {
  274. case mode.ChatCompletions, mode.Completions, mode.Anthropic,
  275. mode.Responses, mode.ResponsesGet, mode.ResponsesDelete, mode.ResponsesCancel, mode.ResponsesInputItems:
  276. return modelMode == mode.ChatCompletions ||
  277. modelMode == mode.Completions ||
  278. modelMode == mode.Anthropic ||
  279. modelMode == mode.Responses ||
  280. modelMode == mode.ResponsesGet ||
  281. modelMode == mode.ResponsesDelete ||
  282. modelMode == mode.ResponsesCancel ||
  283. modelMode == mode.ResponsesInputItems
  284. case mode.ImagesGenerations, mode.ImagesEdits:
  285. return modelMode == mode.ImagesGenerations ||
  286. modelMode == mode.ImagesEdits
  287. case mode.VideoGenerationsJobs, mode.VideoGenerationsGetJobs, mode.VideoGenerationsContent:
  288. return modelMode == mode.VideoGenerationsJobs ||
  289. modelMode == mode.VideoGenerationsGetJobs ||
  290. modelMode == mode.VideoGenerationsContent
  291. default:
  292. return requestMode == modelMode
  293. }
  294. }
  295. func distribute(c *gin.Context, mode mode.Mode) {
  296. c.Set(Mode, mode)
  297. if config.GetDisableServe() {
  298. AbortLogWithMessage(c, http.StatusServiceUnavailable, "service is under maintenance")
  299. return
  300. }
  301. log := common.GetLogger(c)
  302. group := GetGroup(c)
  303. token := GetToken(c)
  304. if !checkGroupBalance(c, group) {
  305. return
  306. }
  307. requestModel, err := getRequestModel(c, mode, group.ID, token.ID)
  308. if err != nil {
  309. AbortLogWithMessage(
  310. c,
  311. http.StatusInternalServerError,
  312. err.Error(),
  313. )
  314. return
  315. }
  316. if requestModel == "" {
  317. AbortLogWithMessage(c, http.StatusBadRequest, "no model provided")
  318. return
  319. }
  320. findModel := token.FindModel(requestModel)
  321. if findModel == "" {
  322. AbortLogWithMessage(
  323. c,
  324. http.StatusNotFound,
  325. fmt.Sprintf(
  326. "The model `%s` does not exist or you do not have access to it.",
  327. requestModel,
  328. ),
  329. )
  330. return
  331. }
  332. SetLogModelFields(log.Data, findModel)
  333. mc, ok := GetModelCaches(c).ModelConfig.GetModelConfig(findModel)
  334. if !ok {
  335. AbortLogWithMessage(
  336. c,
  337. http.StatusNotFound,
  338. fmt.Sprintf(
  339. "The model `%s` does not exist or you do not have access to it.",
  340. findModel,
  341. ),
  342. )
  343. return
  344. }
  345. mc = GetGroupAdjustedModelConfig(group, mc)
  346. c.Set(RequestModel, findModel)
  347. c.Set(ModelConfig, mc)
  348. if !CheckRelayMode(mode, mc.Type) {
  349. AbortLogWithMessage(
  350. c,
  351. http.StatusNotFound,
  352. fmt.Sprintf(
  353. "The model `%s` does not exist on this endpoint.",
  354. findModel,
  355. ),
  356. )
  357. return
  358. }
  359. user, err := getRequestUser(c, mode)
  360. if err != nil {
  361. AbortLogWithMessage(
  362. c,
  363. http.StatusInternalServerError,
  364. err.Error(),
  365. )
  366. return
  367. }
  368. c.Set(RequestUser, user)
  369. metadata, err := getRequestMetadata(c, mode)
  370. if err != nil {
  371. AbortLogWithMessage(
  372. c,
  373. http.StatusInternalServerError,
  374. err.Error(),
  375. )
  376. return
  377. }
  378. c.Set(RequestMetadata, metadata)
  379. if err := checkGroupModelRPMAndTPM(c, group, mc, token.Name); err != nil {
  380. errMsg := err.Error()
  381. consume.AsyncConsume(
  382. nil,
  383. http.StatusTooManyRequests,
  384. time.Time{},
  385. NewMetaByContext(c, nil, mode),
  386. model.Usage{},
  387. model.Price{},
  388. errMsg,
  389. c.ClientIP(),
  390. 0,
  391. nil,
  392. true,
  393. user,
  394. metadata,
  395. )
  396. AbortLogWithMessage(c, http.StatusTooManyRequests, errMsg)
  397. return
  398. }
  399. c.Next()
  400. }
  401. func GetRequestModel(c *gin.Context) string {
  402. return c.GetString(RequestModel)
  403. }
  404. func GetRequestUser(c *gin.Context) string {
  405. return c.GetString(RequestUser)
  406. }
  407. func GetChannelID(c *gin.Context) int {
  408. return c.GetInt(ChannelID)
  409. }
  410. func GetJobID(c *gin.Context) string {
  411. return c.GetString(JobID)
  412. }
  413. func GetGenerationID(c *gin.Context) string {
  414. return c.GetString(GenerationID)
  415. }
  416. func GetResponseID(c *gin.Context) string {
  417. return c.GetString(ResponseID)
  418. }
  419. func GetRequestMetadata(c *gin.Context) map[string]string {
  420. return c.GetStringMapString(RequestMetadata)
  421. }
  422. func GetModelConfig(c *gin.Context) model.ModelConfig {
  423. v, ok := c.MustGet(ModelConfig).(model.ModelConfig)
  424. if !ok {
  425. panic(fmt.Sprintf("model config type error: %T, %v", v, v))
  426. }
  427. return v
  428. }
  429. func NewMetaByContext(c *gin.Context,
  430. channel *model.Channel,
  431. mode mode.Mode,
  432. opts ...meta.Option,
  433. ) *meta.Meta {
  434. requestID := GetRequestID(c)
  435. group := GetGroup(c)
  436. token := GetToken(c)
  437. modelName := GetRequestModel(c)
  438. modelConfig := GetModelConfig(c)
  439. requestAt := GetRequestAt(c)
  440. jobID := GetJobID(c)
  441. generationID := GetGenerationID(c)
  442. responseID := GetResponseID(c)
  443. opts = append(
  444. opts,
  445. meta.WithRequestAt(requestAt),
  446. meta.WithRequestID(requestID),
  447. meta.WithGroup(group),
  448. meta.WithToken(token),
  449. meta.WithEndpoint(c.Request.URL.Path),
  450. meta.WithJobID(jobID),
  451. meta.WithGenerationID(generationID),
  452. meta.WithResponseID(responseID),
  453. )
  454. return meta.NewMeta(
  455. channel,
  456. mode,
  457. modelName,
  458. modelConfig,
  459. opts...,
  460. )
  461. }
  462. // https://platform.openai.com/docs/api-reference/chat
  463. func getRequestModel(c *gin.Context, m mode.Mode, group string, tokenID int) (string, error) {
  464. path := c.Request.URL.Path
  465. switch {
  466. case m == mode.ParsePdf:
  467. query := c.Request.URL.Query()
  468. model := query.Get("model")
  469. if model != "" {
  470. return model, nil
  471. }
  472. fallthrough
  473. case m == mode.AudioTranscription,
  474. m == mode.AudioTranslation,
  475. m == mode.ImagesEdits:
  476. return c.Request.FormValue("model"), nil
  477. case strings.HasPrefix(path, "/v1/engines") && strings.HasSuffix(path, "/embeddings"):
  478. // /engines/:model/embeddings
  479. return c.Param("model"), nil
  480. case m == mode.VideoGenerationsGetJobs:
  481. jobID := c.Param("id")
  482. store, err := model.CacheGetStore(group, tokenID, jobID)
  483. if err != nil {
  484. return "", fmt.Errorf("get request model failed: %w", err)
  485. }
  486. c.Set(JobID, store.ID)
  487. c.Set(ChannelID, store.ChannelID)
  488. return store.Model, nil
  489. case m == mode.VideoGenerationsContent:
  490. generationID := c.Param("id")
  491. store, err := model.CacheGetStore(group, tokenID, generationID)
  492. if err != nil {
  493. return "", fmt.Errorf("get request model failed: %w", err)
  494. }
  495. c.Set(GenerationID, store.ID)
  496. c.Set(ChannelID, store.ChannelID)
  497. return store.Model, nil
  498. case m == mode.ResponsesGet || m == mode.ResponsesDelete ||
  499. m == mode.ResponsesCancel || m == mode.ResponsesInputItems:
  500. responseID := c.Param("response_id")
  501. store, err := model.CacheGetStore(group, tokenID, responseID)
  502. if err != nil {
  503. return "", fmt.Errorf("get request model failed: %w", err)
  504. }
  505. c.Set(ResponseID, store.ID)
  506. c.Set(ChannelID, store.ChannelID)
  507. return store.Model, nil
  508. case m == mode.Responses:
  509. body, err := common.GetRequestBodyReusable(c.Request)
  510. if err != nil {
  511. return "", fmt.Errorf("get request model failed: %w", err)
  512. }
  513. responseID, err := GetPreviousResponseIDFromJSON(body)
  514. if err != nil {
  515. return "", fmt.Errorf("get request previous response id failed: %w", err)
  516. }
  517. modelName, err := GetModelFromJSON(body)
  518. if err != nil {
  519. return "", err
  520. }
  521. if responseID != "" {
  522. store, err := model.CacheGetStore(group, tokenID, responseID)
  523. if err != nil {
  524. return "", fmt.Errorf("get request model failed: %w", err)
  525. }
  526. c.Set(ResponseID, store.ID)
  527. c.Set(ChannelID, store.ChannelID)
  528. }
  529. return modelName, nil
  530. default:
  531. body, err := common.GetRequestBodyReusable(c.Request)
  532. if err != nil {
  533. return "", fmt.Errorf("get request model failed: %w", err)
  534. }
  535. return GetModelFromJSON(body)
  536. }
  537. }
  538. func GetModelFromJSON(body []byte) (string, error) {
  539. node, err := sonic.GetWithOptions(body, ast.SearchOptions{}, "model")
  540. if err != nil {
  541. if errors.Is(err, ast.ErrNotExist) {
  542. return "", nil
  543. }
  544. return "", fmt.Errorf("get request model failed: %w", err)
  545. }
  546. return node.String()
  547. }
  548. func GetPreviousResponseIDFromJSON(body []byte) (string, error) {
  549. node, err := sonic.GetWithOptions(body, ast.SearchOptions{}, "previous_response_id")
  550. if err != nil {
  551. if errors.Is(err, ast.ErrNotExist) {
  552. return "", nil
  553. }
  554. return "", fmt.Errorf("get request model failed: %w", err)
  555. }
  556. return node.String()
  557. }
  558. // https://platform.openai.com/docs/api-reference/chat
  559. func getRequestUser(c *gin.Context, m mode.Mode) (string, error) {
  560. switch m {
  561. case mode.ChatCompletions,
  562. mode.Completions,
  563. mode.Embeddings,
  564. mode.ImagesGenerations,
  565. mode.AudioSpeech,
  566. mode.Rerank,
  567. mode.Anthropic:
  568. body, err := common.GetRequestBodyReusable(c.Request)
  569. if err != nil {
  570. return "", fmt.Errorf("get request model failed: %w", err)
  571. }
  572. return GetRequestUserFromJSON(body)
  573. default:
  574. return "", nil
  575. }
  576. }
  577. func GetRequestUserFromJSON(body []byte) (string, error) {
  578. node, err := sonic.GetWithOptions(body, ast.SearchOptions{}, "user")
  579. if err != nil {
  580. if errors.Is(err, ast.ErrNotExist) {
  581. return "", nil
  582. }
  583. return "", fmt.Errorf("get request user failed: %w", err)
  584. }
  585. if node.Exists() {
  586. return node.String()
  587. }
  588. return "", nil
  589. }
  590. func getRequestMetadata(c *gin.Context, m mode.Mode) (map[string]string, error) {
  591. switch m {
  592. case mode.ChatCompletions,
  593. mode.Completions,
  594. mode.Embeddings,
  595. mode.ImagesGenerations,
  596. mode.AudioSpeech,
  597. mode.Rerank,
  598. mode.Anthropic:
  599. body, err := common.GetRequestBodyReusable(c.Request)
  600. if err != nil {
  601. return nil, fmt.Errorf("get request metadata failed: %w", err)
  602. }
  603. return GetRequestMetadataFromJSON(body)
  604. default:
  605. return nil, nil
  606. }
  607. }
  608. type RequestWithMetadata struct {
  609. Metadata map[string]string `json:"metadata,omitempty"`
  610. }
  611. func GetRequestMetadataFromJSON(body []byte) (map[string]string, error) {
  612. var requestWithMetadata RequestWithMetadata
  613. if err := sonic.Unmarshal(body, &requestWithMetadata); err != nil {
  614. return nil, fmt.Errorf("get request metadata failed: %w", err)
  615. }
  616. return requestWithMetadata.Metadata, nil
  617. }