distributor.go 17 KB


  1. package middleware
  2. import (
  3. "errors"
  4. "fmt"
  5. "net/http"
  6. "strconv"
  7. "strings"
  8. "time"
  9. "github.com/bytedance/sonic"
  10. "github.com/bytedance/sonic/ast"
  11. "github.com/gin-gonic/gin"
  12. "github.com/labring/aiproxy/core/common"
  13. "github.com/labring/aiproxy/core/common/balance"
  14. "github.com/labring/aiproxy/core/common/config"
  15. "github.com/labring/aiproxy/core/common/consume"
  16. "github.com/labring/aiproxy/core/common/notify"
  17. "github.com/labring/aiproxy/core/common/reqlimit"
  18. "github.com/labring/aiproxy/core/model"
  19. "github.com/labring/aiproxy/core/relay/meta"
  20. "github.com/labring/aiproxy/core/relay/mode"
  21. relaymodel "github.com/labring/aiproxy/core/relay/model"
  22. monitorplugin "github.com/labring/aiproxy/core/relay/plugin/monitor"
  23. )
  24. func calculateGroupConsumeLevelRatio(usedAmount float64) float64 {
  25. v := config.GetGroupConsumeLevelRatio()
  26. if len(v) == 0 {
  27. return 1
  28. }
  29. var (
  30. maxConsumeLevel float64 = -1
  31. groupConsumeLevelRatio float64
  32. )
  33. for consumeLevel, ratio := range v {
  34. if usedAmount < consumeLevel {
  35. continue
  36. }
  37. if consumeLevel > maxConsumeLevel {
  38. maxConsumeLevel = consumeLevel
  39. groupConsumeLevelRatio = ratio
  40. }
  41. }
  42. if groupConsumeLevelRatio <= 0 {
  43. groupConsumeLevelRatio = 1
  44. }
  45. return groupConsumeLevelRatio
  46. }
  47. func getGroupPMRatio(group model.GroupCache) (float64, float64) {
  48. groupRPMRatio := group.RPMRatio
  49. if groupRPMRatio <= 0 {
  50. groupRPMRatio = 1
  51. }
  52. groupTPMRatio := group.TPMRatio
  53. if groupTPMRatio <= 0 {
  54. groupTPMRatio = 1
  55. }
  56. return groupRPMRatio, groupTPMRatio
  57. }
  58. func GetGroupAdjustedModelConfig(group model.GroupCache, mc model.ModelConfig) model.ModelConfig {
  59. if groupModelConfig, ok := group.ModelConfigs[mc.Model]; ok {
  60. mc = mc.LoadFromGroupModelConfig(groupModelConfig)
  61. }
  62. rpmRatio, tpmRatio := getGroupPMRatio(group)
  63. groupConsumeLevelRatio := calculateGroupConsumeLevelRatio(group.UsedAmount)
  64. mc.RPM = int64(float64(mc.RPM) * rpmRatio * groupConsumeLevelRatio)
  65. mc.TPM = int64(float64(mc.TPM) * tpmRatio * groupConsumeLevelRatio)
  66. return mc
  67. }
  68. var (
  69. ErrRequestRateLimitExceeded = errors.New("request rate limit exceeded, please try again later")
  70. ErrRequestTpmLimitExceeded = errors.New("request tpm limit exceeded, please try again later")
  71. )
  72. const (
  73. XRateLimitLimitRequests = "X-RateLimit-Limit-Requests"
  74. //nolint:gosec
  75. XRateLimitLimitTokens = "X-RateLimit-Limit-Tokens"
  76. XRateLimitRemainingRequests = "X-RateLimit-Remaining-Requests"
  77. //nolint:gosec
  78. XRateLimitRemainingTokens = "X-RateLimit-Remaining-Tokens"
  79. XRateLimitResetRequests = "X-RateLimit-Reset-Requests"
  80. //nolint:gosec
  81. XRateLimitResetTokens = "X-RateLimit-Reset-Tokens"
  82. )
  83. func setRpmHeaders(c *gin.Context, rpm, remainingRequests int64) {
  84. c.Header(XRateLimitLimitRequests, strconv.FormatInt(rpm, 10))
  85. c.Header(XRateLimitRemainingRequests, strconv.FormatInt(remainingRequests, 10))
  86. c.Header(XRateLimitResetRequests, "1m0s")
  87. }
  88. func setTpmHeaders(c *gin.Context, tpm, remainingRequests int64) {
  89. c.Header(XRateLimitLimitTokens, strconv.FormatInt(tpm, 10))
  90. c.Header(XRateLimitRemainingTokens, strconv.FormatInt(remainingRequests, 10))
  91. c.Header(XRateLimitResetTokens, "1m0s")
  92. }
  93. func checkGroupModelRPMAndTPM(
  94. c *gin.Context,
  95. group model.GroupCache,
  96. mc model.ModelConfig,
  97. tokenName string,
  98. ) error {
  99. log := common.GetLogger(c)
  100. adjustedModelConfig := GetGroupAdjustedModelConfig(group, mc)
  101. groupModelCount, groupModelOverLimitCount, groupModelSecondCount := reqlimit.PushGroupModelRequest(
  102. c.Request.Context(),
  103. group.ID,
  104. mc.Model,
  105. adjustedModelConfig.RPM,
  106. )
  107. monitorplugin.UpdateGroupModelRequest(
  108. c,
  109. group,
  110. groupModelCount+groupModelOverLimitCount,
  111. groupModelSecondCount,
  112. )
  113. groupModelTokenCount, groupModelTokenOverLimitCount, groupModelTokenSecondCount := reqlimit.PushGroupModelTokennameRequest(
  114. c.Request.Context(),
  115. group.ID,
  116. mc.Model,
  117. tokenName,
  118. )
  119. monitorplugin.UpdateGroupModelTokennameRequest(
  120. c,
  121. groupModelTokenCount+groupModelTokenOverLimitCount,
  122. groupModelTokenSecondCount,
  123. )
  124. if group.Status != model.GroupStatusInternal &&
  125. adjustedModelConfig.RPM > 0 {
  126. log.Data["group_rpm_limit"] = strconv.FormatInt(adjustedModelConfig.RPM, 10)
  127. if groupModelCount > adjustedModelConfig.RPM {
  128. setRpmHeaders(c, adjustedModelConfig.RPM, 0)
  129. return ErrRequestRateLimitExceeded
  130. }
  131. setRpmHeaders(c, adjustedModelConfig.RPM, adjustedModelConfig.RPM-groupModelCount)
  132. }
  133. groupModelCountTPM, groupModelCountTPS := reqlimit.GetGroupModelTokensRequest(
  134. c.Request.Context(),
  135. group.ID,
  136. mc.Model,
  137. )
  138. monitorplugin.UpdateGroupModelTokensRequest(c, group, groupModelCountTPM, groupModelCountTPS)
  139. groupModelTokenCountTPM, groupModelTokenCountTPS := reqlimit.GetGroupModelTokennameTokensRequest(
  140. c.Request.Context(),
  141. group.ID,
  142. mc.Model,
  143. tokenName,
  144. )
  145. monitorplugin.UpdateGroupModelTokennameTokensRequest(
  146. c,
  147. groupModelTokenCountTPM,
  148. groupModelTokenCountTPS,
  149. )
  150. if group.Status != model.GroupStatusInternal &&
  151. adjustedModelConfig.TPM > 0 {
  152. log.Data["group_tpm_limit"] = strconv.FormatInt(adjustedModelConfig.TPM, 10)
  153. if groupModelCountTPM >= adjustedModelConfig.TPM {
  154. setTpmHeaders(c, adjustedModelConfig.TPM, 0)
  155. return ErrRequestTpmLimitExceeded
  156. }
  157. setTpmHeaders(c, adjustedModelConfig.TPM, adjustedModelConfig.TPM-groupModelCountTPM)
  158. }
  159. return nil
  160. }
  161. type GroupBalanceConsumer struct {
  162. Group string
  163. balance float64
  164. CheckBalance func(amount float64) bool
  165. Consumer balance.PostGroupConsumer
  166. }
  167. func GetGroupBalanceConsumerFromContext(c *gin.Context) *GroupBalanceConsumer {
  168. gbcI, ok := c.Get(GroupBalance)
  169. if ok {
  170. groupBalanceConsumer, ok := gbcI.(*GroupBalanceConsumer)
  171. if !ok {
  172. panic("internal error: group balance consumer unavailable")
  173. }
  174. return groupBalanceConsumer
  175. }
  176. return nil
  177. }
  178. func GetGroupBalanceConsumer(
  179. c *gin.Context,
  180. group model.GroupCache,
  181. ) (*GroupBalanceConsumer, error) {
  182. gbc := GetGroupBalanceConsumerFromContext(c)
  183. if gbc != nil {
  184. return gbc, nil
  185. }
  186. if group.Status == model.GroupStatusInternal {
  187. gbc = &GroupBalanceConsumer{
  188. Group: group.ID,
  189. CheckBalance: func(_ float64) bool {
  190. return true
  191. },
  192. Consumer: nil,
  193. }
  194. } else {
  195. log := common.GetLogger(c)
  196. groupBalance, consumer, err := balance.GetGroupRemainBalance(c.Request.Context(), group)
  197. if err != nil {
  198. return nil, err
  199. }
  200. log.Data["balance"] = strconv.FormatFloat(groupBalance, 'f', -1, 64)
  201. gbc = &GroupBalanceConsumer{
  202. Group: group.ID,
  203. balance: groupBalance,
  204. CheckBalance: func(amount float64) bool {
  205. return groupBalance >= amount
  206. },
  207. Consumer: consumer,
  208. }
  209. }
  210. c.Set(GroupBalance, gbc)
  211. return gbc, nil
  212. }
  213. const (
  214. GroupBalanceNotEnough = "group_balance_not_enough"
  215. )
  216. func checkGroupBalance(c *gin.Context, group model.GroupCache) bool {
  217. gbc, err := GetGroupBalanceConsumer(c, group)
  218. if err != nil {
  219. if errors.Is(err, balance.ErrNoRealNameUsedAmountLimit) {
  220. AbortLogWithMessage(
  221. c,
  222. http.StatusForbidden,
  223. err.Error(),
  224. )
  225. return false
  226. }
  227. notify.ErrorThrottle(
  228. "getGroupBalanceError",
  229. time.Minute,
  230. fmt.Sprintf("Get group `%s` balance error", group.ID),
  231. err.Error(),
  232. )
  233. AbortWithMessage(
  234. c,
  235. http.StatusInternalServerError,
  236. fmt.Sprintf("get group `%s` balance error", group.ID),
  237. )
  238. return false
  239. }
  240. if group.Status != model.GroupStatusInternal &&
  241. group.BalanceAlertEnabled &&
  242. !gbc.CheckBalance(group.BalanceAlertThreshold) {
  243. notify.ErrorThrottle(
  244. "groupBalanceAlert:"+group.ID,
  245. time.Minute*15,
  246. fmt.Sprintf("Group `%s` balance below threshold", group.ID),
  247. fmt.Sprintf(
  248. "Group `%s` balance has fallen below the threshold\nCurrent balance: %.2f",
  249. group.ID,
  250. gbc.balance,
  251. ),
  252. )
  253. }
  254. if !gbc.CheckBalance(0) {
  255. AbortLogWithMessage(
  256. c,
  257. http.StatusForbidden,
  258. fmt.Sprintf("group `%s` balance not enough", group.ID),
  259. relaymodel.WithType(GroupBalanceNotEnough),
  260. )
  261. return false
  262. }
  263. return true
  264. }
  265. func NewDistribute(mode mode.Mode) gin.HandlerFunc {
  266. return func(c *gin.Context) {
  267. distribute(c, mode)
  268. }
  269. }
  270. func CheckRelayMode(requestMode, modelMode mode.Mode) bool {
  271. if modelMode == mode.Unknown {
  272. return true
  273. }
  274. switch requestMode {
  275. case mode.ChatCompletions, mode.Completions, mode.Anthropic,
  276. mode.Responses, mode.ResponsesGet, mode.ResponsesDelete, mode.ResponsesCancel, mode.ResponsesInputItems:
  277. return modelMode == mode.ChatCompletions ||
  278. modelMode == mode.Completions ||
  279. modelMode == mode.Anthropic ||
  280. modelMode == mode.Responses ||
  281. modelMode == mode.ResponsesGet ||
  282. modelMode == mode.ResponsesDelete ||
  283. modelMode == mode.ResponsesCancel ||
  284. modelMode == mode.ResponsesInputItems
  285. case mode.ImagesGenerations, mode.ImagesEdits:
  286. return modelMode == mode.ImagesGenerations ||
  287. modelMode == mode.ImagesEdits
  288. case mode.VideoGenerationsJobs, mode.VideoGenerationsGetJobs, mode.VideoGenerationsContent:
  289. return modelMode == mode.VideoGenerationsJobs ||
  290. modelMode == mode.VideoGenerationsGetJobs ||
  291. modelMode == mode.VideoGenerationsContent
  292. default:
  293. return requestMode == modelMode
  294. }
  295. }
  296. func distribute(c *gin.Context, mode mode.Mode) {
  297. c.Set(Mode, mode)
  298. if config.GetDisableServe() {
  299. AbortLogWithMessage(c, http.StatusServiceUnavailable, "service is under maintenance")
  300. return
  301. }
  302. log := common.GetLogger(c)
  303. group := GetGroup(c)
  304. token := GetToken(c)
  305. if !checkGroupBalance(c, group) {
  306. return
  307. }
  308. requestModel, err := getRequestModel(c, mode, group.ID, token.ID)
  309. if err != nil {
  310. AbortLogWithMessage(
  311. c,
  312. http.StatusInternalServerError,
  313. err.Error(),
  314. )
  315. return
  316. }
  317. if requestModel == "" {
  318. AbortLogWithMessage(c, http.StatusBadRequest, "no model provided")
  319. return
  320. }
  321. c.Set(RequestModel, requestModel)
  322. SetLogModelFields(log.Data, requestModel)
  323. mc, ok := GetModelCaches(c).ModelConfig.GetModelConfig(requestModel)
  324. if !ok || !token.ContainsModel(requestModel) {
  325. AbortLogWithMessage(
  326. c,
  327. http.StatusNotFound,
  328. fmt.Sprintf(
  329. "The model `%s` does not exist or you do not have access to it.",
  330. requestModel,
  331. ),
  332. )
  333. return
  334. }
  335. c.Set(ModelConfig, mc)
  336. if !CheckRelayMode(mode, mc.Type) {
  337. AbortLogWithMessage(
  338. c,
  339. http.StatusNotFound,
  340. fmt.Sprintf(
  341. "The model `%s` does not exist on this endpoint.",
  342. requestModel,
  343. ),
  344. )
  345. return
  346. }
  347. user, err := getRequestUser(c, mode)
  348. if err != nil {
  349. AbortLogWithMessage(
  350. c,
  351. http.StatusInternalServerError,
  352. err.Error(),
  353. )
  354. return
  355. }
  356. c.Set(RequestUser, user)
  357. metadata, err := getRequestMetadata(c, mode)
  358. if err != nil {
  359. AbortLogWithMessage(
  360. c,
  361. http.StatusInternalServerError,
  362. err.Error(),
  363. )
  364. return
  365. }
  366. c.Set(RequestMetadata, metadata)
  367. if err := checkGroupModelRPMAndTPM(c, group, mc, token.Name); err != nil {
  368. errMsg := err.Error()
  369. consume.AsyncConsume(
  370. nil,
  371. http.StatusTooManyRequests,
  372. time.Time{},
  373. NewMetaByContext(c, nil, mode),
  374. model.Usage{},
  375. model.Price{},
  376. errMsg,
  377. c.ClientIP(),
  378. 0,
  379. nil,
  380. true,
  381. user,
  382. metadata,
  383. )
  384. AbortLogWithMessage(c, http.StatusTooManyRequests, errMsg)
  385. return
  386. }
  387. c.Next()
  388. }
  389. func GetRequestModel(c *gin.Context) string {
  390. return c.GetString(RequestModel)
  391. }
  392. func GetRequestUser(c *gin.Context) string {
  393. return c.GetString(RequestUser)
  394. }
  395. func GetChannelID(c *gin.Context) int {
  396. return c.GetInt(ChannelID)
  397. }
  398. func GetJobID(c *gin.Context) string {
  399. return c.GetString(JobID)
  400. }
  401. func GetGenerationID(c *gin.Context) string {
  402. return c.GetString(GenerationID)
  403. }
  404. func GetResponseID(c *gin.Context) string {
  405. return c.GetString(ResponseID)
  406. }
  407. func GetRequestMetadata(c *gin.Context) map[string]string {
  408. return c.GetStringMapString(RequestMetadata)
  409. }
  410. func GetModelConfig(c *gin.Context) model.ModelConfig {
  411. v, ok := c.MustGet(ModelConfig).(model.ModelConfig)
  412. if !ok {
  413. panic(fmt.Sprintf("model config type error: %T, %v", v, v))
  414. }
  415. return v
  416. }
  417. func NewMetaByContext(c *gin.Context,
  418. channel *model.Channel,
  419. mode mode.Mode,
  420. opts ...meta.Option,
  421. ) *meta.Meta {
  422. requestID := GetRequestID(c)
  423. group := GetGroup(c)
  424. token := GetToken(c)
  425. modelName := GetRequestModel(c)
  426. modelConfig := GetModelConfig(c)
  427. requestAt := GetRequestAt(c)
  428. jobID := GetJobID(c)
  429. generationID := GetGenerationID(c)
  430. responseID := GetResponseID(c)
  431. opts = append(
  432. opts,
  433. meta.WithRequestAt(requestAt),
  434. meta.WithRequestID(requestID),
  435. meta.WithGroup(group),
  436. meta.WithToken(token),
  437. meta.WithEndpoint(c.Request.URL.Path),
  438. meta.WithJobID(jobID),
  439. meta.WithGenerationID(generationID),
  440. meta.WithResponseID(responseID),
  441. )
  442. return meta.NewMeta(
  443. channel,
  444. mode,
  445. modelName,
  446. modelConfig,
  447. opts...,
  448. )
  449. }
  450. // https://platform.openai.com/docs/api-reference/chat
  451. func getRequestModel(c *gin.Context, m mode.Mode, group string, tokenID int) (string, error) {
  452. path := c.Request.URL.Path
  453. switch {
  454. case m == mode.ParsePdf:
  455. query := c.Request.URL.Query()
  456. model := query.Get("model")
  457. if model != "" {
  458. return model, nil
  459. }
  460. fallthrough
  461. case m == mode.AudioTranscription,
  462. m == mode.AudioTranslation,
  463. m == mode.ImagesEdits:
  464. return c.Request.FormValue("model"), nil
  465. case strings.HasPrefix(path, "/v1/engines") && strings.HasSuffix(path, "/embeddings"):
  466. // /engines/:model/embeddings
  467. return c.Param("model"), nil
  468. case m == mode.VideoGenerationsGetJobs:
  469. jobID := c.Param("id")
  470. store, err := model.CacheGetStore(group, tokenID, jobID)
  471. if err != nil {
  472. return "", fmt.Errorf("get request model failed: %w", err)
  473. }
  474. c.Set(JobID, store.ID)
  475. c.Set(ChannelID, store.ChannelID)
  476. return store.Model, nil
  477. case m == mode.VideoGenerationsContent:
  478. generationID := c.Param("id")
  479. store, err := model.CacheGetStore(group, tokenID, generationID)
  480. if err != nil {
  481. return "", fmt.Errorf("get request model failed: %w", err)
  482. }
  483. c.Set(GenerationID, store.ID)
  484. c.Set(ChannelID, store.ChannelID)
  485. return store.Model, nil
  486. case m == mode.ResponsesGet || m == mode.ResponsesDelete ||
  487. m == mode.ResponsesCancel || m == mode.ResponsesInputItems:
  488. responseID := c.Param("response_id")
  489. store, err := model.CacheGetStore(group, tokenID, responseID)
  490. if err != nil {
  491. return "", fmt.Errorf("get request model failed: %w", err)
  492. }
  493. c.Set(ResponseID, store.ID)
  494. c.Set(ChannelID, store.ChannelID)
  495. return store.Model, nil
  496. case m == mode.Responses:
  497. body, err := common.GetRequestBodyReusable(c.Request)
  498. if err != nil {
  499. return "", fmt.Errorf("get request model failed: %w", err)
  500. }
  501. responseID, err := GetPreviousResponseIDFromJSON(body)
  502. if err != nil {
  503. return "", fmt.Errorf("get request previous response id failed: %w", err)
  504. }
  505. modelName, err := GetModelFromJSON(body)
  506. if err != nil {
  507. return "", err
  508. }
  509. if responseID != "" {
  510. store, err := model.CacheGetStore(group, tokenID, responseID)
  511. if err != nil {
  512. return "", fmt.Errorf("get request model failed: %w", err)
  513. }
  514. c.Set(ResponseID, store.ID)
  515. c.Set(ChannelID, store.ChannelID)
  516. }
  517. return modelName, nil
  518. default:
  519. body, err := common.GetRequestBodyReusable(c.Request)
  520. if err != nil {
  521. return "", fmt.Errorf("get request model failed: %w", err)
  522. }
  523. return GetModelFromJSON(body)
  524. }
  525. }
  526. func GetModelFromJSON(body []byte) (string, error) {
  527. node, err := sonic.GetWithOptions(body, ast.SearchOptions{}, "model")
  528. if err != nil {
  529. if errors.Is(err, ast.ErrNotExist) {
  530. return "", nil
  531. }
  532. return "", fmt.Errorf("get request model failed: %w", err)
  533. }
  534. return node.String()
  535. }
  536. func GetPreviousResponseIDFromJSON(body []byte) (string, error) {
  537. node, err := sonic.GetWithOptions(body, ast.SearchOptions{}, "previous_response_id")
  538. if err != nil {
  539. if errors.Is(err, ast.ErrNotExist) {
  540. return "", nil
  541. }
  542. return "", fmt.Errorf("get request model failed: %w", err)
  543. }
  544. return node.String()
  545. }
  546. // https://platform.openai.com/docs/api-reference/chat
  547. func getRequestUser(c *gin.Context, m mode.Mode) (string, error) {
  548. switch m {
  549. case mode.ChatCompletions,
  550. mode.Completions,
  551. mode.Embeddings,
  552. mode.ImagesGenerations,
  553. mode.AudioSpeech,
  554. mode.Rerank,
  555. mode.Anthropic:
  556. body, err := common.GetRequestBodyReusable(c.Request)
  557. if err != nil {
  558. return "", fmt.Errorf("get request model failed: %w", err)
  559. }
  560. return GetRequestUserFromJSON(body)
  561. default:
  562. return "", nil
  563. }
  564. }
  565. func GetRequestUserFromJSON(body []byte) (string, error) {
  566. node, err := sonic.GetWithOptions(body, ast.SearchOptions{}, "user")
  567. if err != nil {
  568. if errors.Is(err, ast.ErrNotExist) {
  569. return "", nil
  570. }
  571. return "", fmt.Errorf("get request user failed: %w", err)
  572. }
  573. if node.Exists() {
  574. return node.String()
  575. }
  576. return "", nil
  577. }
  578. func getRequestMetadata(c *gin.Context, m mode.Mode) (map[string]string, error) {
  579. switch m {
  580. case mode.ChatCompletions,
  581. mode.Completions,
  582. mode.Embeddings,
  583. mode.ImagesGenerations,
  584. mode.AudioSpeech,
  585. mode.Rerank,
  586. mode.Anthropic:
  587. body, err := common.GetRequestBodyReusable(c.Request)
  588. if err != nil {
  589. return nil, fmt.Errorf("get request metadata failed: %w", err)
  590. }
  591. return GetRequestMetadataFromJSON(body)
  592. default:
  593. return nil, nil
  594. }
  595. }
  596. type RequestWithMetadata struct {
  597. Metadata map[string]string `json:"metadata,omitempty"`
  598. }
  599. func GetRequestMetadataFromJSON(body []byte) (map[string]string, error) {
  600. var requestWithMetadata RequestWithMetadata
  601. if err := sonic.Unmarshal(body, &requestWithMetadata); err != nil {
  602. return nil, fmt.Errorf("get request metadata failed: %w", err)
  603. }
  604. return requestWithMetadata.Metadata, nil
  605. }