consume.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. package consume
  2. import (
  3. "context"
  4. "net/http"
  5. "sync"
  6. "time"
  7. "github.com/labring/aiproxy/core/common/balance"
  8. "github.com/labring/aiproxy/core/common/notify"
  9. "github.com/labring/aiproxy/core/model"
  10. "github.com/labring/aiproxy/core/relay/meta"
  11. "github.com/labring/aiproxy/core/relay/mode"
  12. "github.com/shopspring/decimal"
  13. log "github.com/sirupsen/logrus"
  14. )
  15. var consumeWaitGroup sync.WaitGroup
  16. func Wait() {
  17. consumeWaitGroup.Wait()
  18. }
  19. func AsyncConsume(
  20. postGroupConsumer balance.PostGroupConsumer,
  21. code int,
  22. firstByteAt time.Time,
  23. meta *meta.Meta,
  24. usage model.Usage,
  25. modelPrice model.Price,
  26. content string,
  27. ip string,
  28. retryTimes int,
  29. requestDetail *model.RequestDetail,
  30. downstreamResult bool,
  31. user string,
  32. metadata map[string]string,
  33. ) {
  34. if !checkNeedRecordConsume(code, meta) {
  35. return
  36. }
  37. consumeWaitGroup.Add(1)
  38. defer func() {
  39. consumeWaitGroup.Done()
  40. if r := recover(); r != nil {
  41. log.Errorf("panic in consume: %v", r)
  42. }
  43. }()
  44. go Consume(
  45. context.Background(),
  46. time.Now(),
  47. postGroupConsumer,
  48. firstByteAt,
  49. code,
  50. meta,
  51. usage,
  52. modelPrice,
  53. content,
  54. ip,
  55. retryTimes,
  56. requestDetail,
  57. downstreamResult,
  58. user,
  59. metadata,
  60. )
  61. }
  62. func Consume(
  63. ctx context.Context,
  64. now time.Time,
  65. postGroupConsumer balance.PostGroupConsumer,
  66. firstByteAt time.Time,
  67. code int,
  68. meta *meta.Meta,
  69. usage model.Usage,
  70. modelPrice model.Price,
  71. content string,
  72. ip string,
  73. retryTimes int,
  74. requestDetail *model.RequestDetail,
  75. downstreamResult bool,
  76. user string,
  77. metadata map[string]string,
  78. ) {
  79. if !checkNeedRecordConsume(code, meta) {
  80. return
  81. }
  82. amount := CalculateAmount(code, usage, modelPrice)
  83. amount = consumeAmount(ctx, amount, postGroupConsumer, meta)
  84. selectedModelPrice := modelPrice.SelectConditionalPrice(usage)
  85. selectedModelPrice.ConditionalPrices = nil
  86. err := recordConsume(
  87. now,
  88. meta,
  89. code,
  90. firstByteAt,
  91. usage,
  92. selectedModelPrice,
  93. content,
  94. ip,
  95. requestDetail,
  96. amount,
  97. retryTimes,
  98. downstreamResult,
  99. user,
  100. metadata,
  101. )
  102. if err != nil {
  103. log.Error("error batch record consume: " + err.Error())
  104. notify.ErrorThrottle("recordConsume", time.Minute, "record consume failed", err.Error())
  105. }
  106. }
  107. func checkNeedRecordConsume(code int, meta *meta.Meta) bool {
  108. switch meta.Mode {
  109. case mode.VideoGenerationsGetJobs,
  110. mode.VideoGenerationsContent,
  111. mode.ResponsesGet,
  112. mode.ResponsesDelete,
  113. mode.ResponsesCancel,
  114. mode.ResponsesInputItems:
  115. return code != http.StatusOK
  116. default:
  117. return true
  118. }
  119. }
  120. func consumeAmount(
  121. ctx context.Context,
  122. amount float64,
  123. postGroupConsumer balance.PostGroupConsumer,
  124. meta *meta.Meta,
  125. ) float64 {
  126. if amount > 0 && postGroupConsumer != nil {
  127. return processGroupConsume(ctx, amount, postGroupConsumer, meta)
  128. }
  129. return amount
  130. }
  131. func CalculateAmount(
  132. code int,
  133. usage model.Usage,
  134. modelPrice model.Price,
  135. ) float64 {
  136. if modelPrice.PerRequestPrice != 0 {
  137. if code != http.StatusOK {
  138. return 0
  139. }
  140. return float64(modelPrice.PerRequestPrice)
  141. }
  142. modelPrice = modelPrice.SelectConditionalPrice(usage)
  143. inputTokens := usage.InputTokens
  144. if modelPrice.ImageInputPrice > 0 {
  145. inputTokens -= usage.ImageInputTokens
  146. }
  147. if modelPrice.AudioInputPrice > 0 {
  148. inputTokens -= usage.AudioInputTokens
  149. }
  150. if modelPrice.CachedPrice > 0 {
  151. inputTokens -= usage.CachedTokens
  152. }
  153. if modelPrice.CacheCreationPrice > 0 {
  154. inputTokens -= usage.CacheCreationTokens
  155. }
  156. outputTokens := usage.OutputTokens
  157. outputPrice := float64(modelPrice.OutputPrice)
  158. outputPriceUnit := modelPrice.GetOutputPriceUnit()
  159. if usage.ReasoningTokens != 0 && modelPrice.ThinkingModeOutputPrice != 0 {
  160. outputPrice = float64(modelPrice.ThinkingModeOutputPrice)
  161. if modelPrice.ThinkingModeOutputPriceUnit != 0 {
  162. outputPriceUnit = int64(modelPrice.ThinkingModeOutputPriceUnit)
  163. }
  164. }
  165. inputAmount := decimal.NewFromInt(int64(inputTokens)).
  166. Mul(decimal.NewFromFloat(float64(modelPrice.InputPrice))).
  167. Div(decimal.NewFromInt(modelPrice.GetInputPriceUnit()))
  168. imageInputAmount := decimal.NewFromInt(int64(usage.ImageInputTokens)).
  169. Mul(decimal.NewFromFloat(float64(modelPrice.ImageInputPrice))).
  170. Div(decimal.NewFromInt(modelPrice.GetImageInputPriceUnit()))
  171. audioInputAmount := decimal.NewFromInt(int64(usage.AudioInputTokens)).
  172. Mul(decimal.NewFromFloat(float64(modelPrice.AudioInputPrice))).
  173. Div(decimal.NewFromInt(modelPrice.GetAudioInputPriceUnit()))
  174. cachedAmount := decimal.NewFromInt(int64(usage.CachedTokens)).
  175. Mul(decimal.NewFromFloat(float64(modelPrice.CachedPrice))).
  176. Div(decimal.NewFromInt(modelPrice.GetCachedPriceUnit()))
  177. cacheCreationAmount := decimal.NewFromInt(int64(usage.CacheCreationTokens)).
  178. Mul(decimal.NewFromFloat(float64(modelPrice.CacheCreationPrice))).
  179. Div(decimal.NewFromInt(modelPrice.GetCacheCreationPriceUnit()))
  180. webSearchAmount := decimal.NewFromInt(int64(usage.WebSearchCount)).
  181. Mul(decimal.NewFromFloat(float64(modelPrice.WebSearchPrice))).
  182. Div(decimal.NewFromInt(modelPrice.GetWebSearchPriceUnit()))
  183. outputAmount := decimal.NewFromInt(int64(outputTokens)).
  184. Mul(decimal.NewFromFloat(outputPrice)).
  185. Div(decimal.NewFromInt(outputPriceUnit))
  186. return inputAmount.
  187. Add(imageInputAmount).
  188. Add(audioInputAmount).
  189. Add(cachedAmount).
  190. Add(cacheCreationAmount).
  191. Add(webSearchAmount).
  192. Add(outputAmount).
  193. InexactFloat64()
  194. }
  195. func processGroupConsume(
  196. ctx context.Context,
  197. amount float64,
  198. postGroupConsumer balance.PostGroupConsumer,
  199. meta *meta.Meta,
  200. ) float64 {
  201. consumedAmount, err := postGroupConsumer.PostGroupConsume(ctx, meta.Token.Name, amount)
  202. if err != nil {
  203. log.Error("error consuming token remain amount: " + err.Error())
  204. if err := model.CreateConsumeError(
  205. meta.RequestID,
  206. meta.RequestAt,
  207. meta.Group.ID,
  208. meta.Token.Name,
  209. meta.OriginModel,
  210. err.Error(),
  211. amount,
  212. meta.Token.ID,
  213. ); err != nil {
  214. log.Error("failed to create consume error: " + err.Error())
  215. }
  216. return amount
  217. }
  218. return consumedAmount
  219. }