| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253 |
- package consume
- import (
- "context"
- "net/http"
- "sync"
- "time"
- "github.com/labring/aiproxy/core/common/balance"
- "github.com/labring/aiproxy/core/common/notify"
- "github.com/labring/aiproxy/core/model"
- "github.com/labring/aiproxy/core/relay/meta"
- "github.com/labring/aiproxy/core/relay/mode"
- "github.com/shopspring/decimal"
- log "github.com/sirupsen/logrus"
- )
- var consumeWaitGroup sync.WaitGroup
- func Wait() {
- consumeWaitGroup.Wait()
- }
- func AsyncConsume(
- postGroupConsumer balance.PostGroupConsumer,
- code int,
- firstByteAt time.Time,
- meta *meta.Meta,
- usage model.Usage,
- modelPrice model.Price,
- content string,
- ip string,
- retryTimes int,
- requestDetail *model.RequestDetail,
- downstreamResult bool,
- user string,
- metadata map[string]string,
- ) {
- if !checkNeedRecordConsume(code, meta) {
- return
- }
- consumeWaitGroup.Add(1)
- defer func() {
- consumeWaitGroup.Done()
- if r := recover(); r != nil {
- log.Errorf("panic in consume: %v", r)
- }
- }()
- go Consume(
- context.Background(),
- time.Now(),
- postGroupConsumer,
- firstByteAt,
- code,
- meta,
- usage,
- modelPrice,
- content,
- ip,
- retryTimes,
- requestDetail,
- downstreamResult,
- user,
- metadata,
- )
- }
- func Consume(
- ctx context.Context,
- now time.Time,
- postGroupConsumer balance.PostGroupConsumer,
- firstByteAt time.Time,
- code int,
- meta *meta.Meta,
- usage model.Usage,
- modelPrice model.Price,
- content string,
- ip string,
- retryTimes int,
- requestDetail *model.RequestDetail,
- downstreamResult bool,
- user string,
- metadata map[string]string,
- ) {
- if !checkNeedRecordConsume(code, meta) {
- return
- }
- amount := CalculateAmount(code, usage, modelPrice)
- amount = consumeAmount(ctx, amount, postGroupConsumer, meta)
- selectedModelPrice := modelPrice.SelectConditionalPrice(usage)
- selectedModelPrice.ConditionalPrices = nil
- err := recordConsume(
- now,
- meta,
- code,
- firstByteAt,
- usage,
- selectedModelPrice,
- content,
- ip,
- requestDetail,
- amount,
- retryTimes,
- downstreamResult,
- user,
- metadata,
- )
- if err != nil {
- log.Error("error batch record consume: " + err.Error())
- notify.ErrorThrottle("recordConsume", time.Minute, "record consume failed", err.Error())
- }
- }
- func checkNeedRecordConsume(code int, meta *meta.Meta) bool {
- switch meta.Mode {
- case mode.VideoGenerationsGetJobs,
- mode.VideoGenerationsContent,
- mode.ResponsesGet,
- mode.ResponsesDelete,
- mode.ResponsesCancel,
- mode.ResponsesInputItems:
- return code != http.StatusOK
- default:
- return true
- }
- }
- func consumeAmount(
- ctx context.Context,
- amount float64,
- postGroupConsumer balance.PostGroupConsumer,
- meta *meta.Meta,
- ) float64 {
- if amount > 0 && postGroupConsumer != nil {
- return processGroupConsume(ctx, amount, postGroupConsumer, meta)
- }
- return amount
- }
- func CalculateAmount(
- code int,
- usage model.Usage,
- modelPrice model.Price,
- ) float64 {
- if modelPrice.PerRequestPrice != 0 {
- if code != http.StatusOK {
- return 0
- }
- return float64(modelPrice.PerRequestPrice)
- }
- modelPrice = modelPrice.SelectConditionalPrice(usage)
- inputTokens := usage.InputTokens
- if modelPrice.ImageInputPrice > 0 {
- inputTokens -= usage.ImageInputTokens
- }
- if modelPrice.AudioInputPrice > 0 {
- inputTokens -= usage.AudioInputTokens
- }
- if modelPrice.CachedPrice > 0 {
- inputTokens -= usage.CachedTokens
- }
- if modelPrice.CacheCreationPrice > 0 {
- inputTokens -= usage.CacheCreationTokens
- }
- outputTokens := usage.OutputTokens
- outputPrice := float64(modelPrice.OutputPrice)
- outputPriceUnit := modelPrice.GetOutputPriceUnit()
- if usage.ReasoningTokens != 0 && modelPrice.ThinkingModeOutputPrice != 0 {
- outputPrice = float64(modelPrice.ThinkingModeOutputPrice)
- if modelPrice.ThinkingModeOutputPriceUnit != 0 {
- outputPriceUnit = int64(modelPrice.ThinkingModeOutputPriceUnit)
- }
- }
- inputAmount := decimal.NewFromInt(int64(inputTokens)).
- Mul(decimal.NewFromFloat(float64(modelPrice.InputPrice))).
- Div(decimal.NewFromInt(modelPrice.GetInputPriceUnit()))
- imageInputAmount := decimal.NewFromInt(int64(usage.ImageInputTokens)).
- Mul(decimal.NewFromFloat(float64(modelPrice.ImageInputPrice))).
- Div(decimal.NewFromInt(modelPrice.GetImageInputPriceUnit()))
- audioInputAmount := decimal.NewFromInt(int64(usage.AudioInputTokens)).
- Mul(decimal.NewFromFloat(float64(modelPrice.AudioInputPrice))).
- Div(decimal.NewFromInt(modelPrice.GetAudioInputPriceUnit()))
- cachedAmount := decimal.NewFromInt(int64(usage.CachedTokens)).
- Mul(decimal.NewFromFloat(float64(modelPrice.CachedPrice))).
- Div(decimal.NewFromInt(modelPrice.GetCachedPriceUnit()))
- cacheCreationAmount := decimal.NewFromInt(int64(usage.CacheCreationTokens)).
- Mul(decimal.NewFromFloat(float64(modelPrice.CacheCreationPrice))).
- Div(decimal.NewFromInt(modelPrice.GetCacheCreationPriceUnit()))
- webSearchAmount := decimal.NewFromInt(int64(usage.WebSearchCount)).
- Mul(decimal.NewFromFloat(float64(modelPrice.WebSearchPrice))).
- Div(decimal.NewFromInt(modelPrice.GetWebSearchPriceUnit()))
- outputAmount := decimal.NewFromInt(int64(outputTokens)).
- Mul(decimal.NewFromFloat(outputPrice)).
- Div(decimal.NewFromInt(outputPriceUnit))
- return inputAmount.
- Add(imageInputAmount).
- Add(audioInputAmount).
- Add(cachedAmount).
- Add(cacheCreationAmount).
- Add(webSearchAmount).
- Add(outputAmount).
- InexactFloat64()
- }
- func processGroupConsume(
- ctx context.Context,
- amount float64,
- postGroupConsumer balance.PostGroupConsumer,
- meta *meta.Meta,
- ) float64 {
- consumedAmount, err := postGroupConsumer.PostGroupConsume(ctx, meta.Token.Name, amount)
- if err != nil {
- log.Error("error consuming token remain amount: " + err.Error())
- if err := model.CreateConsumeError(
- meta.RequestID,
- meta.RequestAt,
- meta.Group.ID,
- meta.Token.Name,
- meta.OriginModel,
- err.Error(),
- amount,
- meta.Token.ID,
- ); err != nil {
- log.Error("failed to create consume error: " + err.Error())
- }
- return amount
- }
- return consumedAmount
- }
|