billing_session.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. package service
  2. import (
  3. "fmt"
  4. "net/http"
  5. "strings"
  6. "sync"
  7. "github.com/QuantumNous/new-api/common"
  8. "github.com/QuantumNous/new-api/logger"
  9. "github.com/QuantumNous/new-api/model"
  10. relaycommon "github.com/QuantumNous/new-api/relay/common"
  11. "github.com/QuantumNous/new-api/types"
  12. "github.com/bytedance/gopkg/util/gopool"
  13. "github.com/gin-gonic/gin"
  14. )
  15. // ---------------------------------------------------------------------------
  16. // BillingSession — 统一计费会话
  17. // ---------------------------------------------------------------------------
  18. // BillingSession 封装单次请求的预扣费/结算/退款生命周期。
  19. // 实现 relaycommon.BillingSettler 接口。
  20. type BillingSession struct {
  21. relayInfo *relaycommon.RelayInfo
  22. funding FundingSource
  23. preConsumedQuota int // 实际预扣额度(信任用户可能为 0)
  24. tokenConsumed int // 令牌额度实际扣减量
  25. extraReserved int // 发送前补充预扣的额度(订阅退款时需要单独回滚)
  26. trusted bool // 是否命中信任额度旁路
  27. fundingSettled bool // funding.Settle 已成功,资金来源已提交
  28. settled bool // Settle 全部完成(资金 + 令牌)
  29. refunded bool // Refund 已调用
  30. mu sync.Mutex
  31. }
  32. // Settle 根据实际消耗额度进行结算。
  33. // 资金来源和令牌额度分两步提交:若资金来源已提交但令牌调整失败,
  34. // 会标记 fundingSettled 防止 Refund 对已提交的资金来源执行退款。
  35. func (s *BillingSession) Settle(actualQuota int) error {
  36. s.mu.Lock()
  37. defer s.mu.Unlock()
  38. if s.settled {
  39. return nil
  40. }
  41. delta := actualQuota - s.preConsumedQuota
  42. if delta == 0 {
  43. s.settled = true
  44. return nil
  45. }
  46. // 1) 调整资金来源(仅在尚未提交时执行,防止重复调用)
  47. if !s.fundingSettled {
  48. if err := s.funding.Settle(delta); err != nil {
  49. return err
  50. }
  51. s.fundingSettled = true
  52. }
  53. // 2) 调整令牌额度
  54. var tokenErr error
  55. if !s.relayInfo.IsPlayground {
  56. if delta > 0 {
  57. tokenErr = model.DecreaseTokenQuota(s.relayInfo.TokenId, s.relayInfo.TokenKey, delta)
  58. } else {
  59. tokenErr = model.IncreaseTokenQuota(s.relayInfo.TokenId, s.relayInfo.TokenKey, -delta)
  60. }
  61. if tokenErr != nil {
  62. // 资金来源已提交,令牌调整失败只能记录日志;标记 settled 防止 Refund 误退资金
  63. common.SysLog(fmt.Sprintf("error adjusting token quota after funding settled (userId=%d, tokenId=%d, delta=%d): %s",
  64. s.relayInfo.UserId, s.relayInfo.TokenId, delta, tokenErr.Error()))
  65. }
  66. }
  67. // 3) 更新 relayInfo 上的订阅 PostDelta(用于日志)
  68. if s.funding.Source() == BillingSourceSubscription {
  69. s.relayInfo.SubscriptionPostDelta += int64(delta)
  70. }
  71. s.settled = true
  72. return tokenErr
  73. }
  74. // Refund 退还所有预扣费,幂等安全,异步执行。
  75. func (s *BillingSession) Refund(c *gin.Context) {
  76. s.mu.Lock()
  77. if s.settled || s.refunded || !s.needsRefundLocked() {
  78. s.mu.Unlock()
  79. return
  80. }
  81. s.refunded = true
  82. s.mu.Unlock()
  83. logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费(token_quota=%s, funding=%s)",
  84. s.relayInfo.UserId,
  85. logger.FormatQuota(s.tokenConsumed),
  86. s.funding.Source(),
  87. ))
  88. // 复制需要的值到闭包中
  89. tokenId := s.relayInfo.TokenId
  90. tokenKey := s.relayInfo.TokenKey
  91. isPlayground := s.relayInfo.IsPlayground
  92. tokenConsumed := s.tokenConsumed
  93. extraReserved := s.extraReserved
  94. subscriptionId := s.relayInfo.SubscriptionId
  95. funding := s.funding
  96. gopool.Go(func() {
  97. // 1) 退还资金来源
  98. if err := funding.Refund(); err != nil {
  99. common.SysLog("error refunding billing source: " + err.Error())
  100. }
  101. if extraReserved > 0 && funding.Source() == BillingSourceSubscription && subscriptionId > 0 {
  102. if err := model.PostConsumeUserSubscriptionDelta(subscriptionId, -int64(extraReserved)); err != nil {
  103. common.SysLog("error refunding subscription extra reserved quota: " + err.Error())
  104. }
  105. }
  106. // 2) 退还令牌额度
  107. if tokenConsumed > 0 && !isPlayground {
  108. if err := model.IncreaseTokenQuota(tokenId, tokenKey, tokenConsumed); err != nil {
  109. common.SysLog("error refunding token quota: " + err.Error())
  110. }
  111. }
  112. })
  113. }
  114. // NeedsRefund 返回是否存在需要退还的预扣状态。
  115. func (s *BillingSession) NeedsRefund() bool {
  116. s.mu.Lock()
  117. defer s.mu.Unlock()
  118. return s.needsRefundLocked()
  119. }
  120. func (s *BillingSession) needsRefundLocked() bool {
  121. if s.settled || s.refunded || s.fundingSettled {
  122. // fundingSettled 时资金来源已提交结算,不能再退预扣费
  123. return false
  124. }
  125. if s.tokenConsumed > 0 {
  126. return true
  127. }
  128. // 订阅可能在 tokenConsumed=0 时仍预扣了额度
  129. if sub, ok := s.funding.(*SubscriptionFunding); ok && sub.preConsumed > 0 {
  130. return true
  131. }
  132. return false
  133. }
  134. // GetPreConsumedQuota 返回实际预扣的额度。
  135. func (s *BillingSession) GetPreConsumedQuota() int {
  136. return s.preConsumedQuota
  137. }
  138. func (s *BillingSession) Reserve(targetQuota int) error {
  139. s.mu.Lock()
  140. defer s.mu.Unlock()
  141. if s.settled || s.refunded || s.trusted || targetQuota <= s.preConsumedQuota {
  142. return nil
  143. }
  144. delta := targetQuota - s.preConsumedQuota
  145. if delta <= 0 {
  146. return nil
  147. }
  148. if err := s.reserveFunding(delta); err != nil {
  149. return err
  150. }
  151. if err := s.reserveToken(delta); err != nil {
  152. s.rollbackFundingReserve(delta)
  153. return err
  154. }
  155. s.preConsumedQuota += delta
  156. s.tokenConsumed += delta
  157. s.extraReserved += delta
  158. s.syncRelayInfo()
  159. return nil
  160. }
  161. // ---------------------------------------------------------------------------
  162. // PreConsume — 统一预扣费入口(含信任额度旁路)
  163. // ---------------------------------------------------------------------------
  164. // preConsume 执行预扣费:信任检查 -> 令牌预扣 -> 资金来源预扣。
  165. // 任一步骤失败时原子回滚已完成的步骤。
  166. func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIError {
  167. effectiveQuota := quota
  168. // ---- 信任额度旁路 ----
  169. if s.shouldTrust(c) {
  170. s.trusted = true
  171. effectiveQuota = 0
  172. logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足, 信任且不需要预扣费 (funding=%s)", s.relayInfo.UserId, s.funding.Source()))
  173. } else if effectiveQuota > 0 {
  174. logger.LogInfo(c, fmt.Sprintf("用户 %d 需要预扣费 %s (funding=%s)", s.relayInfo.UserId, logger.FormatQuota(effectiveQuota), s.funding.Source()))
  175. }
  176. // ---- 1) 预扣令牌额度 ----
  177. if effectiveQuota > 0 {
  178. if err := PreConsumeTokenQuota(s.relayInfo, effectiveQuota); err != nil {
  179. return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
  180. }
  181. s.tokenConsumed = effectiveQuota
  182. }
  183. // ---- 2) 预扣资金来源 ----
  184. if err := s.funding.PreConsume(effectiveQuota); err != nil {
  185. // 预扣费失败,回滚令牌额度
  186. if s.tokenConsumed > 0 && !s.relayInfo.IsPlayground {
  187. if rollbackErr := model.IncreaseTokenQuota(s.relayInfo.TokenId, s.relayInfo.TokenKey, s.tokenConsumed); rollbackErr != nil {
  188. common.SysLog(fmt.Sprintf("error rolling back token quota (userId=%d, tokenId=%d, amount=%d, fundingErr=%s): %s",
  189. s.relayInfo.UserId, s.relayInfo.TokenId, s.tokenConsumed, err.Error(), rollbackErr.Error()))
  190. }
  191. s.tokenConsumed = 0
  192. }
  193. // TODO: model 层应定义哨兵错误(如 ErrNoActiveSubscription),用 errors.Is 替代字符串匹配
  194. errMsg := err.Error()
  195. if strings.Contains(errMsg, "no active subscription") || strings.Contains(errMsg, "subscription quota insufficient") {
  196. return types.NewErrorWithStatusCode(fmt.Errorf("订阅额度不足或未配置订阅: %s", errMsg), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
  197. }
  198. return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
  199. }
  200. s.preConsumedQuota = effectiveQuota
  201. // ---- 同步 RelayInfo 兼容字段 ----
  202. s.syncRelayInfo()
  203. return nil
  204. }
  205. func (s *BillingSession) reserveFunding(delta int) error {
  206. switch funding := s.funding.(type) {
  207. case *WalletFunding:
  208. if err := model.DecreaseUserQuota(funding.userId, delta); err != nil {
  209. return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
  210. }
  211. funding.consumed += delta
  212. return nil
  213. case *SubscriptionFunding:
  214. if err := model.PostConsumeUserSubscriptionDelta(funding.subscriptionId, int64(delta)); err != nil {
  215. return types.NewErrorWithStatusCode(
  216. fmt.Errorf("订阅额度不足或未配置订阅: %s", err.Error()),
  217. types.ErrorCodeInsufficientUserQuota,
  218. http.StatusForbidden,
  219. types.ErrOptionWithSkipRetry(),
  220. types.ErrOptionWithNoRecordErrorLog(),
  221. )
  222. }
  223. return nil
  224. default:
  225. return types.NewError(fmt.Errorf("unsupported funding source: %s", s.funding.Source()), types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
  226. }
  227. }
  228. func (s *BillingSession) rollbackFundingReserve(delta int) {
  229. switch funding := s.funding.(type) {
  230. case *WalletFunding:
  231. if err := model.IncreaseUserQuota(funding.userId, delta, false); err != nil {
  232. common.SysLog("error rolling back wallet funding reserve: " + err.Error())
  233. } else {
  234. funding.consumed -= delta
  235. }
  236. case *SubscriptionFunding:
  237. if err := model.PostConsumeUserSubscriptionDelta(funding.subscriptionId, -int64(delta)); err != nil {
  238. common.SysLog("error rolling back subscription funding reserve: " + err.Error())
  239. }
  240. }
  241. }
  242. func (s *BillingSession) reserveToken(delta int) error {
  243. if delta <= 0 || s.relayInfo.IsPlayground {
  244. return nil
  245. }
  246. if err := PreConsumeTokenQuota(s.relayInfo, delta); err != nil {
  247. return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
  248. }
  249. return nil
  250. }
  251. // shouldTrust 统一信任额度检查,适用于钱包和订阅。
  252. func (s *BillingSession) shouldTrust(c *gin.Context) bool {
  253. // 异步任务(ForcePreConsume=true)必须预扣全额,不允许信任旁路
  254. if s.relayInfo.ForcePreConsume {
  255. return false
  256. }
  257. trustQuota := common.GetTrustQuota()
  258. if trustQuota <= 0 {
  259. return false
  260. }
  261. // 检查令牌是否充足
  262. tokenTrusted := s.relayInfo.TokenUnlimited
  263. if !tokenTrusted {
  264. tokenQuota := c.GetInt("token_quota")
  265. tokenTrusted = tokenQuota > trustQuota
  266. }
  267. if !tokenTrusted {
  268. return false
  269. }
  270. switch s.funding.Source() {
  271. case BillingSourceWallet:
  272. return s.relayInfo.UserQuota > trustQuota
  273. case BillingSourceSubscription:
  274. // 订阅不能启用信任旁路。原因:
  275. // 1. PreConsumeUserSubscription 要求 amount>0 来创建预扣记录并锁定订阅
  276. // 2. SubscriptionFunding.PreConsume 忽略参数,始终用 s.amount 预扣
  277. // 3. 若信任旁路将 effectiveQuota 设为 0,会导致 preConsumedQuota 与实际订阅预扣不一致
  278. return false
  279. default:
  280. return false
  281. }
  282. }
  283. // syncRelayInfo 将 BillingSession 的状态同步到 RelayInfo 的兼容字段上。
  284. func (s *BillingSession) syncRelayInfo() {
  285. info := s.relayInfo
  286. info.FinalPreConsumedQuota = s.preConsumedQuota
  287. info.BillingSource = s.funding.Source()
  288. if sub, ok := s.funding.(*SubscriptionFunding); ok {
  289. info.SubscriptionId = sub.subscriptionId
  290. info.SubscriptionPreConsumed = sub.preConsumed + int64(s.extraReserved)
  291. info.SubscriptionPostDelta = 0
  292. info.SubscriptionAmountTotal = sub.AmountTotal
  293. info.SubscriptionAmountUsedAfterPreConsume = sub.AmountUsedAfter + int64(s.extraReserved)
  294. info.SubscriptionPlanId = sub.PlanId
  295. info.SubscriptionPlanTitle = sub.PlanTitle
  296. } else {
  297. info.SubscriptionId = 0
  298. info.SubscriptionPreConsumed = 0
  299. }
  300. }
  301. // ---------------------------------------------------------------------------
  302. // NewBillingSession 工厂 — 根据计费偏好创建会话并处理回退
  303. // ---------------------------------------------------------------------------
  304. // NewBillingSession 根据用户计费偏好创建 BillingSession,处理 subscription_first / wallet_first 的回退。
  305. func NewBillingSession(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) (*BillingSession, *types.NewAPIError) {
  306. if relayInfo == nil {
  307. return nil, types.NewError(fmt.Errorf("relayInfo is nil"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
  308. }
  309. pref := common.NormalizeBillingPreference(relayInfo.UserSetting.BillingPreference)
  310. // 钱包路径需要先检查用户额度
  311. tryWallet := func() (*BillingSession, *types.NewAPIError) {
  312. userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
  313. if err != nil {
  314. return nil, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
  315. }
  316. if userQuota <= 0 {
  317. return nil, types.NewErrorWithStatusCode(
  318. fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)),
  319. types.ErrorCodeInsufficientUserQuota, http.StatusForbidden,
  320. types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
  321. }
  322. if userQuota-preConsumedQuota < 0 {
  323. return nil, types.NewErrorWithStatusCode(
  324. fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)),
  325. types.ErrorCodeInsufficientUserQuota, http.StatusForbidden,
  326. types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
  327. }
  328. relayInfo.UserQuota = userQuota
  329. session := &BillingSession{
  330. relayInfo: relayInfo,
  331. funding: &WalletFunding{userId: relayInfo.UserId},
  332. }
  333. if apiErr := session.preConsume(c, preConsumedQuota); apiErr != nil {
  334. return nil, apiErr
  335. }
  336. return session, nil
  337. }
  338. trySubscription := func() (*BillingSession, *types.NewAPIError) {
  339. subConsume := int64(preConsumedQuota)
  340. if subConsume <= 0 {
  341. subConsume = 1
  342. }
  343. session := &BillingSession{
  344. relayInfo: relayInfo,
  345. funding: &SubscriptionFunding{
  346. requestId: relayInfo.RequestId,
  347. userId: relayInfo.UserId,
  348. modelName: relayInfo.OriginModelName,
  349. amount: subConsume,
  350. },
  351. }
  352. // 必须传 subConsume 而非 preConsumedQuota,保证 SubscriptionFunding.amount、
  353. // preConsume 参数和 FinalPreConsumedQuota 三者一致,避免订阅多扣费。
  354. if apiErr := session.preConsume(c, int(subConsume)); apiErr != nil {
  355. return nil, apiErr
  356. }
  357. return session, nil
  358. }
  359. switch pref {
  360. case "subscription_only":
  361. return trySubscription()
  362. case "wallet_only":
  363. return tryWallet()
  364. case "wallet_first":
  365. session, err := tryWallet()
  366. if err != nil {
  367. if err.GetErrorCode() == types.ErrorCodeInsufficientUserQuota {
  368. return trySubscription()
  369. }
  370. return nil, err
  371. }
  372. return session, nil
  373. case "subscription_first":
  374. fallthrough
  375. default:
  376. hasSub, subCheckErr := model.HasActiveUserSubscription(relayInfo.UserId)
  377. if subCheckErr != nil {
  378. return nil, types.NewError(subCheckErr, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
  379. }
  380. if !hasSub {
  381. return tryWallet()
  382. }
  383. session, apiErr := trySubscription()
  384. if apiErr != nil {
  385. if apiErr.GetErrorCode() == types.ErrorCodeInsufficientUserQuota {
  386. return tryWallet()
  387. }
  388. return nil, apiErr
  389. }
  390. return session, nil
  391. }
  392. }