Prechádzať zdrojové kódy

feat(audio): enhance audio request handling with token type detection and streaming support

CaIon 2 týždňov pred
rodič
commit
e36e2e1b69

+ 5 - 1
dto/audio.go

@@ -2,6 +2,7 @@ package dto
 
 import (
 	"encoding/json"
+	"strings"
 
 	"github.com/QuantumNous/new-api/types"
 
@@ -24,11 +25,14 @@ func (r *AudioRequest) GetTokenCountMeta() *types.TokenCountMeta {
 		CombineText: r.Input,
 		TokenType:   types.TokenTypeTextNumber,
 	}
+	if strings.Contains(r.Model, "gpt") {
+		meta.TokenType = types.TokenTypeTokenizer
+	}
 	return meta
 }
 
 func (r *AudioRequest) IsStream(c *gin.Context) bool {
-	return false
+	return r.StreamFormat == "sse"
 }
 
 func (r *AudioRequest) SetModelName(modelName string) {

+ 5 - 2
relay/audio_handler.go

@@ -67,8 +67,11 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
 		service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 		return newAPIError
 	}
-
-	postConsumeQuota(c, info, usage.(*dto.Usage), "")
+	if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
+		service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
+	} else {
+		postConsumeQuota(c, info, usage.(*dto.Usage), "")
+	}
 
 	return nil
 }

+ 145 - 0
relay/channel/openai/audio.go

@@ -0,0 +1,145 @@
+package openai
+
+import (
+	"bytes"
+	"fmt"
+	"io"
+	"math"
+	"net/http"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/constant"
+	"github.com/QuantumNous/new-api/dto"
+	"github.com/QuantumNous/new-api/logger"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/QuantumNous/new-api/relay/helper"
+	"github.com/QuantumNous/new-api/service"
+	"github.com/QuantumNous/new-api/types"
+	"github.com/gin-gonic/gin"
+)
+
+func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
+	// the status code has been judged before, if there is a body reading failure,
+	// it should be regarded as a non-recoverable error, so it should not return err for external retry.
+	// Analogous to nginx's load balancing, it will only retry if it can't be requested or
+	// if the upstream returns a specific status code, once the upstream has already written the header,
+	// the subsequent failure of the response body should be regarded as a non-recoverable error,
+	// and can be terminated directly.
+	defer service.CloseResponseBodyGracefully(resp)
+	usage := &dto.Usage{}
+	usage.PromptTokens = info.GetEstimatePromptTokens()
+	usage.TotalTokens = info.GetEstimatePromptTokens()
+	for k, v := range resp.Header {
+		c.Writer.Header().Set(k, v[0])
+	}
+	c.Writer.WriteHeader(resp.StatusCode)
+
+	if info.IsStream {
+		helper.StreamScannerHandler(c, resp, info, func(data string) bool {
+			if service.SundaySearch(data, "usage") {
+				var simpleResponse dto.SimpleResponse
+				err := common.Unmarshal([]byte(data), &simpleResponse)
+				if err != nil {
+					logger.LogError(c, err.Error())
+				}
+				if simpleResponse.Usage.TotalTokens != 0 {
+					usage.PromptTokens = simpleResponse.Usage.InputTokens
+					usage.CompletionTokens = simpleResponse.OutputTokens
+					usage.TotalTokens = simpleResponse.TotalTokens
+				}
+			}
+			_ = helper.StringData(c, data)
+			return true
+		})
+	} else {
+		common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
+		// 读取响应体到缓冲区
+		bodyBytes, err := io.ReadAll(resp.Body)
+		if err != nil {
+			logger.LogError(c, fmt.Sprintf("failed to read TTS response body: %v", err))
+			c.Writer.WriteHeaderNow()
+			return usage
+		}
+
+		// 写入响应到客户端
+		c.Writer.WriteHeaderNow()
+		_, err = c.Writer.Write(bodyBytes)
+		if err != nil {
+			logger.LogError(c, fmt.Sprintf("failed to write TTS response: %v", err))
+		}
+
+		// 计算音频时长并更新 usage
+		audioFormat := "mp3" // 默认格式
+		if audioReq, ok := info.Request.(*dto.AudioRequest); ok && audioReq.ResponseFormat != "" {
+			audioFormat = audioReq.ResponseFormat
+		}
+
+		var duration float64
+		var durationErr error
+
+		if audioFormat == "pcm" {
+			// PCM 格式没有文件头,根据 OpenAI TTS 的 PCM 参数计算时长
+			// 采样率: 24000 Hz, 位深度: 16-bit (2 bytes), 声道数: 1
+			const sampleRate = 24000
+			const bytesPerSample = 2
+			const channels = 1
+			duration = float64(len(bodyBytes)) / float64(sampleRate*bytesPerSample*channels)
+		} else {
+			ext := "." + audioFormat
+			reader := bytes.NewReader(bodyBytes)
+			duration, durationErr = common.GetAudioDuration(c.Request.Context(), reader, ext)
+		}
+
+		usage.PromptTokensDetails.TextTokens = usage.PromptTokens
+		usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+
+		if durationErr != nil {
+			logger.LogWarn(c, fmt.Sprintf("failed to get audio duration: %v", durationErr))
+			// 如果无法获取时长,则设置保底的 CompletionTokens,根据body大小计算
+			sizeInKB := float64(len(bodyBytes)) / 1000.0
+			estimatedTokens := int(math.Ceil(sizeInKB)) // 粗略估算每KB约等于1 token
+			usage.CompletionTokens = estimatedTokens
+			usage.CompletionTokenDetails.AudioTokens = estimatedTokens
+		} else if duration > 0 {
+			// 计算 token: ceil(duration) / 60.0 * 1000,即每分钟 1000 tokens
+			completionTokens := int(math.Round(math.Ceil(duration) / 60.0 * 1000))
+			usage.CompletionTokens = completionTokens
+			usage.CompletionTokenDetails.AudioTokens = completionTokens
+		}
+	}
+
+	return usage
+}
+
+func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
+	defer service.CloseResponseBodyGracefully(resp)
+
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
+	}
+	// 写入新的 response body
+	service.IOCopyBytesGracefully(c, resp, responseBody)
+
+	var responseData struct {
+		Usage *dto.Usage `json:"usage"`
+	}
+	if err := common.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil {
+		if responseData.Usage.TotalTokens > 0 {
+			usage := responseData.Usage
+			if usage.PromptTokens == 0 {
+				usage.PromptTokens = usage.InputTokens
+			}
+			if usage.CompletionTokens == 0 {
+				usage.CompletionTokens = usage.OutputTokens
+			}
+			return nil, usage
+		}
+	}
+
+	usage := &dto.Usage{}
+	usage.PromptTokens = info.GetEstimatePromptTokens()
+	usage.CompletionTokens = 0
+	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+	return nil, usage
+}

+ 2 - 65
relay/channel/openai/relay-openai.go

@@ -1,7 +1,6 @@
 package openai
 
 import (
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -151,7 +150,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
 		var streamResp struct {
 			Usage *dto.Usage `json:"usage"`
 		}
-		err := json.Unmarshal([]byte(secondLastStreamData), &streamResp)
+		err := common.Unmarshal([]byte(secondLastStreamData), &streamResp)
 		if err == nil && streamResp.Usage != nil && service.ValidUsage(streamResp.Usage) {
 			usage = streamResp.Usage
 			containStreamUsage = true
@@ -327,68 +326,6 @@ func streamTTSResponse(c *gin.Context, resp *http.Response) {
 	}
 }
 
-func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
-	// the status code has been judged before, if there is a body reading failure,
-	// it should be regarded as a non-recoverable error, so it should not return err for external retry.
-	// Analogous to nginx's load balancing, it will only retry if it can't be requested or
-	// if the upstream returns a specific status code, once the upstream has already written the header,
-	// the subsequent failure of the response body should be regarded as a non-recoverable error,
-	// and can be terminated directly.
-	defer service.CloseResponseBodyGracefully(resp)
-	usage := &dto.Usage{}
-	usage.PromptTokens = info.GetEstimatePromptTokens()
-	usage.TotalTokens = info.GetEstimatePromptTokens()
-	for k, v := range resp.Header {
-		c.Writer.Header().Set(k, v[0])
-	}
-	c.Writer.WriteHeader(resp.StatusCode)
-
-	isStreaming := resp.ContentLength == -1 || resp.Header.Get("Content-Length") == ""
-	if isStreaming {
-		streamTTSResponse(c, resp)
-	} else {
-		c.Writer.WriteHeaderNow()
-		_, err := io.Copy(c.Writer, resp.Body)
-		if err != nil {
-			logger.LogError(c, err.Error())
-		}
-	}
-	return usage
-}
-
-func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
-	defer service.CloseResponseBodyGracefully(resp)
-
-	responseBody, err := io.ReadAll(resp.Body)
-	if err != nil {
-		return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
-	}
-	// 写入新的 response body
-	service.IOCopyBytesGracefully(c, resp, responseBody)
-
-	var responseData struct {
-		Usage *dto.Usage `json:"usage"`
-	}
-	if err := json.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil {
-		if responseData.Usage.TotalTokens > 0 {
-			usage := responseData.Usage
-			if usage.PromptTokens == 0 {
-				usage.PromptTokens = usage.InputTokens
-			}
-			if usage.CompletionTokens == 0 {
-				usage.CompletionTokens = usage.OutputTokens
-			}
-			return nil, usage
-		}
-	}
-
-	usage := &dto.Usage{}
-	usage.PromptTokens = info.GetEstimatePromptTokens()
-	usage.CompletionTokens = 0
-	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
-	return nil, usage
-}
-
 func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
 	if info == nil || info.ClientWs == nil || info.TargetWs == nil {
 		return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
@@ -687,7 +624,7 @@ func extractCachedTokensFromBody(body []byte) (int, bool) {
 		} `json:"usage"`
 	}
 
-	if err := json.Unmarshal(body, &payload); err != nil {
+	if err := common.Unmarshal(body, &payload); err != nil {
 		return 0, false
 	}
 

+ 1 - 1
relay/compatible_handler.go

@@ -181,7 +181,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
 		return newApiErr
 	}
 
-	if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
+	if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
 		service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
 	} else {
 		postConsumeQuota(c, info, usage.(*dto.Usage), "")

+ 1 - 1
setting/ratio_setting/model_ratio.go

@@ -536,7 +536,7 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
 			if name == "gpt-4o-2024-05-13" {
 				return 3, true
 			}
-			return 4, true
+			return 4, false
 		}
 		// gpt-5 匹配
 		if strings.HasPrefix(name, "gpt-5") {