Explorar el Código

fix(channel-test): support tiered billing model tests (#4145)

Pre-fill BillingRequestInput from dto.Request before ModelPriceHelper,
so tiered_expr billing resolves param() from the structured request
instead of reading HTTP body (which is empty in channel-test context).

- attachTestBillingRequestInput: marshal dto.Request → RequestInput
- ResolveIncomingBillingExprRequestInput: early-return when pre-filled
- settleTestQuota / buildTestLogOther: align test settlement & logging
  with production TryTieredSettle / InjectTieredBillingInfo paths
yyhhyyyyyy hace 3 días
padre
commit
0220df8429

+ 56 - 12
controller/channel-test.go

@@ -20,6 +20,7 @@ import (
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/middleware"
 	"github.com/QuantumNous/new-api/model"
+	"github.com/QuantumNous/new-api/pkg/billingexpr"
 	"github.com/QuantumNous/new-api/relay"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	relayconstant "github.com/QuantumNous/new-api/relay/constant"
@@ -232,6 +233,15 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
 	info.IsChannelTest = true
 	info.InitChannelMeta(c)
 
+	err = attachTestBillingRequestInput(info, request)
+	if err != nil {
+		return testResult{
+			context:     c,
+			localErr:    err,
+			newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
+		}
+	}
+
 	err = helper.ModelMappedHelper(c, info, request)
 	if err != nil {
 		return testResult{
@@ -468,21 +478,11 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
 	}
 	info.SetEstimatePromptTokens(usage.PromptTokens)
 
-	quota := 0
-	if !priceData.UsePrice {
-		quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
-		quota = int(math.Round(float64(quota) * priceData.ModelRatio))
-		if priceData.ModelRatio != 0 && quota <= 0 {
-			quota = 1
-		}
-	} else {
-		quota = int(priceData.ModelPrice * common.QuotaPerUnit)
-	}
+	quota, tieredResult := settleTestQuota(info, priceData, usage)
 	tok := time.Now()
 	milliseconds := tok.Sub(tik).Milliseconds()
 	consumedTime := float64(milliseconds) / 1000.0
-	other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
-		usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
+	other := buildTestLogOther(c, info, priceData, usage, tieredResult)
 	model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
 		ChannelId:        channel.Id,
 		PromptTokens:     usage.PromptTokens,
@@ -504,6 +504,50 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
 	}
 }
 
+func attachTestBillingRequestInput(info *relaycommon.RelayInfo, request dto.Request) error {
+	if info == nil {
+		return nil
+	}
+
+	input, err := helper.BuildBillingExprRequestInputFromRequest(request, info.RequestHeaders)
+	if err != nil {
+		return err
+	}
+	info.BillingRequestInput = &input
+	return nil
+}
+
+func settleTestQuota(info *relaycommon.RelayInfo, priceData types.PriceData, usage *dto.Usage) (int, *billingexpr.TieredResult) {
+	if usage != nil && info != nil && info.TieredBillingSnapshot != nil {
+		isClaudeUsageSemantic := usage.UsageSemantic == "anthropic" || info.GetFinalRequestRelayFormat() == types.RelayFormatClaude
+		usedVars := billingexpr.UsedVars(info.TieredBillingSnapshot.ExprString)
+		if ok, quota, result := service.TryTieredSettle(info, service.BuildTieredTokenParams(usage, isClaudeUsageSemantic, usedVars)); ok {
+			return quota, result
+		}
+	}
+
+	quota := 0
+	if !priceData.UsePrice {
+		quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
+		quota = int(math.Round(float64(quota) * priceData.ModelRatio))
+		if priceData.ModelRatio != 0 && quota <= 0 {
+			quota = 1
+		}
+		return quota, nil
+	}
+
+	return int(priceData.ModelPrice * common.QuotaPerUnit), nil
+}
+
+func buildTestLogOther(c *gin.Context, info *relaycommon.RelayInfo, priceData types.PriceData, usage *dto.Usage, tieredResult *billingexpr.TieredResult) map[string]interface{} {
+	other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
+		usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
+	if tieredResult != nil {
+		service.InjectTieredBillingInfo(other, info, tieredResult)
+	}
+	return other
+}
+
 func coerceTestUsage(usageAny any, isStream bool, estimatePromptTokens int) (*dto.Usage, error) {
 	switch u := usageAny.(type) {
 	case *dto.Usage:

+ 71 - 0
controller/channel_test_internal_test.go

@@ -0,0 +1,71 @@
+package controller
+
+import (
+	"net/http/httptest"
+	"testing"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/dto"
+	"github.com/QuantumNous/new-api/pkg/billingexpr"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/QuantumNous/new-api/types"
+	"github.com/gin-gonic/gin"
+	"github.com/stretchr/testify/require"
+)
+
+func TestSettleTestQuotaUsesTieredBilling(t *testing.T) {
+	info := &relaycommon.RelayInfo{
+		TieredBillingSnapshot: &billingexpr.BillingSnapshot{
+			BillingMode:   "tiered_expr",
+			ExprString:    `param("stream") == true ? tier("stream", p * 3) : tier("base", p * 2)`,
+			ExprHash:      billingexpr.ExprHashString(`param("stream") == true ? tier("stream", p * 3) : tier("base", p * 2)`),
+			GroupRatio:    1,
+			EstimatedTier: "stream",
+			QuotaPerUnit:  common.QuotaPerUnit,
+			ExprVersion:   1,
+		},
+		BillingRequestInput: &billingexpr.RequestInput{
+			Body: []byte(`{"stream":true}`),
+		},
+	}
+
+	quota, result := settleTestQuota(info, types.PriceData{
+		ModelRatio:      1,
+		CompletionRatio: 2,
+	}, &dto.Usage{
+		PromptTokens: 1000,
+	})
+
+	require.Equal(t, 1500, quota)
+	require.NotNil(t, result)
+	require.Equal(t, "stream", result.MatchedTier)
+}
+
+func TestBuildTestLogOtherInjectsTieredInfo(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+	ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
+
+	info := &relaycommon.RelayInfo{
+		TieredBillingSnapshot: &billingexpr.BillingSnapshot{
+			BillingMode: "tiered_expr",
+			ExprString:  `tier("base", p * 2)`,
+		},
+		ChannelMeta: &relaycommon.ChannelMeta{},
+	}
+	priceData := types.PriceData{
+		GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1},
+	}
+	usage := &dto.Usage{
+		PromptTokensDetails: dto.InputTokenDetails{
+			CachedTokens: 12,
+		},
+	}
+
+	other := buildTestLogOther(ctx, info, priceData, usage, &billingexpr.TieredResult{
+		MatchedTier: "base",
+	})
+
+	require.Equal(t, "tiered_expr", other["billing_mode"])
+	require.Equal(t, "base", other["matched_tier"])
+	require.NotEmpty(t, other["expr_b64"])
+}

+ 35 - 0
relay/helper/billing_expr_request.go

@@ -4,12 +4,21 @@ import (
 	"strings"
 
 	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/pkg/billingexpr"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/gin-gonic/gin"
 )
 
 func ResolveIncomingBillingExprRequestInput(c *gin.Context, info *relaycommon.RelayInfo) (billingexpr.RequestInput, error) {
+	if info != nil && info.BillingRequestInput != nil {
+		input := cloneRequestInput(*info.BillingRequestInput)
+		if len(input.Headers) == 0 {
+			input.Headers = cloneStringMap(info.RequestHeaders)
+		}
+		return input, nil
+	}
+
 	input := billingexpr.RequestInput{}
 	if info != nil {
 		input.Headers = cloneStringMap(info.RequestHeaders)
@@ -23,6 +32,22 @@ func ResolveIncomingBillingExprRequestInput(c *gin.Context, info *relaycommon.Re
 	return input, nil
 }
 
+func BuildBillingExprRequestInputFromRequest(request dto.Request, headers map[string]string) (billingexpr.RequestInput, error) {
+	input := billingexpr.RequestInput{
+		Headers: cloneStringMap(headers),
+	}
+	if request == nil {
+		return input, nil
+	}
+
+	bodyBytes, err := common.Marshal(request)
+	if err != nil {
+		return billingexpr.RequestInput{}, err
+	}
+	input.Body = bodyBytes
+	return input, nil
+}
+
 func readIncomingBillingExprBody(c *gin.Context) ([]byte, error) {
 	if c == nil || c.Request == nil || !isJSONContentType(c.Request.Header.Get("Content-Type")) {
 		return nil, nil
@@ -34,6 +59,16 @@ func readIncomingBillingExprBody(c *gin.Context) ([]byte, error) {
 	return storage.Bytes()
 }
 
+func cloneRequestInput(src billingexpr.RequestInput) billingexpr.RequestInput {
+	input := billingexpr.RequestInput{
+		Headers: cloneStringMap(src.Headers),
+	}
+	if len(src.Body) > 0 {
+		input.Body = append([]byte(nil), src.Body...)
+	}
+	return input
+}
+
 func isJSONContentType(contentType string) bool {
 	contentType = strings.ToLower(strings.TrimSpace(contentType))
 	return strings.HasPrefix(contentType, "application/json")

+ 28 - 0
relay/helper/billing_expr_request_test.go

@@ -8,9 +8,12 @@ import (
 	"testing"
 
 	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/dto"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/gin-gonic/gin"
+	"github.com/samber/lo"
 	"github.com/stretchr/testify/require"
+	"github.com/tidwall/gjson"
 )
 
 func TestResolveIncomingBillingExprRequestInput(t *testing.T) {
@@ -33,3 +36,28 @@ func TestResolveIncomingBillingExprRequestInput(t *testing.T) {
 	require.Equal(t, body, input.Body)
 	require.Equal(t, "application/json", input.Headers["Content-Type"])
 }
+
+func TestBuildBillingExprRequestInputFromRequest(t *testing.T) {
+	request := &dto.GeneralOpenAIRequest{
+		Model:  "gemini-3.1-pro-preview",
+		Stream: lo.ToPtr(true),
+		Messages: []dto.Message{
+			{
+				Role:    "user",
+				Content: "hi",
+			},
+		},
+		MaxTokens: lo.ToPtr(uint(3000)),
+	}
+
+	input, err := BuildBillingExprRequestInputFromRequest(request, map[string]string{
+		"Content-Type": "application/json",
+		"X-Test":       "1",
+	})
+	require.NoError(t, err)
+	require.Equal(t, "application/json", input.Headers["Content-Type"])
+	require.Equal(t, "1", input.Headers["X-Test"])
+	require.True(t, gjson.GetBytes(input.Body, "stream").Bool())
+	require.Equal(t, "user", gjson.GetBytes(input.Body, "messages.0.role").String())
+	require.Equal(t, float64(3000), gjson.GetBytes(input.Body, "max_tokens").Float())
+}

+ 62 - 0
relay/helper/price_test.go

@@ -0,0 +1,62 @@
+package helper
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/pkg/billingexpr"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/QuantumNous/new-api/setting/billing_setting"
+	"github.com/QuantumNous/new-api/setting/config"
+	"github.com/QuantumNous/new-api/types"
+	"github.com/gin-gonic/gin"
+	"github.com/stretchr/testify/require"
+)
+
+func TestModelPriceHelperTieredUsesPreloadedRequestInput(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+
+	saved := map[string]string{}
+	require.NoError(t, config.GlobalConfig.SaveToDB(func(key, value string) error {
+		saved[key] = value
+		return nil
+	}))
+	t.Cleanup(func() {
+		require.NoError(t, config.GlobalConfig.LoadFromDB(saved))
+	})
+
+	require.NoError(t, config.GlobalConfig.LoadFromDB(map[string]string{
+		"billing_setting.billing_mode": `{"tiered-test-model":"tiered_expr"}`,
+		"billing_setting.billing_expr": `{"tiered-test-model":"param(\"stream\") == true ? tier(\"stream\", p * 3) : tier(\"base\", p * 2)"}`,
+	}))
+
+	recorder := httptest.NewRecorder()
+	ctx, _ := gin.CreateTestContext(recorder)
+	req := httptest.NewRequest(http.MethodPost, "/api/channel/test/1", nil)
+	req.Body = nil
+	req.ContentLength = 0
+	req.Header.Set("Content-Type", "application/json")
+	ctx.Request = req
+	ctx.Set("group", "default")
+
+	info := &relaycommon.RelayInfo{
+		OriginModelName: "tiered-test-model",
+		UserGroup:       "default",
+		UsingGroup:      "default",
+		RequestHeaders:  map[string]string{"Content-Type": "application/json"},
+		BillingRequestInput: &billingexpr.RequestInput{
+			Headers: map[string]string{"Content-Type": "application/json"},
+			Body:    []byte(`{"stream":true}`),
+		},
+	}
+
+	priceData, err := ModelPriceHelper(ctx, info, 1000, &types.TokenCountMeta{})
+	require.NoError(t, err)
+	require.Equal(t, 1500, priceData.QuotaToPreConsume)
+	require.NotNil(t, info.TieredBillingSnapshot)
+	require.Equal(t, "stream", info.TieredBillingSnapshot.EstimatedTier)
+	require.Equal(t, billing_setting.BillingModeTieredExpr, info.TieredBillingSnapshot.BillingMode)
+	require.Equal(t, common.QuotaPerUnit, info.TieredBillingSnapshot.QuotaPerUnit)
+}