distributor.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743
  1. package middleware
  2. import (
  3. "errors"
  4. "fmt"
  5. "net/http"
  6. "strconv"
  7. "strings"
  8. "time"
  9. "github.com/bytedance/sonic"
  10. "github.com/bytedance/sonic/ast"
  11. "github.com/gin-gonic/gin"
  12. "github.com/labring/aiproxy/core/common"
  13. "github.com/labring/aiproxy/core/common/balance"
  14. "github.com/labring/aiproxy/core/common/config"
  15. "github.com/labring/aiproxy/core/common/consume"
  16. "github.com/labring/aiproxy/core/common/notify"
  17. "github.com/labring/aiproxy/core/common/reqlimit"
  18. "github.com/labring/aiproxy/core/model"
  19. "github.com/labring/aiproxy/core/relay/meta"
  20. "github.com/labring/aiproxy/core/relay/mode"
  21. relaymodel "github.com/labring/aiproxy/core/relay/model"
  22. monitorplugin "github.com/labring/aiproxy/core/relay/plugin/monitor"
  23. )
  24. func calculateGroupConsumeLevelRatio(usedAmount float64) float64 {
  25. v := config.GetGroupConsumeLevelRatio()
  26. if len(v) == 0 {
  27. return 1
  28. }
  29. var (
  30. maxConsumeLevel float64 = -1
  31. groupConsumeLevelRatio float64
  32. )
  33. for consumeLevel, ratio := range v {
  34. if usedAmount < consumeLevel {
  35. continue
  36. }
  37. if consumeLevel > maxConsumeLevel {
  38. maxConsumeLevel = consumeLevel
  39. groupConsumeLevelRatio = ratio
  40. }
  41. }
  42. if groupConsumeLevelRatio <= 0 {
  43. groupConsumeLevelRatio = 1
  44. }
  45. return groupConsumeLevelRatio
  46. }
  47. func getGroupPMRatio(group model.GroupCache) (float64, float64) {
  48. groupRPMRatio := group.RPMRatio
  49. if groupRPMRatio <= 0 {
  50. groupRPMRatio = 1
  51. }
  52. groupTPMRatio := group.TPMRatio
  53. if groupTPMRatio <= 0 {
  54. groupTPMRatio = 1
  55. }
  56. return groupRPMRatio, groupTPMRatio
  57. }
  58. func GetGroupAdjustedModelConfig(group model.GroupCache, mc model.ModelConfig) model.ModelConfig {
  59. if groupModelConfig, ok := group.ModelConfigs[mc.Model]; ok {
  60. mc = mc.LoadFromGroupModelConfig(groupModelConfig)
  61. }
  62. rpmRatio, tpmRatio := getGroupPMRatio(group)
  63. groupConsumeLevelRatio := calculateGroupConsumeLevelRatio(group.UsedAmount)
  64. mc.RPM = int64(float64(mc.RPM) * rpmRatio * groupConsumeLevelRatio)
  65. mc.TPM = int64(float64(mc.TPM) * tpmRatio * groupConsumeLevelRatio)
  66. return mc
  67. }
  68. var (
  69. ErrRequestRateLimitExceeded = errors.New("request rate limit exceeded, please try again later")
  70. ErrRequestTpmLimitExceeded = errors.New("request tpm limit exceeded, please try again later")
  71. )
  72. const (
  73. XRateLimitLimitRequests = "X-RateLimit-Limit-Requests"
  74. //nolint:gosec
  75. XRateLimitLimitTokens = "X-RateLimit-Limit-Tokens"
  76. XRateLimitRemainingRequests = "X-RateLimit-Remaining-Requests"
  77. //nolint:gosec
  78. XRateLimitRemainingTokens = "X-RateLimit-Remaining-Tokens"
  79. XRateLimitResetRequests = "X-RateLimit-Reset-Requests"
  80. //nolint:gosec
  81. XRateLimitResetTokens = "X-RateLimit-Reset-Tokens"
  82. )
  83. func setRpmHeaders(c *gin.Context, rpm, remainingRequests int64) {
  84. c.Header(XRateLimitLimitRequests, strconv.FormatInt(rpm, 10))
  85. c.Header(XRateLimitRemainingRequests, strconv.FormatInt(remainingRequests, 10))
  86. c.Header(XRateLimitResetRequests, "1m0s")
  87. }
  88. func setTpmHeaders(c *gin.Context, tpm, remainingRequests int64) {
  89. c.Header(XRateLimitLimitTokens, strconv.FormatInt(tpm, 10))
  90. c.Header(XRateLimitRemainingTokens, strconv.FormatInt(remainingRequests, 10))
  91. c.Header(XRateLimitResetTokens, "1m0s")
  92. }
  93. func checkGroupModelRPMAndTPM(
  94. c *gin.Context,
  95. group model.GroupCache,
  96. mc model.ModelConfig,
  97. tokenName string,
  98. ) error {
  99. log := common.GetLogger(c)
  100. groupModelCount, groupModelOverLimitCount, groupModelSecondCount := reqlimit.PushGroupModelRequest(
  101. c.Request.Context(),
  102. group.ID,
  103. mc.Model,
  104. mc.RPM,
  105. )
  106. monitorplugin.UpdateGroupModelRequest(
  107. c,
  108. group,
  109. groupModelCount+groupModelOverLimitCount,
  110. groupModelSecondCount,
  111. )
  112. groupModelTokenCount, groupModelTokenOverLimitCount, groupModelTokenSecondCount := reqlimit.PushGroupModelTokennameRequest(
  113. c.Request.Context(),
  114. group.ID,
  115. mc.Model,
  116. tokenName,
  117. )
  118. monitorplugin.UpdateGroupModelTokennameRequest(
  119. c,
  120. groupModelTokenCount+groupModelTokenOverLimitCount,
  121. groupModelTokenSecondCount,
  122. )
  123. if group.Status != model.GroupStatusInternal &&
  124. mc.RPM > 0 {
  125. log.Data["group_rpm_limit"] = strconv.FormatInt(mc.RPM, 10)
  126. if groupModelCount > mc.RPM {
  127. setRpmHeaders(c, mc.RPM, 0)
  128. return ErrRequestRateLimitExceeded
  129. }
  130. setRpmHeaders(c, mc.RPM, mc.RPM-groupModelCount)
  131. }
  132. groupModelCountTPM, groupModelCountTPS := reqlimit.GetGroupModelTokensRequest(
  133. c.Request.Context(),
  134. group.ID,
  135. mc.Model,
  136. )
  137. monitorplugin.UpdateGroupModelTokensRequest(c, group, groupModelCountTPM, groupModelCountTPS)
  138. groupModelTokenCountTPM, groupModelTokenCountTPS := reqlimit.GetGroupModelTokennameTokensRequest(
  139. c.Request.Context(),
  140. group.ID,
  141. mc.Model,
  142. tokenName,
  143. )
  144. monitorplugin.UpdateGroupModelTokennameTokensRequest(
  145. c,
  146. groupModelTokenCountTPM,
  147. groupModelTokenCountTPS,
  148. )
  149. if group.Status != model.GroupStatusInternal &&
  150. mc.TPM > 0 {
  151. log.Data["group_tpm_limit"] = strconv.FormatInt(mc.TPM, 10)
  152. if groupModelCountTPM >= mc.TPM {
  153. setTpmHeaders(c, mc.TPM, 0)
  154. return ErrRequestTpmLimitExceeded
  155. }
  156. setTpmHeaders(c, mc.TPM, mc.TPM-groupModelCountTPM)
  157. }
  158. return nil
  159. }
  160. type GroupBalanceConsumer struct {
  161. Group string
  162. balance float64
  163. CheckBalance func(amount float64) bool
  164. Consumer balance.PostGroupConsumer
  165. }
  166. func GetGroupBalanceConsumerFromContext(c *gin.Context) *GroupBalanceConsumer {
  167. gbcI, ok := c.Get(GroupBalance)
  168. if ok {
  169. groupBalanceConsumer, ok := gbcI.(*GroupBalanceConsumer)
  170. if !ok {
  171. panic("internal error: group balance consumer unavailable")
  172. }
  173. return groupBalanceConsumer
  174. }
  175. return nil
  176. }
  177. func GetGroupBalanceConsumer(
  178. c *gin.Context,
  179. group model.GroupCache,
  180. ) (*GroupBalanceConsumer, error) {
  181. gbc := GetGroupBalanceConsumerFromContext(c)
  182. if gbc != nil {
  183. return gbc, nil
  184. }
  185. if group.Status == model.GroupStatusInternal {
  186. gbc = &GroupBalanceConsumer{
  187. Group: group.ID,
  188. CheckBalance: func(_ float64) bool {
  189. return true
  190. },
  191. Consumer: nil,
  192. }
  193. } else {
  194. log := common.GetLogger(c)
  195. groupBalance, consumer, err := balance.GetGroupRemainBalance(c.Request.Context(), group)
  196. if err != nil {
  197. return nil, err
  198. }
  199. log.Data["balance"] = strconv.FormatFloat(groupBalance, 'f', -1, 64)
  200. gbc = &GroupBalanceConsumer{
  201. Group: group.ID,
  202. balance: groupBalance,
  203. CheckBalance: func(amount float64) bool {
  204. return groupBalance >= amount
  205. },
  206. Consumer: consumer,
  207. }
  208. }
  209. c.Set(GroupBalance, gbc)
  210. return gbc, nil
  211. }
  212. const (
  213. GroupBalanceNotEnough = "group_balance_not_enough"
  214. )
  215. func checkGroupBalance(c *gin.Context, group model.GroupCache) bool {
  216. gbc, err := GetGroupBalanceConsumer(c, group)
  217. if err != nil {
  218. if errors.Is(err, balance.ErrNoRealNameUsedAmountLimit) {
  219. AbortLogWithMessage(
  220. c,
  221. http.StatusForbidden,
  222. err.Error(),
  223. )
  224. return false
  225. }
  226. notify.ErrorThrottle(
  227. "getGroupBalanceError",
  228. time.Minute*3,
  229. fmt.Sprintf("Get group `%s` balance error", group.ID),
  230. err.Error(),
  231. )
  232. AbortWithMessage(
  233. c,
  234. http.StatusInternalServerError,
  235. fmt.Sprintf("get group `%s` balance error", group.ID),
  236. )
  237. return false
  238. }
  239. if group.Status != model.GroupStatusInternal &&
  240. group.BalanceAlertEnabled &&
  241. !gbc.CheckBalance(group.BalanceAlertThreshold) {
  242. notify.ErrorThrottle(
  243. "groupBalanceAlert:"+group.ID,
  244. time.Minute*30,
  245. fmt.Sprintf("Group `%s` balance below threshold", group.ID),
  246. fmt.Sprintf(
  247. "Group `%s` balance has fallen below the threshold\nCurrent balance: %.2f",
  248. group.ID,
  249. gbc.balance,
  250. ),
  251. )
  252. }
  253. if !gbc.CheckBalance(0) {
  254. AbortLogWithMessage(
  255. c,
  256. http.StatusForbidden,
  257. fmt.Sprintf("group `%s` balance not enough", group.ID),
  258. relaymodel.WithType(GroupBalanceNotEnough),
  259. )
  260. return false
  261. }
  262. return true
  263. }
  264. func NewDistribute(mode mode.Mode) gin.HandlerFunc {
  265. return func(c *gin.Context) {
  266. distribute(c, mode)
  267. }
  268. }
  269. func CheckRelayMode(requestMode, modelMode mode.Mode) bool {
  270. if modelMode == mode.Unknown {
  271. return true
  272. }
  273. switch requestMode {
  274. case mode.ChatCompletions, mode.Completions, mode.Anthropic, mode.Gemini,
  275. mode.Responses, mode.ResponsesGet, mode.ResponsesDelete, mode.ResponsesCancel, mode.ResponsesInputItems:
  276. return modelMode == mode.ChatCompletions ||
  277. modelMode == mode.Completions ||
  278. modelMode == mode.Anthropic ||
  279. modelMode == mode.Gemini ||
  280. modelMode == mode.Responses ||
  281. modelMode == mode.ResponsesGet ||
  282. modelMode == mode.ResponsesDelete ||
  283. modelMode == mode.ResponsesCancel ||
  284. modelMode == mode.ResponsesInputItems
  285. case mode.ImagesGenerations, mode.ImagesEdits:
  286. return modelMode == mode.ImagesGenerations ||
  287. modelMode == mode.ImagesEdits
  288. case mode.VideoGenerationsJobs, mode.VideoGenerationsGetJobs, mode.VideoGenerationsContent:
  289. return modelMode == mode.VideoGenerationsJobs ||
  290. modelMode == mode.VideoGenerationsGetJobs ||
  291. modelMode == mode.VideoGenerationsContent
  292. default:
  293. return requestMode == modelMode
  294. }
  295. }
  296. func distribute(c *gin.Context, mode mode.Mode) {
  297. c.Set(Mode, mode)
  298. if config.GetDisableServe() {
  299. AbortLogWithMessage(c, http.StatusServiceUnavailable, "service is under maintenance")
  300. return
  301. }
  302. log := common.GetLogger(c)
  303. group := GetGroup(c)
  304. token := GetToken(c)
  305. if !checkGroupBalance(c, group) {
  306. return
  307. }
  308. requestModel, err := getRequestModel(c, mode, group.ID, token.ID)
  309. if err != nil {
  310. AbortLogWithMessage(
  311. c,
  312. http.StatusInternalServerError,
  313. err.Error(),
  314. )
  315. return
  316. }
  317. if requestModel == "" {
  318. AbortLogWithMessage(c, http.StatusBadRequest, "no model provided")
  319. return
  320. }
  321. findModel := token.FindModel(requestModel)
  322. if findModel == "" {
  323. AbortLogWithMessage(
  324. c,
  325. http.StatusNotFound,
  326. fmt.Sprintf(
  327. "The model `%s` does not exist or you do not have access to it.",
  328. requestModel,
  329. ),
  330. )
  331. return
  332. }
  333. SetLogModelFields(log.Data, findModel)
  334. mc, ok := GetModelCaches(c).ModelConfig.GetModelConfig(findModel)
  335. if !ok {
  336. AbortLogWithMessage(
  337. c,
  338. http.StatusNotFound,
  339. fmt.Sprintf(
  340. "The model `%s` does not exist or you do not have access to it.",
  341. findModel,
  342. ),
  343. )
  344. return
  345. }
  346. mc = GetGroupAdjustedModelConfig(group, mc)
  347. c.Set(RequestModel, findModel)
  348. c.Set(ModelConfig, mc)
  349. if !CheckRelayMode(mode, mc.Type) {
  350. AbortLogWithMessage(
  351. c,
  352. http.StatusNotFound,
  353. fmt.Sprintf(
  354. "The model `%s` does not exist on this endpoint.",
  355. findModel,
  356. ),
  357. )
  358. return
  359. }
  360. user, err := getRequestUser(c, mode)
  361. if err != nil {
  362. AbortLogWithMessage(
  363. c,
  364. http.StatusInternalServerError,
  365. err.Error(),
  366. )
  367. return
  368. }
  369. c.Set(RequestUser, user)
  370. metadata, err := getRequestMetadata(c, mode)
  371. if err != nil {
  372. AbortLogWithMessage(
  373. c,
  374. http.StatusInternalServerError,
  375. err.Error(),
  376. )
  377. return
  378. }
  379. c.Set(RequestMetadata, metadata)
  380. if err := checkGroupModelRPMAndTPM(c, group, mc, token.Name); err != nil {
  381. errMsg := err.Error()
  382. consume.Summary(
  383. http.StatusTooManyRequests,
  384. time.Time{},
  385. NewMetaByContext(c, nil, mode),
  386. model.Usage{},
  387. model.Price{},
  388. true,
  389. )
  390. AbortLogWithMessage(c, http.StatusTooManyRequests, errMsg)
  391. return
  392. }
  393. c.Next()
  394. }
  395. func GetRequestModel(c *gin.Context) string {
  396. return c.GetString(RequestModel)
  397. }
  398. func GetRequestUser(c *gin.Context) string {
  399. return c.GetString(RequestUser)
  400. }
  401. func GetChannelID(c *gin.Context) int {
  402. return c.GetInt(ChannelID)
  403. }
  404. func GetJobID(c *gin.Context) string {
  405. return c.GetString(JobID)
  406. }
  407. func GetGenerationID(c *gin.Context) string {
  408. return c.GetString(GenerationID)
  409. }
  410. func GetResponseID(c *gin.Context) string {
  411. return c.GetString(ResponseID)
  412. }
  413. func GetRequestMetadata(c *gin.Context) map[string]string {
  414. return c.GetStringMapString(RequestMetadata)
  415. }
  416. func GetModelConfig(c *gin.Context) model.ModelConfig {
  417. v, ok := c.MustGet(ModelConfig).(model.ModelConfig)
  418. if !ok {
  419. panic(fmt.Sprintf("model config type error: %T, %v", v, v))
  420. }
  421. return v
  422. }
  423. func NewMetaByContext(c *gin.Context,
  424. channel *model.Channel,
  425. mode mode.Mode,
  426. opts ...meta.Option,
  427. ) *meta.Meta {
  428. requestID := GetRequestID(c)
  429. group := GetGroup(c)
  430. token := GetToken(c)
  431. modelName := GetRequestModel(c)
  432. modelConfig := GetModelConfig(c)
  433. requestAt := GetRequestAt(c)
  434. jobID := GetJobID(c)
  435. generationID := GetGenerationID(c)
  436. responseID := GetResponseID(c)
  437. opts = append(
  438. opts,
  439. meta.WithRequestAt(requestAt),
  440. meta.WithRequestID(requestID),
  441. meta.WithGroup(group),
  442. meta.WithToken(token),
  443. meta.WithEndpoint(c.Request.URL.Path),
  444. meta.WithJobID(jobID),
  445. meta.WithGenerationID(generationID),
  446. meta.WithResponseID(responseID),
  447. )
  448. return meta.NewMeta(
  449. channel,
  450. mode,
  451. modelName,
  452. modelConfig,
  453. opts...,
  454. )
  455. }
  456. // https://platform.openai.com/docs/api-reference/chat
  457. func getRequestModel(c *gin.Context, m mode.Mode, group string, tokenID int) (string, error) {
  458. path := c.Request.URL.Path
  459. switch {
  460. case m == mode.ParsePdf:
  461. query := c.Request.URL.Query()
  462. model := query.Get("model")
  463. if model != "" {
  464. return model, nil
  465. }
  466. fallthrough
  467. case m == mode.AudioTranscription,
  468. m == mode.AudioTranslation,
  469. m == mode.ImagesEdits:
  470. return c.Request.FormValue("model"), nil
  471. case strings.HasPrefix(path, "/v1/engines") && strings.HasSuffix(path, "/embeddings"):
  472. // /engines/:model/embeddings
  473. return c.Param("model"), nil
  474. case m == mode.VideoGenerationsGetJobs:
  475. jobID := c.Param("id")
  476. store, err := model.CacheGetStore(group, tokenID, jobID)
  477. if err != nil {
  478. return "", fmt.Errorf("get request model failed: %w", err)
  479. }
  480. c.Set(JobID, store.ID)
  481. c.Set(ChannelID, store.ChannelID)
  482. return store.Model, nil
  483. case m == mode.VideoGenerationsContent:
  484. generationID := c.Param("id")
  485. store, err := model.CacheGetStore(group, tokenID, generationID)
  486. if err != nil {
  487. return "", fmt.Errorf("get request model failed: %w", err)
  488. }
  489. c.Set(GenerationID, store.ID)
  490. c.Set(ChannelID, store.ChannelID)
  491. return store.Model, nil
  492. case m == mode.ResponsesGet || m == mode.ResponsesDelete ||
  493. m == mode.ResponsesCancel || m == mode.ResponsesInputItems:
  494. responseID := c.Param("response_id")
  495. store, err := model.CacheGetStore(group, tokenID, responseID)
  496. if err != nil {
  497. return "", fmt.Errorf("get request model failed: %w", err)
  498. }
  499. c.Set(ResponseID, store.ID)
  500. c.Set(ChannelID, store.ChannelID)
  501. return store.Model, nil
  502. case m == mode.Responses:
  503. body, err := common.GetRequestBodyReusable(c.Request)
  504. if err != nil {
  505. return "", fmt.Errorf("get request model failed: %w", err)
  506. }
  507. responseID, err := GetPreviousResponseIDFromJSON(body)
  508. if err != nil {
  509. return "", fmt.Errorf("get request previous response id failed: %w", err)
  510. }
  511. modelName, err := GetModelFromJSON(body)
  512. if err != nil {
  513. return "", err
  514. }
  515. if responseID != "" {
  516. store, err := model.CacheGetStore(group, tokenID, responseID)
  517. if err != nil {
  518. return "", fmt.Errorf("get request model failed: %w", err)
  519. }
  520. c.Set(ResponseID, store.ID)
  521. c.Set(ChannelID, store.ChannelID)
  522. }
  523. return modelName, nil
  524. case m == mode.Gemini:
  525. modelName := strings.TrimPrefix(c.Param("model"), "/")
  526. modelName, _, _ = strings.Cut(modelName, ":")
  527. return modelName, nil
  528. default:
  529. body, err := common.GetRequestBodyReusable(c.Request)
  530. if err != nil {
  531. return "", fmt.Errorf("get request model failed: %w", err)
  532. }
  533. return GetModelFromJSON(body)
  534. }
  535. }
  536. func GetModelFromJSON(body []byte) (string, error) {
  537. node, err := sonic.GetWithOptions(body, ast.SearchOptions{}, "model")
  538. if err != nil {
  539. if errors.Is(err, ast.ErrNotExist) {
  540. return "", nil
  541. }
  542. return "", fmt.Errorf("get request model failed: %w", err)
  543. }
  544. return node.String()
  545. }
  546. func GetPreviousResponseIDFromJSON(body []byte) (string, error) {
  547. node, err := sonic.GetWithOptions(body, ast.SearchOptions{}, "previous_response_id")
  548. if err != nil {
  549. if errors.Is(err, ast.ErrNotExist) {
  550. return "", nil
  551. }
  552. return "", fmt.Errorf("get request model failed: %w", err)
  553. }
  554. return node.String()
  555. }
  556. // https://platform.openai.com/docs/api-reference/chat
  557. func getRequestUser(c *gin.Context, m mode.Mode) (string, error) {
  558. switch m {
  559. case mode.ChatCompletions,
  560. mode.Completions,
  561. mode.Embeddings,
  562. mode.ImagesGenerations,
  563. mode.AudioSpeech,
  564. mode.Rerank,
  565. mode.Anthropic,
  566. mode.Gemini:
  567. body, err := common.GetRequestBodyReusable(c.Request)
  568. if err != nil {
  569. return "", fmt.Errorf("get request model failed: %w", err)
  570. }
  571. return GetRequestUserFromJSON(body)
  572. default:
  573. return "", nil
  574. }
  575. }
  576. func GetRequestUserFromJSON(body []byte) (string, error) {
  577. node, err := sonic.GetWithOptions(body, ast.SearchOptions{}, "user")
  578. if err != nil {
  579. if errors.Is(err, ast.ErrNotExist) {
  580. return "", nil
  581. }
  582. return "", fmt.Errorf("get request user failed: %w", err)
  583. }
  584. if node.Exists() {
  585. return node.String()
  586. }
  587. return "", nil
  588. }
  589. func getRequestMetadata(c *gin.Context, m mode.Mode) (map[string]string, error) {
  590. switch m {
  591. case mode.ChatCompletions,
  592. mode.Completions,
  593. mode.Embeddings,
  594. mode.ImagesGenerations,
  595. mode.AudioSpeech,
  596. mode.Rerank,
  597. mode.Anthropic,
  598. mode.Gemini:
  599. body, err := common.GetRequestBodyReusable(c.Request)
  600. if err != nil {
  601. return nil, fmt.Errorf("get request metadata failed: %w", err)
  602. }
  603. return GetRequestMetadataFromJSON(body)
  604. default:
  605. return nil, nil
  606. }
  607. }
  608. type RequestWithMetadata struct {
  609. Metadata map[string]string `json:"metadata,omitempty"`
  610. }
  611. func GetRequestMetadataFromJSON(body []byte) (map[string]string, error) {
  612. var requestWithMetadata RequestWithMetadata
  613. if err := sonic.Unmarshal(body, &requestWithMetadata); err != nil {
  614. return nil, fmt.Errorf("get request metadata failed: %w", err)
  615. }
  616. return requestWithMetadata.Metadata, nil
  617. }