task_billing.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. package service
  2. import (
  3. "context"
  4. "fmt"
  5. "strings"
  6. "github.com/QuantumNous/new-api/common"
  7. "github.com/QuantumNous/new-api/constant"
  8. "github.com/QuantumNous/new-api/logger"
  9. "github.com/QuantumNous/new-api/model"
  10. relaycommon "github.com/QuantumNous/new-api/relay/common"
  11. "github.com/QuantumNous/new-api/setting/ratio_setting"
  12. "github.com/gin-gonic/gin"
  13. )
  14. // LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。
  15. // 实际扣费已由 BillingSession(PreConsumeBilling + SettleBilling)完成。
  16. func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo) {
  17. tokenName := c.GetString("token_name")
  18. logContent := fmt.Sprintf("操作 %s", info.Action)
  19. // 支持任务仅按次计费
  20. if common.StringsContains(constant.TaskPricePatches, info.OriginModelName) {
  21. logContent = fmt.Sprintf("%s,按次计费", logContent)
  22. } else {
  23. if len(info.PriceData.OtherRatios) > 0 {
  24. var contents []string
  25. for key, ra := range info.PriceData.OtherRatios {
  26. if 1.0 != ra {
  27. contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
  28. }
  29. }
  30. if len(contents) > 0 {
  31. logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
  32. }
  33. }
  34. }
  35. other := make(map[string]interface{})
  36. other["is_task"] = true
  37. other["request_path"] = c.Request.URL.Path
  38. other["model_price"] = info.PriceData.ModelPrice
  39. if info.PriceData.ModelRatio > 0 {
  40. other["model_ratio"] = info.PriceData.ModelRatio
  41. }
  42. other["group_ratio"] = info.PriceData.GroupRatioInfo.GroupRatio
  43. if info.PriceData.GroupRatioInfo.HasSpecialRatio {
  44. other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio
  45. }
  46. if info.IsModelMapped {
  47. other["is_model_mapped"] = true
  48. other["upstream_model_name"] = info.UpstreamModelName
  49. }
  50. model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
  51. ChannelId: info.ChannelId,
  52. ModelName: info.OriginModelName,
  53. TokenName: tokenName,
  54. Quota: info.PriceData.Quota,
  55. Content: logContent,
  56. TokenId: info.TokenId,
  57. Group: info.UsingGroup,
  58. Other: other,
  59. })
  60. model.UpdateUserUsedQuotaAndRequestCount(info.UserId, info.PriceData.Quota)
  61. model.UpdateChannelUsedQuota(info.ChannelId, info.PriceData.Quota)
  62. }
  63. // ---------------------------------------------------------------------------
  64. // 异步任务计费辅助函数
  65. // ---------------------------------------------------------------------------
  66. // resolveTokenKey 通过 TokenId 运行时获取令牌 Key(用于 Redis 缓存操作)。
  67. // 如果令牌已被删除或查询失败,返回空字符串。
  68. func resolveTokenKey(ctx context.Context, tokenId int, taskID string) string {
  69. token, err := model.GetTokenById(tokenId)
  70. if err != nil {
  71. logger.LogWarn(ctx, fmt.Sprintf("获取令牌 key 失败 (tokenId=%d, task=%s): %s", tokenId, taskID, err.Error()))
  72. return ""
  73. }
  74. return token.Key
  75. }
  76. // taskIsSubscription 判断任务是否通过订阅计费。
  77. func taskIsSubscription(task *model.Task) bool {
  78. return task.PrivateData.BillingSource == BillingSourceSubscription && task.PrivateData.SubscriptionId > 0
  79. }
  80. // taskAdjustFunding 调整任务的资金来源(钱包或订阅),delta > 0 表示扣费,delta < 0 表示退还。
  81. func taskAdjustFunding(task *model.Task, delta int) error {
  82. if taskIsSubscription(task) {
  83. return model.PostConsumeUserSubscriptionDelta(task.PrivateData.SubscriptionId, int64(delta))
  84. }
  85. if delta > 0 {
  86. return model.DecreaseUserQuota(task.UserId, delta, false)
  87. }
  88. return model.IncreaseUserQuota(task.UserId, -delta, false)
  89. }
  90. // taskAdjustTokenQuota 调整任务的令牌额度,delta > 0 表示扣费,delta < 0 表示退还。
  91. // 需要通过 resolveTokenKey 运行时获取 key(不从 PrivateData 中读取)。
  92. func taskAdjustTokenQuota(ctx context.Context, task *model.Task, delta int) {
  93. if task.PrivateData.TokenId <= 0 || delta == 0 {
  94. return
  95. }
  96. tokenKey := resolveTokenKey(ctx, task.PrivateData.TokenId, task.TaskID)
  97. if tokenKey == "" {
  98. return
  99. }
  100. var err error
  101. if delta > 0 {
  102. err = model.DecreaseTokenQuota(task.PrivateData.TokenId, tokenKey, delta)
  103. } else {
  104. err = model.IncreaseTokenQuota(task.PrivateData.TokenId, tokenKey, -delta)
  105. }
  106. if err != nil {
  107. logger.LogWarn(ctx, fmt.Sprintf("调整令牌额度失败 (delta=%d, task=%s): %s", delta, task.TaskID, err.Error()))
  108. }
  109. }
  110. // taskBillingOther 从 task 的 BillingContext 构建日志 Other 字段。
  111. func taskBillingOther(task *model.Task) map[string]interface{} {
  112. other := make(map[string]interface{})
  113. if bc := task.PrivateData.BillingContext; bc != nil {
  114. other["model_price"] = bc.ModelPrice
  115. if bc.ModelRatio > 0 {
  116. other["model_ratio"] = bc.ModelRatio
  117. }
  118. other["group_ratio"] = bc.GroupRatio
  119. if len(bc.OtherRatios) > 0 {
  120. for k, v := range bc.OtherRatios {
  121. other[k] = v
  122. }
  123. }
  124. }
  125. props := task.Properties
  126. if props.UpstreamModelName != "" && props.UpstreamModelName != props.OriginModelName {
  127. other["is_model_mapped"] = true
  128. other["upstream_model_name"] = props.UpstreamModelName
  129. }
  130. return other
  131. }
  132. // taskModelName 从 BillingContext 或 Properties 中获取模型名称。
  133. func taskModelName(task *model.Task) string {
  134. if bc := task.PrivateData.BillingContext; bc != nil && bc.OriginModelName != "" {
  135. return bc.OriginModelName
  136. }
  137. return task.Properties.OriginModelName
  138. }
  139. // RefundTaskQuota 统一的任务失败退款逻辑。
  140. // 当异步任务失败时,将预扣的 quota 退还给用户(支持钱包和订阅),并退还令牌额度。
  141. func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) {
  142. quota := task.Quota
  143. if quota == 0 {
  144. return
  145. }
  146. // 1. 退还资金来源(钱包或订阅)
  147. if err := taskAdjustFunding(task, -quota); err != nil {
  148. logger.LogWarn(ctx, fmt.Sprintf("退还资金来源失败 task %s: %s", task.TaskID, err.Error()))
  149. return
  150. }
  151. // 2. 退还令牌额度
  152. taskAdjustTokenQuota(ctx, task, -quota)
  153. // 3. 记录日志
  154. other := taskBillingOther(task)
  155. other["task_id"] = task.TaskID
  156. other["reason"] = reason
  157. model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{
  158. UserId: task.UserId,
  159. LogType: model.LogTypeRefund,
  160. Content: "",
  161. ChannelId: task.ChannelId,
  162. ModelName: taskModelName(task),
  163. Quota: quota,
  164. TokenId: task.PrivateData.TokenId,
  165. Group: task.Group,
  166. Other: other,
  167. })
  168. }
  169. // RecalculateTaskQuota 通用的异步差额结算。
  170. // actualQuota 是任务完成后的实际应扣额度,与预扣额度 (task.Quota) 做差额结算。
  171. // reason 用于日志记录(例如 "token重算" 或 "adaptor调整")。
  172. func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int, reason string) {
  173. if actualQuota <= 0 {
  174. return
  175. }
  176. preConsumedQuota := task.Quota
  177. quotaDelta := actualQuota - preConsumedQuota
  178. if quotaDelta == 0 {
  179. logger.LogInfo(ctx, fmt.Sprintf("任务 %s 预扣费准确(%s,%s)",
  180. task.TaskID, logger.LogQuota(actualQuota), reason))
  181. return
  182. }
  183. logger.LogInfo(ctx, fmt.Sprintf("任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,%s)",
  184. task.TaskID,
  185. logger.LogQuota(quotaDelta),
  186. logger.LogQuota(actualQuota),
  187. logger.LogQuota(preConsumedQuota),
  188. reason,
  189. ))
  190. // 调整资金来源
  191. if err := taskAdjustFunding(task, quotaDelta); err != nil {
  192. logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error()))
  193. return
  194. }
  195. // 调整令牌额度
  196. taskAdjustTokenQuota(ctx, task, quotaDelta)
  197. task.Quota = actualQuota
  198. var logType int
  199. var logQuota int
  200. if quotaDelta > 0 {
  201. logType = model.LogTypeConsume
  202. logQuota = quotaDelta
  203. model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
  204. model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
  205. } else {
  206. logType = model.LogTypeRefund
  207. logQuota = -quotaDelta
  208. }
  209. other := taskBillingOther(task)
  210. other["task_id"] = task.TaskID
  211. other["pre_consumed_quota"] = preConsumedQuota
  212. other["actual_quota"] = actualQuota
  213. model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{
  214. UserId: task.UserId,
  215. LogType: logType,
  216. Content: reason,
  217. ChannelId: task.ChannelId,
  218. ModelName: taskModelName(task),
  219. Quota: logQuota,
  220. TokenId: task.PrivateData.TokenId,
  221. Group: task.Group,
  222. Other: other,
  223. })
  224. }
  225. // RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。
  226. // 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度,
  227. // 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。
  228. func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTokens int) {
  229. if totalTokens <= 0 {
  230. return
  231. }
  232. modelName := taskModelName(task)
  233. // 获取模型价格和倍率
  234. modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
  235. // 只有配置了倍率(非固定价格)时才按 token 重新计费
  236. if !hasRatioSetting || modelRatio <= 0 {
  237. return
  238. }
  239. // 获取用户和组的倍率信息
  240. group := task.Group
  241. if group == "" {
  242. user, err := model.GetUserById(task.UserId, false)
  243. if err == nil {
  244. group = user.Group
  245. }
  246. }
  247. if group == "" {
  248. return
  249. }
  250. groupRatio := ratio_setting.GetGroupRatio(group)
  251. userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
  252. var finalGroupRatio float64
  253. if hasUserGroupRatio {
  254. finalGroupRatio = userGroupRatio
  255. } else {
  256. finalGroupRatio = groupRatio
  257. }
  258. // 计算 OtherRatios 乘积(视频折扣、时长等)
  259. otherMultiplier := 1.0
  260. if bc := task.PrivateData.BillingContext; bc != nil {
  261. for _, r := range bc.OtherRatios {
  262. if r != 1.0 && r > 0 {
  263. otherMultiplier *= r
  264. }
  265. }
  266. }
  267. // 计算实际应扣费额度: totalTokens * modelRatio * groupRatio * otherMultiplier
  268. actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio * otherMultiplier)
  269. reason := fmt.Sprintf("token重算:tokens=%d, modelRatio=%.2f, groupRatio=%.2f, otherMultiplier=%.4f", totalTokens, modelRatio, finalGroupRatio, otherMultiplier)
  270. RecalculateTaskQuota(ctx, task, actualQuota, reason)
  271. }