distributor.go 17 KB


  1. package middleware
  2. import (
  3. "errors"
  4. "fmt"
  5. "net/http"
  6. "strconv"
  7. "strings"
  8. "time"
  9. "github.com/bytedance/sonic"
  10. "github.com/bytedance/sonic/ast"
  11. "github.com/gin-gonic/gin"
  12. "github.com/labring/aiproxy/core/common"
  13. "github.com/labring/aiproxy/core/common/balance"
  14. "github.com/labring/aiproxy/core/common/config"
  15. "github.com/labring/aiproxy/core/common/consume"
  16. "github.com/labring/aiproxy/core/common/notify"
  17. "github.com/labring/aiproxy/core/common/reqlimit"
  18. "github.com/labring/aiproxy/core/model"
  19. "github.com/labring/aiproxy/core/relay/meta"
  20. "github.com/labring/aiproxy/core/relay/mode"
  21. relaymodel "github.com/labring/aiproxy/core/relay/model"
  22. monitorplugin "github.com/labring/aiproxy/core/relay/plugin/monitor"
  23. )
  24. func calculateGroupConsumeLevelRatio(usedAmount float64) float64 {
  25. v := config.GetGroupConsumeLevelRatio()
  26. if len(v) == 0 {
  27. return 1
  28. }
  29. var (
  30. maxConsumeLevel float64 = -1
  31. groupConsumeLevelRatio float64
  32. )
  33. for consumeLevel, ratio := range v {
  34. if usedAmount < consumeLevel {
  35. continue
  36. }
  37. if consumeLevel > maxConsumeLevel {
  38. maxConsumeLevel = consumeLevel
  39. groupConsumeLevelRatio = ratio
  40. }
  41. }
  42. if groupConsumeLevelRatio <= 0 {
  43. groupConsumeLevelRatio = 1
  44. }
  45. return groupConsumeLevelRatio
  46. }
  47. func getGroupPMRatio(group model.GroupCache) (float64, float64) {
  48. groupRPMRatio := group.RPMRatio
  49. if groupRPMRatio <= 0 {
  50. groupRPMRatio = 1
  51. }
  52. groupTPMRatio := group.TPMRatio
  53. if groupTPMRatio <= 0 {
  54. groupTPMRatio = 1
  55. }
  56. return groupRPMRatio, groupTPMRatio
  57. }
  58. func GetGroupAdjustedModelConfig(group model.GroupCache, mc model.ModelConfig) model.ModelConfig {
  59. if groupModelConfig, ok := group.ModelConfigs[mc.Model]; ok {
  60. mc = mc.LoadFromGroupModelConfig(groupModelConfig)
  61. }
  62. rpmRatio, tpmRatio := getGroupPMRatio(group)
  63. groupConsumeLevelRatio := calculateGroupConsumeLevelRatio(group.UsedAmount)
  64. mc.RPM = int64(float64(mc.RPM) * rpmRatio * groupConsumeLevelRatio)
  65. mc.TPM = int64(float64(mc.TPM) * tpmRatio * groupConsumeLevelRatio)
  66. return mc
  67. }
  68. var (
  69. ErrRequestRateLimitExceeded = errors.New("request rate limit exceeded, please try again later")
  70. ErrRequestTpmLimitExceeded = errors.New("request tpm limit exceeded, please try again later")
  71. )
  72. const (
  73. XRateLimitLimitRequests = "X-RateLimit-Limit-Requests"
  74. //nolint:gosec
  75. XRateLimitLimitTokens = "X-RateLimit-Limit-Tokens"
  76. XRateLimitRemainingRequests = "X-RateLimit-Remaining-Requests"
  77. //nolint:gosec
  78. XRateLimitRemainingTokens = "X-RateLimit-Remaining-Tokens"
  79. XRateLimitResetRequests = "X-RateLimit-Reset-Requests"
  80. //nolint:gosec
  81. XRateLimitResetTokens = "X-RateLimit-Reset-Tokens"
  82. )
  83. func setRpmHeaders(c *gin.Context, rpm, remainingRequests int64) {
  84. c.Header(XRateLimitLimitRequests, strconv.FormatInt(rpm, 10))
  85. c.Header(XRateLimitRemainingRequests, strconv.FormatInt(remainingRequests, 10))
  86. c.Header(XRateLimitResetRequests, "1m0s")
  87. }
  88. func setTpmHeaders(c *gin.Context, tpm, remainingRequests int64) {
  89. c.Header(XRateLimitLimitTokens, strconv.FormatInt(tpm, 10))
  90. c.Header(XRateLimitRemainingTokens, strconv.FormatInt(remainingRequests, 10))
  91. c.Header(XRateLimitResetTokens, "1m0s")
  92. }
  93. func checkGroupModelRPMAndTPM(
  94. c *gin.Context,
  95. group model.GroupCache,
  96. mc model.ModelConfig,
  97. tokenName string,
  98. ) error {
  99. log := common.GetLogger(c)
  100. groupModelCount, groupModelOverLimitCount, groupModelSecondCount := reqlimit.PushGroupModelRequest(
  101. c.Request.Context(),
  102. group.ID,
  103. mc.Model,
  104. mc.RPM,
  105. )
  106. monitorplugin.UpdateGroupModelRequest(
  107. c,
  108. group,
  109. groupModelCount+groupModelOverLimitCount,
  110. groupModelSecondCount,
  111. )
  112. groupModelTokenCount, groupModelTokenOverLimitCount, groupModelTokenSecondCount := reqlimit.PushGroupModelTokennameRequest(
  113. c.Request.Context(),
  114. group.ID,
  115. mc.Model,
  116. tokenName,
  117. )
  118. monitorplugin.UpdateGroupModelTokennameRequest(
  119. c,
  120. groupModelTokenCount+groupModelTokenOverLimitCount,
  121. groupModelTokenSecondCount,
  122. )
  123. if group.Status != model.GroupStatusInternal &&
  124. mc.RPM > 0 {
  125. log.Data["group_rpm_limit"] = strconv.FormatInt(mc.RPM, 10)
  126. if groupModelCount > mc.RPM {
  127. setRpmHeaders(c, mc.RPM, 0)
  128. return ErrRequestRateLimitExceeded
  129. }
  130. setRpmHeaders(c, mc.RPM, mc.RPM-groupModelCount)
  131. }
  132. groupModelCountTPM, groupModelCountTPS := reqlimit.GetGroupModelTokensRequest(
  133. c.Request.Context(),
  134. group.ID,
  135. mc.Model,
  136. )
  137. monitorplugin.UpdateGroupModelTokensRequest(c, group, groupModelCountTPM, groupModelCountTPS)
  138. groupModelTokenCountTPM, groupModelTokenCountTPS := reqlimit.GetGroupModelTokennameTokensRequest(
  139. c.Request.Context(),
  140. group.ID,
  141. mc.Model,
  142. tokenName,
  143. )
  144. monitorplugin.UpdateGroupModelTokennameTokensRequest(
  145. c,
  146. groupModelTokenCountTPM,
  147. groupModelTokenCountTPS,
  148. )
  149. if group.Status != model.GroupStatusInternal &&
  150. mc.TPM > 0 {
  151. log.Data["group_tpm_limit"] = strconv.FormatInt(mc.TPM, 10)
  152. if groupModelCountTPM >= mc.TPM {
  153. setTpmHeaders(c, mc.TPM, 0)
  154. return ErrRequestTpmLimitExceeded
  155. }
  156. setTpmHeaders(c, mc.TPM, mc.TPM-groupModelCountTPM)
  157. }
  158. return nil
  159. }
  160. type GroupBalanceConsumer struct {
  161. Group string
  162. balance float64
  163. CheckBalance func(amount float64) bool
  164. Consumer balance.PostGroupConsumer
  165. }
  166. func GetGroupBalanceConsumerFromContext(c *gin.Context) *GroupBalanceConsumer {
  167. gbcI, ok := c.Get(GroupBalance)
  168. if ok {
  169. groupBalanceConsumer, ok := gbcI.(*GroupBalanceConsumer)
  170. if !ok {
  171. panic("internal error: group balance consumer unavailable")
  172. }
  173. return groupBalanceConsumer
  174. }
  175. return nil
  176. }
  177. func GetGroupBalanceConsumer(
  178. c *gin.Context,
  179. group model.GroupCache,
  180. ) (*GroupBalanceConsumer, error) {
  181. gbc := GetGroupBalanceConsumerFromContext(c)
  182. if gbc != nil {
  183. return gbc, nil
  184. }
  185. if group.Status == model.GroupStatusInternal {
  186. gbc = &GroupBalanceConsumer{
  187. Group: group.ID,
  188. CheckBalance: func(_ float64) bool {
  189. return true
  190. },
  191. Consumer: nil,
  192. }
  193. } else {
  194. log := common.GetLogger(c)
  195. groupBalance, consumer, err := balance.GetGroupRemainBalance(c.Request.Context(), group)
  196. if err != nil {
  197. return nil, err
  198. }
  199. log.Data["balance"] = strconv.FormatFloat(groupBalance, 'f', -1, 64)
  200. gbc = &GroupBalanceConsumer{
  201. Group: group.ID,
  202. balance: groupBalance,
  203. CheckBalance: func(amount float64) bool {
  204. return groupBalance >= amount
  205. },
  206. Consumer: consumer,
  207. }
  208. }
  209. c.Set(GroupBalance, gbc)
  210. return gbc, nil
  211. }
  212. const (
  213. GroupBalanceNotEnough = "group_balance_not_enough"
  214. )
  215. func checkGroupBalance(c *gin.Context, group model.GroupCache) bool {
  216. gbc, err := GetGroupBalanceConsumer(c, group)
  217. if err != nil {
  218. if errors.Is(err, balance.ErrNoRealNameUsedAmountLimit) {
  219. AbortLogWithMessage(
  220. c,
  221. http.StatusForbidden,
  222. err.Error(),
  223. )
  224. return false
  225. }
  226. notify.ErrorThrottle(
  227. "getGroupBalanceError",
  228. time.Minute*3,
  229. fmt.Sprintf("Get group `%s` balance error", group.ID),
  230. err.Error(),
  231. )
  232. AbortWithMessage(
  233. c,
  234. http.StatusInternalServerError,
  235. fmt.Sprintf("get group `%s` balance error", group.ID),
  236. )
  237. return false
  238. }
  239. if group.Status != model.GroupStatusInternal &&
  240. group.BalanceAlertEnabled &&
  241. !gbc.CheckBalance(group.BalanceAlertThreshold) {
  242. notify.ErrorThrottle(
  243. "groupBalanceAlert:"+group.ID,
  244. time.Minute*30,
  245. fmt.Sprintf("Group `%s` balance below threshold", group.ID),
  246. fmt.Sprintf(
  247. "Group `%s` balance has fallen below the threshold\nCurrent balance: %.2f",
  248. group.ID,
  249. gbc.balance,
  250. ),
  251. )
  252. }
  253. if !gbc.CheckBalance(0) {
  254. AbortLogWithMessage(
  255. c,
  256. http.StatusForbidden,
  257. fmt.Sprintf("group `%s` balance not enough", group.ID),
  258. relaymodel.WithType(GroupBalanceNotEnough),
  259. )
  260. return false
  261. }
  262. return true
  263. }
  264. func NewDistribute(mode mode.Mode) gin.HandlerFunc {
  265. return func(c *gin.Context) {
  266. distribute(c, mode)
  267. }
  268. }
  269. func CheckRelayMode(requestMode, modelMode mode.Mode) bool {
  270. if modelMode == mode.Unknown {
  271. return true
  272. }
  273. switch requestMode {
  274. case mode.ChatCompletions, mode.Completions, mode.Anthropic, mode.Gemini,
  275. mode.Responses, mode.ResponsesGet, mode.ResponsesDelete, mode.ResponsesCancel, mode.ResponsesInputItems:
  276. return modelMode == mode.ChatCompletions ||
  277. modelMode == mode.Completions ||
  278. modelMode == mode.Anthropic ||
  279. modelMode == mode.Gemini ||
  280. modelMode == mode.Responses ||
  281. modelMode == mode.ResponsesGet ||
  282. modelMode == mode.ResponsesDelete ||
  283. modelMode == mode.ResponsesCancel ||
  284. modelMode == mode.ResponsesInputItems
  285. case mode.ImagesGenerations, mode.ImagesEdits:
  286. return modelMode == mode.ImagesGenerations ||
  287. modelMode == mode.ImagesEdits
  288. case mode.VideoGenerationsJobs, mode.VideoGenerationsGetJobs, mode.VideoGenerationsContent:
  289. return modelMode == mode.VideoGenerationsJobs ||
  290. modelMode == mode.VideoGenerationsGetJobs ||
  291. modelMode == mode.VideoGenerationsContent
  292. default:
  293. return requestMode == modelMode
  294. }
  295. }
  296. func distribute(c *gin.Context, mode mode.Mode) {
  297. c.Set(Mode, mode)
  298. if config.GetDisableServe() {
  299. AbortLogWithMessage(c, http.StatusServiceUnavailable, "service is under maintenance")
  300. return
  301. }
  302. log := common.GetLogger(c)
  303. group := GetGroup(c)
  304. token := GetToken(c)
  305. if !checkGroupBalance(c, group) {
  306. return
  307. }
  308. requestModel, err := getRequestModel(c, mode, group.ID, token.ID)
  309. if err != nil {
  310. AbortLogWithMessage(
  311. c,
  312. http.StatusInternalServerError,
  313. err.Error(),
  314. )
  315. return
  316. }
  317. if requestModel == "" {
  318. AbortLogWithMessage(c, http.StatusBadRequest, "no model provided")
  319. return
  320. }
  321. findModel := token.FindModel(requestModel)
  322. if findModel == "" {
  323. AbortLogWithMessage(
  324. c,
  325. http.StatusNotFound,
  326. fmt.Sprintf(
  327. "The model `%s` does not exist or you do not have access to it.",
  328. requestModel,
  329. ),
  330. )
  331. return
  332. }
  333. SetLogModelFields(log.Data, findModel)
  334. mc, ok := GetModelCaches(c).ModelConfig.GetModelConfig(findModel)
  335. if !ok {
  336. AbortLogWithMessage(
  337. c,
  338. http.StatusNotFound,
  339. fmt.Sprintf(
  340. "The model `%s` does not exist or you do not have access to it.",
  341. findModel,
  342. ),
  343. )
  344. return
  345. }
  346. mc = GetGroupAdjustedModelConfig(group, mc)
  347. c.Set(RequestModel, findModel)
  348. c.Set(ModelConfig, mc)
  349. if !CheckRelayMode(mode, mc.Type) {
  350. AbortLogWithMessage(
  351. c,
  352. http.StatusNotFound,
  353. fmt.Sprintf(
  354. "The model `%s` does not exist on this endpoint.",
  355. findModel,
  356. ),
  357. )
  358. return
  359. }
  360. user, err := getRequestUser(c, mode)
  361. if err != nil {
  362. AbortLogWithMessage(
  363. c,
  364. http.StatusInternalServerError,
  365. err.Error(),
  366. )
  367. return
  368. }
  369. c.Set(RequestUser, user)
  370. metadata, err := getRequestMetadata(c, mode)
  371. if err != nil {
  372. AbortLogWithMessage(
  373. c,
  374. http.StatusInternalServerError,
  375. err.Error(),
  376. )
  377. return
  378. }
  379. c.Set(RequestMetadata, metadata)
  380. if err := checkGroupModelRPMAndTPM(c, group, mc, token.Name); err != nil {
  381. errMsg := err.Error()
  382. consume.Summary(
  383. http.StatusTooManyRequests,
  384. time.Time{},
  385. NewMetaByContext(c, nil, mode),
  386. model.Usage{},
  387. model.Price{},
  388. true,
  389. )
  390. AbortLogWithMessage(c, http.StatusTooManyRequests, errMsg)
  391. return
  392. }
  393. c.Next()
  394. }
  395. func GetRequestModel(c *gin.Context) string {
  396. return c.GetString(RequestModel)
  397. }
  398. func GetRequestUser(c *gin.Context) string {
  399. return c.GetString(RequestUser)
  400. }
  401. func GetChannelID(c *gin.Context) int {
  402. return c.GetInt(ChannelID)
  403. }
  404. func GetJobID(c *gin.Context) string {
  405. return c.GetString(JobID)
  406. }
  407. func GetGenerationID(c *gin.Context) string {
  408. return c.GetString(GenerationID)
  409. }
  410. func GetResponseID(c *gin.Context) string {
  411. return c.GetString(ResponseID)
  412. }
  413. func GetRequestMetadata(c *gin.Context) map[string]string {
  414. return c.GetStringMapString(RequestMetadata)
  415. }
  416. func GetModelConfig(c *gin.Context) model.ModelConfig {
  417. v, ok := c.MustGet(ModelConfig).(model.ModelConfig)
  418. if !ok {
  419. panic(fmt.Sprintf("model config type error: %T, %v", v, v))
  420. }
  421. return v
  422. }
  423. func NewMetaByContext(c *gin.Context,
  424. channel *model.Channel,
  425. mode mode.Mode,
  426. opts ...meta.Option,
  427. ) *meta.Meta {
  428. requestID := GetRequestID(c)
  429. group := GetGroup(c)
  430. token := GetToken(c)
  431. modelName := GetRequestModel(c)
  432. modelConfig := GetModelConfig(c)
  433. requestAt := GetRequestAt(c)
  434. jobID := GetJobID(c)
  435. generationID := GetGenerationID(c)
  436. responseID := GetResponseID(c)
  437. opts = append(
  438. opts,
  439. meta.WithRequestAt(requestAt),
  440. meta.WithRequestID(requestID),
  441. meta.WithGroup(group),
  442. meta.WithToken(token),
  443. meta.WithEndpoint(c.Request.URL.Path),
  444. meta.WithJobID(jobID),
  445. meta.WithGenerationID(generationID),
  446. meta.WithResponseID(responseID),
  447. )
  448. return meta.NewMeta(
  449. channel,
  450. mode,
  451. modelName,
  452. modelConfig,
  453. opts...,
  454. )
  455. }
  456. // https://platform.openai.com/docs/api-reference/chat
  457. func getRequestModel(c *gin.Context, m mode.Mode, group string, tokenID int) (string, error) {
  458. path := c.Request.URL.Path
  459. switch {
  460. case m == mode.ParsePdf:
  461. query := c.Request.URL.Query()
  462. model := query.Get("model")
  463. if model != "" {
  464. return model, nil
  465. }
  466. fallthrough
  467. case m == mode.AudioTranscription,
  468. m == mode.AudioTranslation,
  469. m == mode.ImagesEdits:
  470. return c.Request.FormValue("model"), nil
  471. case strings.HasPrefix(path, "/v1/engines") && strings.HasSuffix(path, "/embeddings"):
  472. // /engines/:model/embeddings
  473. return c.Param("model"), nil
  474. case m == mode.VideoGenerationsGetJobs:
  475. jobID := c.Param("id")
  476. store, err := model.CacheGetStore(group, tokenID, jobID)
  477. if err != nil {
  478. return "", fmt.Errorf("get request model failed: %w", err)
  479. }
  480. c.Set(JobID, store.ID)
  481. c.Set(ChannelID, store.ChannelID)
  482. return store.Model, nil
  483. case m == mode.VideoGenerationsContent:
  484. generationID := c.Param("id")
  485. store, err := model.CacheGetStore(group, tokenID, generationID)
  486. if err != nil {
  487. return "", fmt.Errorf("get request model failed: %w", err)
  488. }
  489. c.Set(GenerationID, store.ID)
  490. c.Set(ChannelID, store.ChannelID)
  491. return store.Model, nil
  492. case m == mode.ResponsesGet || m == mode.ResponsesDelete ||
  493. m == mode.ResponsesCancel || m == mode.ResponsesInputItems:
  494. responseID := c.Param("response_id")
  495. store, err := model.CacheGetStore(group, tokenID, responseID)
  496. if err != nil {
  497. return "", fmt.Errorf("get request model failed: %w", err)
  498. }
  499. c.Set(ResponseID, store.ID)
  500. c.Set(ChannelID, store.ChannelID)
  501. return store.Model, nil
  502. case m == mode.Responses:
  503. body, err := common.GetRequestBodyReusable(c.Request)
  504. if err != nil {
  505. return "", fmt.Errorf("get request model failed: %w", err)
  506. }
  507. responseID, err := GetPreviousResponseIDFromJSON(body)
  508. if err != nil {
  509. return "", fmt.Errorf("get request previous response id failed: %w", err)
  510. }
  511. modelName, err := GetModelFromJSON(body)
  512. if err != nil {
  513. return "", err
  514. }
  515. if responseID != "" {
  516. store, err := model.CacheGetStore(group, tokenID, responseID)
  517. if err != nil {
  518. return "", fmt.Errorf("get request model failed: %w", err)
  519. }
  520. c.Set(ResponseID, store.ID)
  521. c.Set(ChannelID, store.ChannelID)
  522. }
  523. return modelName, nil
  524. case m == mode.Gemini:
  525. modelName := strings.TrimPrefix(c.Param("model"), "/")
  526. modelName, _, _ = strings.Cut(modelName, ":")
  527. return modelName, nil
  528. default:
  529. body, err := common.GetRequestBodyReusable(c.Request)
  530. if err != nil {
  531. return "", fmt.Errorf("get request model failed: %w", err)
  532. }
  533. return GetModelFromJSON(body)
  534. }
  535. }
  536. func GetModelFromJSON(body []byte) (string, error) {
  537. node, err := sonic.GetWithOptions(body, ast.SearchOptions{}, "model")
  538. if err != nil {
  539. if errors.Is(err, ast.ErrNotExist) {
  540. return "", nil
  541. }
  542. return "", fmt.Errorf("get request model failed: %w", err)
  543. }
  544. return node.String()
  545. }
  546. func GetPreviousResponseIDFromJSON(body []byte) (string, error) {
  547. node, err := sonic.GetWithOptions(body, ast.SearchOptions{}, "previous_response_id")
  548. if err != nil {
  549. if errors.Is(err, ast.ErrNotExist) {
  550. return "", nil
  551. }
  552. return "", fmt.Errorf("get request model failed: %w", err)
  553. }
  554. return node.String()
  555. }
  556. // https://platform.openai.com/docs/api-reference/chat
  557. func getRequestUser(c *gin.Context, m mode.Mode) (string, error) {
  558. switch m {
  559. case mode.ChatCompletions,
  560. mode.Completions,
  561. mode.Embeddings,
  562. mode.ImagesGenerations,
  563. mode.AudioSpeech,
  564. mode.Rerank,
  565. mode.Anthropic,
  566. mode.Gemini:
  567. body, err := common.GetRequestBodyReusable(c.Request)
  568. if err != nil {
  569. return "", fmt.Errorf("get request model failed: %w", err)
  570. }
  571. return GetRequestUserFromJSON(body)
  572. default:
  573. return "", nil
  574. }
  575. }
  576. func GetRequestUserFromJSON(body []byte) (string, error) {
  577. node, err := sonic.GetWithOptions(body, ast.SearchOptions{}, "user")
  578. if err != nil {
  579. if errors.Is(err, ast.ErrNotExist) {
  580. return "", nil
  581. }
  582. return "", fmt.Errorf("get request user failed: %w", err)
  583. }
  584. if node.Exists() {
  585. return node.String()
  586. }
  587. return "", nil
  588. }
  589. func getRequestMetadata(c *gin.Context, m mode.Mode) (map[string]string, error) {
  590. switch m {
  591. case mode.ChatCompletions,
  592. mode.Completions,
  593. mode.Embeddings,
  594. mode.ImagesGenerations,
  595. mode.AudioSpeech,
  596. mode.Rerank,
  597. mode.Anthropic,
  598. mode.Gemini:
  599. body, err := common.GetRequestBodyReusable(c.Request)
  600. if err != nil {
  601. return nil, fmt.Errorf("get request metadata failed: %w", err)
  602. }
  603. return GetRequestMetadataFromJSON(body)
  604. default:
  605. return nil, nil
  606. }
  607. }
  608. type RequestWithMetadata struct {
  609. Metadata map[string]string `json:"metadata,omitempty"`
  610. }
  611. func GetRequestMetadataFromJSON(body []byte) (map[string]string, error) {
  612. var requestWithMetadata RequestWithMetadata
  613. if err := sonic.Unmarshal(body, &requestWithMetadata); err != nil {
  614. return nil, fmt.Errorf("get request metadata failed: %w", err)
  615. }
  616. return requestWithMetadata.Metadata, nil
  617. }