|
|
@@ -0,0 +1,145 @@
|
|
|
+package openai
|
|
|
+
|
|
|
+import (
|
|
|
+ "bytes"
|
|
|
+ "fmt"
|
|
|
+ "io"
|
|
|
+ "math"
|
|
|
+ "net/http"
|
|
|
+
|
|
|
+ "github.com/QuantumNous/new-api/common"
|
|
|
+ "github.com/QuantumNous/new-api/constant"
|
|
|
+ "github.com/QuantumNous/new-api/dto"
|
|
|
+ "github.com/QuantumNous/new-api/logger"
|
|
|
+ relaycommon "github.com/QuantumNous/new-api/relay/common"
|
|
|
+ "github.com/QuantumNous/new-api/relay/helper"
|
|
|
+ "github.com/QuantumNous/new-api/service"
|
|
|
+ "github.com/QuantumNous/new-api/types"
|
|
|
+ "github.com/gin-gonic/gin"
|
|
|
+)
|
|
|
+
|
|
|
+func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
|
|
|
+ // the status code has been judged before, if there is a body reading failure,
|
|
|
+ // it should be regarded as a non-recoverable error, so it should not return err for external retry.
|
|
|
+ // Analogous to nginx's load balancing, it will only retry if it can't be requested or
|
|
|
+ // if the upstream returns a specific status code, once the upstream has already written the header,
|
|
|
+ // the subsequent failure of the response body should be regarded as a non-recoverable error,
|
|
|
+ // and can be terminated directly.
|
|
|
+ defer service.CloseResponseBodyGracefully(resp)
|
|
|
+ usage := &dto.Usage{}
|
|
|
+ usage.PromptTokens = info.GetEstimatePromptTokens()
|
|
|
+ usage.TotalTokens = info.GetEstimatePromptTokens()
|
|
|
+ for k, v := range resp.Header {
|
|
|
+ c.Writer.Header().Set(k, v[0])
|
|
|
+ }
|
|
|
+ c.Writer.WriteHeader(resp.StatusCode)
|
|
|
+
|
|
|
+ if info.IsStream {
|
|
|
+ helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
|
+ if service.SundaySearch(data, "usage") {
|
|
|
+ var simpleResponse dto.SimpleResponse
|
|
|
+ err := common.Unmarshal([]byte(data), &simpleResponse)
|
|
|
+ if err != nil {
|
|
|
+ logger.LogError(c, err.Error())
|
|
|
+ }
|
|
|
+ if simpleResponse.Usage.TotalTokens != 0 {
|
|
|
+ usage.PromptTokens = simpleResponse.Usage.InputTokens
|
|
|
+ usage.CompletionTokens = simpleResponse.OutputTokens
|
|
|
+ usage.TotalTokens = simpleResponse.TotalTokens
|
|
|
+ }
|
|
|
+ }
|
|
|
+ _ = helper.StringData(c, data)
|
|
|
+ return true
|
|
|
+ })
|
|
|
+ } else {
|
|
|
+ common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
|
|
|
+ // 读取响应体到缓冲区
|
|
|
+ bodyBytes, err := io.ReadAll(resp.Body)
|
|
|
+ if err != nil {
|
|
|
+ logger.LogError(c, fmt.Sprintf("failed to read TTS response body: %v", err))
|
|
|
+ c.Writer.WriteHeaderNow()
|
|
|
+ return usage
|
|
|
+ }
|
|
|
+
|
|
|
+ // 写入响应到客户端
|
|
|
+ c.Writer.WriteHeaderNow()
|
|
|
+ _, err = c.Writer.Write(bodyBytes)
|
|
|
+ if err != nil {
|
|
|
+ logger.LogError(c, fmt.Sprintf("failed to write TTS response: %v", err))
|
|
|
+ }
|
|
|
+
|
|
|
+ // 计算音频时长并更新 usage
|
|
|
+ audioFormat := "mp3" // 默认格式
|
|
|
+ if audioReq, ok := info.Request.(*dto.AudioRequest); ok && audioReq.ResponseFormat != "" {
|
|
|
+ audioFormat = audioReq.ResponseFormat
|
|
|
+ }
|
|
|
+
|
|
|
+ var duration float64
|
|
|
+ var durationErr error
|
|
|
+
|
|
|
+ if audioFormat == "pcm" {
|
|
|
+ // PCM 格式没有文件头,根据 OpenAI TTS 的 PCM 参数计算时长
|
|
|
+ // 采样率: 24000 Hz, 位深度: 16-bit (2 bytes), 声道数: 1
|
|
|
+ const sampleRate = 24000
|
|
|
+ const bytesPerSample = 2
|
|
|
+ const channels = 1
|
|
|
+ duration = float64(len(bodyBytes)) / float64(sampleRate*bytesPerSample*channels)
|
|
|
+ } else {
|
|
|
+ ext := "." + audioFormat
|
|
|
+ reader := bytes.NewReader(bodyBytes)
|
|
|
+ duration, durationErr = common.GetAudioDuration(c.Request.Context(), reader, ext)
|
|
|
+ }
|
|
|
+
|
|
|
+ usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
|
|
+ usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
|
+
|
|
|
+ if durationErr != nil {
|
|
|
+ logger.LogWarn(c, fmt.Sprintf("failed to get audio duration: %v", durationErr))
|
|
|
+ // 如果无法获取时长,则设置保底的 CompletionTokens,根据body大小计算
|
|
|
+ sizeInKB := float64(len(bodyBytes)) / 1000.0
|
|
|
+ estimatedTokens := int(math.Ceil(sizeInKB)) // 粗略估算每KB约等于1 token
|
|
|
+ usage.CompletionTokens = estimatedTokens
|
|
|
+ usage.CompletionTokenDetails.AudioTokens = estimatedTokens
|
|
|
+ } else if duration > 0 {
|
|
|
+ // 计算 token: ceil(duration) / 60.0 * 1000,即每分钟 1000 tokens
|
|
|
+ completionTokens := int(math.Round(math.Ceil(duration) / 60.0 * 1000))
|
|
|
+ usage.CompletionTokens = completionTokens
|
|
|
+ usage.CompletionTokenDetails.AudioTokens = completionTokens
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return usage
|
|
|
+}
|
|
|
+
|
|
|
+func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
|
|
|
+ defer service.CloseResponseBodyGracefully(resp)
|
|
|
+
|
|
|
+ responseBody, err := io.ReadAll(resp.Body)
|
|
|
+ if err != nil {
|
|
|
+ return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
|
|
+ }
|
|
|
+ // 写入新的 response body
|
|
|
+ service.IOCopyBytesGracefully(c, resp, responseBody)
|
|
|
+
|
|
|
+ var responseData struct {
|
|
|
+ Usage *dto.Usage `json:"usage"`
|
|
|
+ }
|
|
|
+ if err := common.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil {
|
|
|
+ if responseData.Usage.TotalTokens > 0 {
|
|
|
+ usage := responseData.Usage
|
|
|
+ if usage.PromptTokens == 0 {
|
|
|
+ usage.PromptTokens = usage.InputTokens
|
|
|
+ }
|
|
|
+ if usage.CompletionTokens == 0 {
|
|
|
+ usage.CompletionTokens = usage.OutputTokens
|
|
|
+ }
|
|
|
+ return nil, usage
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ usage := &dto.Usage{}
|
|
|
+ usage.PromptTokens = info.GetEstimatePromptTokens()
|
|
|
+ usage.CompletionTokens = 0
|
|
|
+ usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
|
+ return nil, usage
|
|
|
+}
|