Quellcode durchsuchen

Merge pull request #2877 from QuantumNous/refactor/billing-session

refactor: 抽象统一计费会话 BillingSession
Calcium-Ion vor 1 Woche
Ursprung
Commit
8b8ea60b1e

+ 2 - 2
controller/relay.go

@@ -170,8 +170,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
 		// Only return quota if downstream failed and quota was actually pre-consumed
 		if newAPIError != nil {
 			newAPIError = service.NormalizeViolationFeeError(newAPIError)
-			if relayInfo.FinalPreConsumedQuota != 0 {
-				service.ReturnPreConsumedQuota(c, relayInfo)
+			if relayInfo.Billing != nil {
+				relayInfo.Billing.Refund(c)
 			}
 			service.ChargeViolationFeeIfNeeded(c, relayInfo, newAPIError)
 		}

+ 21 - 0
relay/common/billing.go

@@ -0,0 +1,21 @@
+package common
+
+import "github.com/gin-gonic/gin"
+
+// BillingSettler 抽象计费会话的生命周期操作。
+// 由 service.BillingSession 实现,存储在 RelayInfo 上以避免循环引用。
+type BillingSettler interface {
+	// Settle 根据实际消耗额度进行结算,计算 delta = actualQuota - preConsumedQuota,
+	// 同时调整资金来源(钱包/订阅)和令牌额度。
+	Settle(actualQuota int) error
+
+	// Refund 退还所有预扣费额度(资金来源 + 令牌),幂等安全。
+	// 通过 gopool 异步执行。如果已经结算或退款则不做任何操作。
+	Refund(c *gin.Context)
+
+	// NeedsRefund 返回会话是否存在需要退还的预扣状态(未结算且未退款)。
+	NeedsRefund() bool
+
+	// GetPreConsumedQuota 返回实际预扣的额度值(信任用户可能为 0)。
+	GetPreConsumedQuota() int
+}

+ 3 - 0
relay/common/relay_info.go

@@ -115,6 +115,9 @@ type RelayInfo struct {
 	SendResponseCount      int
 	ReceivedResponseCount  int
 	FinalPreConsumedQuota  int // 最终预消耗的配额
+	// Billing 是计费会话,封装了预扣费/结算/退款的统一生命周期。
+	// 免费模型和按次计费(MJ/Task)时为 nil。
+	Billing BillingSettler
 	// BillingSource indicates whether this request is billed from wallet quota or subscription.
 	// "" or "wallet" => wallet; "subscription" => subscription
 	BillingSource string

+ 2 - 23
relay/compatible_handler.go

@@ -423,29 +423,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 	}
 
-	quotaDelta := quota - relayInfo.FinalPreConsumedQuota
-
-	//logger.LogInfo(ctx, fmt.Sprintf("request quota delta: %s", logger.FormatQuota(quotaDelta)))
-
-	if quotaDelta > 0 {
-		logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)",
-			logger.FormatQuota(quotaDelta),
-			logger.FormatQuota(quota),
-			logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
-		))
-	} else if quotaDelta < 0 {
-		logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)",
-			logger.FormatQuota(-quotaDelta),
-			logger.FormatQuota(quota),
-			logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
-		))
-	}
-
-	if quotaDelta != 0 {
-		err := service.PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
-		if err != nil {
-			logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
-		}
+	if err := service.SettleBilling(ctx, relayInfo, quota); err != nil {
+		logger.LogError(ctx, "error settling billing: "+err.Error())
 	}
 
 	logModel := modelName

+ 44 - 76
service/billing.go

@@ -2,12 +2,8 @@ package service
 
 import (
 	"fmt"
-	"net/http"
-	"strings"
 
-	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/logger"
-	"github.com/QuantumNous/new-api/model"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/types"
 	"github.com/gin-gonic/gin"
@@ -18,89 +14,61 @@ const (
 	BillingSourceSubscription = "subscription"
 )
 
-// PreConsumeBilling decides whether to pre-consume from subscription or wallet based on user preference.
-// It also always pre-consumes token quota in quota units (same as legacy flow).
+// PreConsumeBilling 根据用户计费偏好创建 BillingSession 并执行预扣费。
+// 会话存储在 relayInfo.Billing 上,供后续 Settle / Refund 使用。
 func PreConsumeBilling(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError {
-	if relayInfo == nil {
-		return types.NewError(fmt.Errorf("relayInfo is nil"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+	session, apiErr := NewBillingSession(c, relayInfo, preConsumedQuota)
+	if apiErr != nil {
+		return apiErr
 	}
+	relayInfo.Billing = session
+	return nil
+}
 
-	pref := common.NormalizeBillingPreference(relayInfo.UserSetting.BillingPreference)
-	trySubscription := func() *types.NewAPIError {
-		quotaType := 0
-		// For total quota: consume preConsumedQuota quota units.
-		subConsume := int64(preConsumedQuota)
-		if subConsume <= 0 {
-			subConsume = 1
-		}
+// ---------------------------------------------------------------------------
+// SettleBilling — 后结算辅助函数
+// ---------------------------------------------------------------------------
 
-		// Pre-consume token quota in quota units to keep token limits consistent.
-		if preConsumedQuota > 0 {
-			if err := PreConsumeTokenQuota(relayInfo, preConsumedQuota); err != nil {
-				return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
-			}
-		}
+// SettleBilling 执行计费结算。如果 RelayInfo 上有 BillingSession 则通过 session 结算,
+// 否则回退到旧的 PostConsumeQuota 路径(兼容按次计费等场景)。
+func SettleBilling(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, actualQuota int) error {
+	if relayInfo.Billing != nil {
+		preConsumed := relayInfo.Billing.GetPreConsumedQuota()
+		delta := actualQuota - preConsumed
 
-		res, err := model.PreConsumeUserSubscription(relayInfo.RequestId, relayInfo.UserId, relayInfo.OriginModelName, quotaType, subConsume)
-		if err != nil {
-			// revert token pre-consume when subscription fails
-			if preConsumedQuota > 0 && !relayInfo.IsPlayground {
-				_ = model.IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, preConsumedQuota)
-			}
-			errMsg := err.Error()
-			if strings.Contains(errMsg, "no active subscription") || strings.Contains(errMsg, "subscription quota insufficient") {
-				return types.NewErrorWithStatusCode(fmt.Errorf("订阅额度不足或未配置订阅: %s", errMsg), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
-			}
-			return types.NewErrorWithStatusCode(fmt.Errorf("订阅预扣失败: %s", errMsg), types.ErrorCodeQueryDataError, http.StatusInternalServerError)
+		if delta > 0 {
+			logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)",
+				logger.FormatQuota(delta),
+				logger.FormatQuota(actualQuota),
+				logger.FormatQuota(preConsumed),
+			))
+		} else if delta < 0 {
+			logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)",
+				logger.FormatQuota(-delta),
+				logger.FormatQuota(actualQuota),
+				logger.FormatQuota(preConsumed),
+			))
+		} else {
+			logger.LogInfo(ctx, fmt.Sprintf("预扣费与实际消耗一致,无需调整:%s(按次计费)",
+				logger.FormatQuota(actualQuota),
+			))
 		}
 
-		relayInfo.BillingSource = BillingSourceSubscription
-		relayInfo.SubscriptionId = res.UserSubscriptionId
-		relayInfo.SubscriptionPreConsumed = res.PreConsumed
-		relayInfo.SubscriptionPostDelta = 0
-		relayInfo.SubscriptionAmountTotal = res.AmountTotal
-		relayInfo.SubscriptionAmountUsedAfterPreConsume = res.AmountUsedAfter
-		if planInfo, err := model.GetSubscriptionPlanInfoByUserSubscriptionId(res.UserSubscriptionId); err == nil && planInfo != nil {
-			relayInfo.SubscriptionPlanId = planInfo.PlanId
-			relayInfo.SubscriptionPlanTitle = planInfo.PlanTitle
+		if err := relayInfo.Billing.Settle(actualQuota); err != nil {
+			return err
 		}
-		relayInfo.FinalPreConsumedQuota = preConsumedQuota
 
-		logger.LogInfo(c, fmt.Sprintf("用户 %d 使用订阅计费预扣:订阅=%d,token_quota=%d", relayInfo.UserId, res.PreConsumed, preConsumedQuota))
+		// 发送额度通知
+		if actualQuota != 0 {
+			checkAndSendQuotaNotify(relayInfo, actualQuota-preConsumed, preConsumed)
+		}
 		return nil
 	}
 
-	tryWallet := func() *types.NewAPIError {
-		relayInfo.BillingSource = BillingSourceWallet
-		relayInfo.SubscriptionId = 0
-		relayInfo.SubscriptionPreConsumed = 0
-		return PreConsumeQuota(c, preConsumedQuota, relayInfo)
-	}
-
-	switch pref {
-	case "subscription_only":
-		return trySubscription()
-	case "wallet_only":
-		return tryWallet()
-	case "wallet_first":
-		if err := tryWallet(); err != nil {
-			// only fallback for insufficient wallet quota
-			if err.GetErrorCode() == types.ErrorCodeInsufficientUserQuota {
-				return trySubscription()
-			}
-			return err
-		}
-		return nil
-	case "subscription_first":
-		fallthrough
-	default:
-		if err := trySubscription(); err != nil {
-			// fallback only when subscription not available/insufficient
-			if err.GetErrorCode() == types.ErrorCodeInsufficientUserQuota {
-				return tryWallet()
-			}
-			return err
-		}
-		return nil
+	// 回退:无 BillingSession 时使用旧路径
+	quotaDelta := actualQuota - relayInfo.FinalPreConsumedQuota
+	if quotaDelta != 0 {
+		return PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
 	}
+	return nil
 }

+ 335 - 0
service/billing_session.go

@@ -0,0 +1,335 @@
+package service
+
+import (
+	"fmt"
+	"net/http"
+	"strings"
+	"sync"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/logger"
+	"github.com/QuantumNous/new-api/model"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/QuantumNous/new-api/types"
+
+	"github.com/bytedance/gopkg/util/gopool"
+	"github.com/gin-gonic/gin"
+)
+
+// ---------------------------------------------------------------------------
+// BillingSession — 统一计费会话
+// ---------------------------------------------------------------------------
+
+// BillingSession 封装单次请求的预扣费/结算/退款生命周期。
+// 实现 relaycommon.BillingSettler 接口。
+type BillingSession struct {
+	relayInfo        *relaycommon.RelayInfo
+	funding          FundingSource
+	preConsumedQuota int  // 实际预扣额度(信任用户可能为 0)
+	tokenConsumed    int  // 令牌额度实际扣减量
+	fundingSettled   bool // funding.Settle 已成功,资金来源已提交
+	settled          bool // Settle 全部完成(资金 + 令牌)
+	refunded         bool // Refund 已调用
+	mu               sync.Mutex
+}
+
+// Settle 根据实际消耗额度进行结算。
+// 资金来源和令牌额度分两步提交:若资金来源已提交但令牌调整失败,
+// 会标记 fundingSettled 防止 Refund 对已提交的资金来源执行退款。
+func (s *BillingSession) Settle(actualQuota int) error {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	if s.settled {
+		return nil
+	}
+	delta := actualQuota - s.preConsumedQuota
+	if delta == 0 {
+		s.settled = true
+		return nil
+	}
+	// 1) 调整资金来源(仅在尚未提交时执行,防止重复调用)
+	if !s.fundingSettled {
+		if err := s.funding.Settle(delta); err != nil {
+			return err
+		}
+		s.fundingSettled = true
+	}
+	// 2) 调整令牌额度
+	var tokenErr error
+	if !s.relayInfo.IsPlayground {
+		if delta > 0 {
+			tokenErr = model.DecreaseTokenQuota(s.relayInfo.TokenId, s.relayInfo.TokenKey, delta)
+		} else {
+			tokenErr = model.IncreaseTokenQuota(s.relayInfo.TokenId, s.relayInfo.TokenKey, -delta)
+		}
+		if tokenErr != nil {
+			// 资金来源已提交,令牌调整失败只能记录日志;标记 settled 防止 Refund 误退资金
+			common.SysLog(fmt.Sprintf("error adjusting token quota after funding settled (userId=%d, tokenId=%d, delta=%d): %s",
+				s.relayInfo.UserId, s.relayInfo.TokenId, delta, tokenErr.Error()))
+		}
+	}
+	// 3) 更新 relayInfo 上的订阅 PostDelta(用于日志)
+	if s.funding.Source() == BillingSourceSubscription {
+		s.relayInfo.SubscriptionPostDelta += int64(delta)
+	}
+	s.settled = true
+	return tokenErr
+}
+
+// Refund 退还所有预扣费,幂等安全,异步执行。
+func (s *BillingSession) Refund(c *gin.Context) {
+	s.mu.Lock()
+	if s.settled || s.refunded || !s.needsRefundLocked() {
+		s.mu.Unlock()
+		return
+	}
+	s.refunded = true
+	s.mu.Unlock()
+
+	logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费(token_quota=%s, funding=%s)",
+		s.relayInfo.UserId,
+		logger.FormatQuota(s.tokenConsumed),
+		s.funding.Source(),
+	))
+
+	// 复制需要的值到闭包中
+	tokenId := s.relayInfo.TokenId
+	tokenKey := s.relayInfo.TokenKey
+	isPlayground := s.relayInfo.IsPlayground
+	tokenConsumed := s.tokenConsumed
+	funding := s.funding
+
+	gopool.Go(func() {
+		// 1) 退还资金来源
+		if err := funding.Refund(); err != nil {
+			common.SysLog("error refunding billing source: " + err.Error())
+		}
+		// 2) 退还令牌额度
+		if tokenConsumed > 0 && !isPlayground {
+			if err := model.IncreaseTokenQuota(tokenId, tokenKey, tokenConsumed); err != nil {
+				common.SysLog("error refunding token quota: " + err.Error())
+			}
+		}
+	})
+}
+
+// NeedsRefund 返回是否存在需要退还的预扣状态。
+func (s *BillingSession) NeedsRefund() bool {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	return s.needsRefundLocked()
+}
+
+func (s *BillingSession) needsRefundLocked() bool {
+	if s.settled || s.refunded || s.fundingSettled {
+		// fundingSettled 时资金来源已提交结算,不能再退预扣费
+		return false
+	}
+	if s.tokenConsumed > 0 {
+		return true
+	}
+	// 订阅可能在 tokenConsumed=0 时仍预扣了额度
+	if sub, ok := s.funding.(*SubscriptionFunding); ok && sub.preConsumed > 0 {
+		return true
+	}
+	return false
+}
+
+// GetPreConsumedQuota 返回实际预扣的额度。
+func (s *BillingSession) GetPreConsumedQuota() int {
+	return s.preConsumedQuota
+}
+
+// ---------------------------------------------------------------------------
+// PreConsume — 统一预扣费入口(含信任额度旁路)
+// ---------------------------------------------------------------------------
+
+// preConsume 执行预扣费:信任检查 -> 令牌预扣 -> 资金来源预扣。
+// 任一步骤失败时原子回滚已完成的步骤。
+func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIError {
+	effectiveQuota := quota
+
+	// ---- 信任额度旁路 ----
+	if s.shouldTrust(c) {
+		effectiveQuota = 0
+		logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足, 信任且不需要预扣费 (funding=%s)", s.relayInfo.UserId, s.funding.Source()))
+	} else if effectiveQuota > 0 {
+		logger.LogInfo(c, fmt.Sprintf("用户 %d 需要预扣费 %s (funding=%s)", s.relayInfo.UserId, logger.FormatQuota(effectiveQuota), s.funding.Source()))
+	}
+
+	// ---- 1) 预扣令牌额度 ----
+	if effectiveQuota > 0 {
+		if err := PreConsumeTokenQuota(s.relayInfo, effectiveQuota); err != nil {
+			return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+		}
+		s.tokenConsumed = effectiveQuota
+	}
+
+	// ---- 2) 预扣资金来源 ----
+	if err := s.funding.PreConsume(effectiveQuota); err != nil {
+		// 预扣费失败,回滚令牌额度
+		if s.tokenConsumed > 0 && !s.relayInfo.IsPlayground {
+			if rollbackErr := model.IncreaseTokenQuota(s.relayInfo.TokenId, s.relayInfo.TokenKey, s.tokenConsumed); rollbackErr != nil {
+				common.SysLog(fmt.Sprintf("error rolling back token quota (userId=%d, tokenId=%d, amount=%d, fundingErr=%s): %s",
+					s.relayInfo.UserId, s.relayInfo.TokenId, s.tokenConsumed, err.Error(), rollbackErr.Error()))
+			}
+			s.tokenConsumed = 0
+		}
+		// TODO: model 层应定义哨兵错误(如 ErrNoActiveSubscription),用 errors.Is 替代字符串匹配
+		errMsg := err.Error()
+		if strings.Contains(errMsg, "no active subscription") || strings.Contains(errMsg, "subscription quota insufficient") {
+			return types.NewErrorWithStatusCode(fmt.Errorf("订阅额度不足或未配置订阅: %s", errMsg), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+		}
+		return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
+	}
+
+	s.preConsumedQuota = effectiveQuota
+
+	// ---- 同步 RelayInfo 兼容字段 ----
+	s.syncRelayInfo()
+
+	return nil
+}
+
+// shouldTrust 统一信任额度检查,适用于钱包和订阅。
+func (s *BillingSession) shouldTrust(c *gin.Context) bool {
+	trustQuota := common.GetTrustQuota()
+	if trustQuota <= 0 {
+		return false
+	}
+
+	// 检查令牌是否充足
+	tokenTrusted := s.relayInfo.TokenUnlimited
+	if !tokenTrusted {
+		tokenQuota := c.GetInt("token_quota")
+		tokenTrusted = tokenQuota > trustQuota
+	}
+	if !tokenTrusted {
+		return false
+	}
+
+	switch s.funding.Source() {
+	case BillingSourceWallet:
+		return s.relayInfo.UserQuota > trustQuota
+	case BillingSourceSubscription:
+		// 订阅不能启用信任旁路。原因:
+		// 1. PreConsumeUserSubscription 要求 amount>0 来创建预扣记录并锁定订阅
+		// 2. SubscriptionFunding.PreConsume 忽略参数,始终用 s.amount 预扣
+		// 3. 若信任旁路将 effectiveQuota 设为 0,会导致 preConsumedQuota 与实际订阅预扣不一致
+		return false
+	default:
+		return false
+	}
+}
+
+// syncRelayInfo 将 BillingSession 的状态同步到 RelayInfo 的兼容字段上。
+func (s *BillingSession) syncRelayInfo() {
+	info := s.relayInfo
+	info.FinalPreConsumedQuota = s.preConsumedQuota
+	info.BillingSource = s.funding.Source()
+
+	if sub, ok := s.funding.(*SubscriptionFunding); ok {
+		info.SubscriptionId = sub.subscriptionId
+		info.SubscriptionPreConsumed = sub.preConsumed
+		info.SubscriptionPostDelta = 0
+		info.SubscriptionAmountTotal = sub.AmountTotal
+		info.SubscriptionAmountUsedAfterPreConsume = sub.AmountUsedAfter
+		info.SubscriptionPlanId = sub.PlanId
+		info.SubscriptionPlanTitle = sub.PlanTitle
+	} else {
+		info.SubscriptionId = 0
+		info.SubscriptionPreConsumed = 0
+	}
+}
+
+// ---------------------------------------------------------------------------
+// NewBillingSession 工厂 — 根据计费偏好创建会话并处理回退
+// ---------------------------------------------------------------------------
+
+// NewBillingSession 根据用户计费偏好创建 BillingSession,处理 subscription_first / wallet_first 的回退。
+func NewBillingSession(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) (*BillingSession, *types.NewAPIError) {
+	if relayInfo == nil {
+		return nil, types.NewError(fmt.Errorf("relayInfo is nil"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+	}
+
+	pref := common.NormalizeBillingPreference(relayInfo.UserSetting.BillingPreference)
+
+	// 钱包路径需要先检查用户额度
+	tryWallet := func() (*BillingSession, *types.NewAPIError) {
+		userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
+		if err != nil {
+			return nil, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
+		}
+		if userQuota <= 0 {
+			return nil, types.NewErrorWithStatusCode(
+				fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)),
+				types.ErrorCodeInsufficientUserQuota, http.StatusForbidden,
+				types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+		}
+		if userQuota-preConsumedQuota < 0 {
+			return nil, types.NewErrorWithStatusCode(
+				fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)),
+				types.ErrorCodeInsufficientUserQuota, http.StatusForbidden,
+				types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+		}
+		relayInfo.UserQuota = userQuota
+
+		session := &BillingSession{
+			relayInfo: relayInfo,
+			funding:   &WalletFunding{userId: relayInfo.UserId},
+		}
+		if apiErr := session.preConsume(c, preConsumedQuota); apiErr != nil {
+			return nil, apiErr
+		}
+		return session, nil
+	}
+
+	trySubscription := func() (*BillingSession, *types.NewAPIError) {
+		subConsume := int64(preConsumedQuota)
+		if subConsume <= 0 {
+			subConsume = 1
+		}
+		session := &BillingSession{
+			relayInfo: relayInfo,
+			funding: &SubscriptionFunding{
+				requestId: relayInfo.RequestId,
+				userId:    relayInfo.UserId,
+				modelName: relayInfo.OriginModelName,
+				amount:    subConsume,
+			},
+		}
+		// 必须传 subConsume 而非 preConsumedQuota,保证 SubscriptionFunding.amount、
+		// preConsume 参数和 FinalPreConsumedQuota 三者一致,避免订阅多扣费。
+		if apiErr := session.preConsume(c, int(subConsume)); apiErr != nil {
+			return nil, apiErr
+		}
+		return session, nil
+	}
+
+	switch pref {
+	case "subscription_only":
+		return trySubscription()
+	case "wallet_only":
+		return tryWallet()
+	case "wallet_first":
+		session, err := tryWallet()
+		if err != nil {
+			if err.GetErrorCode() == types.ErrorCodeInsufficientUserQuota {
+				return trySubscription()
+			}
+			return nil, err
+		}
+		return session, nil
+	case "subscription_first":
+		fallthrough
+	default:
+		session, err := trySubscription()
+		if err != nil {
+			if err.GetErrorCode() == types.ErrorCodeInsufficientUserQuota {
+				return tryWallet()
+			}
+			return nil, err
+		}
+		return session, nil
+	}
+}

+ 139 - 0
service/funding_source.go

@@ -0,0 +1,139 @@
+package service
+
+import (
+	"time"
+
+	"github.com/QuantumNous/new-api/model"
+)
+
+// ---------------------------------------------------------------------------
+// FundingSource — 资金来源接口(钱包 or 订阅)
+// ---------------------------------------------------------------------------
+
+// FundingSource 抽象了预扣费的资金来源。
+type FundingSource interface {
+	// Source 返回资金来源标识:"wallet" 或 "subscription"
+	Source() string
+	// PreConsume 从该资金来源预扣 amount 额度
+	PreConsume(amount int) error
+	// Settle 根据差额调整资金来源(正数补扣,负数退还)
+	Settle(delta int) error
+	// Refund 退还所有预扣费
+	Refund() error
+}
+
+// ---------------------------------------------------------------------------
+// WalletFunding — 钱包资金来源实现
+// ---------------------------------------------------------------------------
+
+type WalletFunding struct {
+	userId   int
+	consumed int // 实际预扣的用户额度
+}
+
+func (w *WalletFunding) Source() string { return BillingSourceWallet }
+
+func (w *WalletFunding) PreConsume(amount int) error {
+	if amount <= 0 {
+		return nil
+	}
+	if err := model.DecreaseUserQuota(w.userId, amount); err != nil {
+		return err
+	}
+	w.consumed = amount
+	return nil
+}
+
+func (w *WalletFunding) Settle(delta int) error {
+	if delta == 0 {
+		return nil
+	}
+	if delta > 0 {
+		return model.DecreaseUserQuota(w.userId, delta)
+	}
+	return model.IncreaseUserQuota(w.userId, -delta, false)
+}
+
+func (w *WalletFunding) Refund() error {
+	if w.consumed <= 0 {
+		return nil
+	}
+	// IncreaseUserQuota 是 quota += N 的非幂等操作,不能重试,否则会多退额度。
+	// 订阅的 RefundSubscriptionPreConsume 有 requestId 幂等保护所以可以重试。
+	return model.IncreaseUserQuota(w.userId, w.consumed, false)
+}
+
+// ---------------------------------------------------------------------------
+// SubscriptionFunding — 订阅资金来源实现
+// ---------------------------------------------------------------------------
+
+type SubscriptionFunding struct {
+	requestId      string
+	userId         int
+	modelName      string
+	amount         int64 // 预扣的订阅额度(subConsume)
+	subscriptionId int
+	preConsumed    int64
+	// 以下字段在 PreConsume 成功后填充,供 RelayInfo 同步使用
+	AmountTotal     int64
+	AmountUsedAfter int64
+	PlanId          int
+	PlanTitle       string
+}
+
+func (s *SubscriptionFunding) Source() string { return BillingSourceSubscription }
+
+func (s *SubscriptionFunding) PreConsume(_ int) error {
+	// amount 参数被忽略,使用内部 s.amount(已在构造时根据 preConsumedQuota 计算)
+	res, err := model.PreConsumeUserSubscription(s.requestId, s.userId, s.modelName, 0, s.amount)
+	if err != nil {
+		return err
+	}
+	s.subscriptionId = res.UserSubscriptionId
+	s.preConsumed = res.PreConsumed
+	s.AmountTotal = res.AmountTotal
+	s.AmountUsedAfter = res.AmountUsedAfter
+	// 获取订阅计划信息
+	if planInfo, err := model.GetSubscriptionPlanInfoByUserSubscriptionId(res.UserSubscriptionId); err == nil && planInfo != nil {
+		s.PlanId = planInfo.PlanId
+		s.PlanTitle = planInfo.PlanTitle
+	}
+	return nil
+}
+
+func (s *SubscriptionFunding) Settle(delta int) error {
+	if delta == 0 {
+		return nil
+	}
+	return model.PostConsumeUserSubscriptionDelta(s.subscriptionId, int64(delta))
+}
+
+func (s *SubscriptionFunding) Refund() error {
+	if s.preConsumed <= 0 {
+		return nil
+	}
+	return refundWithRetry(func() error {
+		return model.RefundSubscriptionPreConsume(s.requestId)
+	})
+}
+
+// refundWithRetry 尝试多次执行退款操作以提高成功率,只能用于基于事务的退款函数!!!!!!
+// try to refund with retries, only for refund functions based on transactions!!!
+func refundWithRetry(fn func() error) error {
+	if fn == nil {
+		return nil
+	}
+	const maxAttempts = 3
+	var lastErr error
+	for i := 0; i < maxAttempts; i++ {
+		if err := fn(); err == nil {
+			return nil
+		} else {
+			lastErr = err
+		}
+		if i < maxAttempts-1 {
+			time.Sleep(time.Duration(200*(i+1)) * time.Millisecond)
+		}
+	}
+	return lastErr
+}

+ 0 - 124
service/pre_consume_quota.go

@@ -1,124 +0,0 @@
-package service
-
-import (
-	"fmt"
-	"net/http"
-	"time"
-
-	"github.com/QuantumNous/new-api/common"
-	"github.com/QuantumNous/new-api/logger"
-	"github.com/QuantumNous/new-api/model"
-	relaycommon "github.com/QuantumNous/new-api/relay/common"
-	"github.com/QuantumNous/new-api/types"
-
-	"github.com/bytedance/gopkg/util/gopool"
-	"github.com/gin-gonic/gin"
-)
-
-func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
-	// Always refund subscription pre-consumed (can be non-zero even when FinalPreConsumedQuota is 0)
-	needRefundSub := relayInfo.BillingSource == BillingSourceSubscription && relayInfo.SubscriptionId != 0 && relayInfo.SubscriptionPreConsumed > 0
-	needRefundToken := relayInfo.FinalPreConsumedQuota != 0
-	if !needRefundSub && !needRefundToken {
-		return
-	}
-	logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费(token_quota=%s, subscription=%d)",
-		relayInfo.UserId,
-		logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
-		relayInfo.SubscriptionPreConsumed,
-	))
-	gopool.Go(func() {
-		relayInfoCopy := *relayInfo
-		if relayInfoCopy.BillingSource == BillingSourceSubscription {
-			if needRefundSub {
-				if err := refundWithRetry(func() error {
-					return model.RefundSubscriptionPreConsume(relayInfoCopy.RequestId)
-				}); err != nil {
-					common.SysLog("error refund subscription pre-consume: " + err.Error())
-				}
-			}
-			// refund token quota only
-			if needRefundToken && !relayInfoCopy.IsPlayground {
-				_ = model.IncreaseTokenQuota(relayInfoCopy.TokenId, relayInfoCopy.TokenKey, relayInfoCopy.FinalPreConsumedQuota)
-			}
-			return
-		}
-
-		// wallet refund uses existing path (user quota + token quota)
-		if needRefundToken {
-			err := PostConsumeQuota(&relayInfoCopy, -relayInfoCopy.FinalPreConsumedQuota, 0, false)
-			if err != nil {
-				common.SysLog("error return pre-consumed quota: " + err.Error())
-			}
-		}
-	})
-}
-
-func refundWithRetry(fn func() error) error {
-	if fn == nil {
-		return nil
-	}
-	const maxAttempts = 3
-	var lastErr error
-	for i := 0; i < maxAttempts; i++ {
-		if err := fn(); err == nil {
-			return nil
-		} else {
-			lastErr = err
-		}
-		if i < maxAttempts-1 {
-			time.Sleep(time.Duration(200*(i+1)) * time.Millisecond)
-		}
-	}
-	return lastErr
-}
-
-// PreConsumeQuota checks if the user has enough quota to pre-consume.
-// It returns the pre-consumed quota if successful, or an error if not.
-func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError {
-	userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
-	if err != nil {
-		return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
-	}
-	if userQuota <= 0 {
-		return types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
-	}
-	if userQuota-preConsumedQuota < 0 {
-		return types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
-	}
-
-	trustQuota := common.GetTrustQuota()
-
-	relayInfo.UserQuota = userQuota
-	if userQuota > trustQuota {
-		// 用户额度充足,判断令牌额度是否充足
-		if !relayInfo.TokenUnlimited {
-			// 非无限令牌,判断令牌额度是否充足
-			tokenQuota := c.GetInt("token_quota")
-			if tokenQuota > trustQuota {
-				// 令牌额度充足,信任令牌
-				preConsumedQuota = 0
-				logger.LogInfo(c, fmt.Sprintf("用户 %d 剩余额度 %s 且令牌 %d 额度 %d 充足, 信任且不需要预扣费", relayInfo.UserId, logger.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
-			}
-		} else {
-			// in this case, we do not pre-consume quota
-			// because the user has enough quota
-			preConsumedQuota = 0
-			logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足且为无限额度令牌, 信任且不需要预扣费", relayInfo.UserId))
-		}
-	}
-
-	if preConsumedQuota > 0 {
-		err := PreConsumeTokenQuota(relayInfo, preConsumedQuota)
-		if err != nil {
-			return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
-		}
-		err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
-		if err != nil {
-			return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
-		}
-		logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota)))
-	}
-	relayInfo.FinalPreConsumedQuota = preConsumedQuota
-	return nil
-}

+ 4 - 42
service/quota.go

@@ -307,27 +307,8 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 	}
 
-	quotaDelta := quota - relayInfo.FinalPreConsumedQuota
-
-	if quotaDelta > 0 {
-		logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)",
-			logger.FormatQuota(quotaDelta),
-			logger.FormatQuota(quota),
-			logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
-		))
-	} else if quotaDelta < 0 {
-		logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)",
-			logger.FormatQuota(-quotaDelta),
-			logger.FormatQuota(quota),
-			logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
-		))
-	}
-
-	if quotaDelta != 0 {
-		err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
-		if err != nil {
-			logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
-		}
+	if err := SettleBilling(ctx, relayInfo, quota); err != nil {
+		logger.LogError(ctx, "error settling billing: "+err.Error())
 	}
 
 	other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
@@ -432,27 +413,8 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, u
 		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 	}
 
-	quotaDelta := quota - relayInfo.FinalPreConsumedQuota
-
-	if quotaDelta > 0 {
-		logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)",
-			logger.FormatQuota(quotaDelta),
-			logger.FormatQuota(quota),
-			logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
-		))
-	} else if quotaDelta < 0 {
-		logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)",
-			logger.FormatQuota(-quotaDelta),
-			logger.FormatQuota(quota),
-			logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
-		))
-	}
-
-	if quotaDelta != 0 {
-		err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
-		if err != nil {
-			logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
-		}
+	if err := SettleBilling(ctx, relayInfo, quota); err != nil {
+		logger.LogError(ctx, "error settling billing: "+err.Error())
 	}
 
 	logModel := relayInfo.OriginModelName