Преглед изворни кода

Merge pull request #705 from maranello-o/main

fix: incorrect whisper audio usage
Calcium-Ion пре 11 месеци
родитељ
комит
7f8112a325
3 измењених фајлова са 76 додато и 52 уклоњено
  1. 1 1
      Dockerfile
  2. 33 0
      common/utils.go
  3. 42 51
      relay/channel/openai/relay-openai.go

+ 1 - 1
Dockerfile

@@ -24,7 +24,7 @@ FROM alpine
 
 RUN apk update \
     && apk upgrade \
-    && apk add --no-cache ca-certificates tzdata \
+    && apk add --no-cache ca-certificates tzdata ffmpeg\
     && update-ca-certificates 2>/dev/null || true
 
 COPY --from=builder2 /build/one-api /

+ 33 - 0
common/utils.go

@@ -1,14 +1,19 @@
 package common
 
 import (
+	"bytes"
+	"context"
 	crand "crypto/rand"
 	"encoding/base64"
 	"fmt"
+	"github.com/pkg/errors"
 	"html/template"
+	"io"
 	"log"
 	"math/big"
 	"math/rand"
 	"net"
+	"os"
 	"os/exec"
 	"runtime"
 	"strconv"
@@ -207,3 +212,31 @@ func RandomSleep() {
 	// Sleep for 0-3000 ms
 	time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
 }
+
+// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
+func SaveTmpFile(filename string, data io.Reader) (string, error) {
+	f, err := os.CreateTemp(os.TempDir(), filename)
+	if err != nil {
+		return "", errors.Wrapf(err, "failed to create temporary file %s", filename)
+	}
+	defer f.Close()
+
+	_, err = io.Copy(f, data)
+	if err != nil {
+		return "", errors.Wrapf(err, "failed to copy data to temporary file %s", filename)
+	}
+
+	return f.Name(), nil
+}
+
+// GetAudioDuration returns the duration of an audio file in seconds.
+func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
+	// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
+	c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
+	output, err := c.Output()
+	if err != nil {
+		return 0, errors.Wrap(err, "failed to get audio duration")
+	}
+
+	return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
+}

+ 42 - 51
relay/channel/openai/relay-openai.go

@@ -5,7 +5,10 @@ import (
 	"bytes"
 	"encoding/json"
 	"fmt"
+	"github.com/pkg/errors"
 	"io"
+	"math"
+	"mime/multipart"
 	"net/http"
 	"one-api/common"
 	"one-api/constant"
@@ -13,6 +16,7 @@ import (
 	relaycommon "one-api/relay/common"
 	relayconstant "one-api/relay/constant"
 	"one-api/service"
+	"os"
 	"strings"
 	"sync"
 	"time"
@@ -316,6 +320,11 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 }
 
 func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	// count tokens by audio file duration
+	audioTokens, err := countAudioTokens(c)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError), nil
+	}
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -340,70 +349,52 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	}
 	resp.Body.Close()
 
-	var text string
-	switch responseFormat {
-	case "json":
-		text, err = getTextFromJSON(responseBody)
-	case "text":
-		text, err = getTextFromText(responseBody)
-	case "srt":
-		text, err = getTextFromSRT(responseBody)
-	case "verbose_json":
-		text, err = getTextFromVerboseJSON(responseBody)
-	case "vtt":
-		text, err = getTextFromVTT(responseBody)
-	}
-
 	usage := &dto.Usage{}
-	usage.PromptTokens = info.PromptTokens
-	usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName)
+	usage.PromptTokens = audioTokens
+	usage.CompletionTokens = 0
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	return nil, usage
 }
 
-func getTextFromVTT(body []byte) (string, error) {
-	return getTextFromSRT(body)
-}
+func countAudioTokens(c *gin.Context) (int, error) {
+	body, err := common.GetRequestBody(c)
+	if err != nil {
+		return 0, errors.WithStack(err)
+	}
 
-func getTextFromVerboseJSON(body []byte) (string, error) {
-	var whisperResponse dto.WhisperVerboseJSONResponse
-	if err := json.Unmarshal(body, &whisperResponse); err != nil {
-		return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
+	var reqBody struct {
+		File *multipart.FileHeader `form:"file" binding:"required"`
+	}
+	c.Request.Body = io.NopCloser(bytes.NewReader(body))
+	if err = c.ShouldBind(&reqBody); err != nil {
+		return 0, errors.WithStack(err)
 	}
-	return whisperResponse.Text, nil
-}
 
-func getTextFromSRT(body []byte) (string, error) {
-	scanner := bufio.NewScanner(strings.NewReader(string(body)))
-	var builder strings.Builder
-	var textLine bool
-	for scanner.Scan() {
-		line := scanner.Text()
-		if textLine {
-			builder.WriteString(line)
-			textLine = false
-			continue
-		} else if strings.Contains(line, "-->") {
-			textLine = true
-			continue
-		}
+	reqFp, err := reqBody.File.Open()
+	if err != nil {
+		return 0, errors.WithStack(err)
 	}
-	if err := scanner.Err(); err != nil {
-		return "", err
+
+	tmpFp, err := os.CreateTemp("", "audio-*")
+	if err != nil {
+		return 0, errors.WithStack(err)
 	}
-	return builder.String(), nil
-}
+	defer os.Remove(tmpFp.Name())
 
-func getTextFromText(body []byte) (string, error) {
-	return strings.TrimSuffix(string(body), "\n"), nil
-}
+	_, err = io.Copy(tmpFp, reqFp)
+	if err != nil {
+		return 0, errors.WithStack(err)
+	}
+	if err = tmpFp.Close(); err != nil {
+		return 0, errors.WithStack(err)
+	}
 
-func getTextFromJSON(body []byte) (string, error) {
-	var whisperResponse dto.AudioResponse
-	if err := json.Unmarshal(body, &whisperResponse); err != nil {
-		return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
+	duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name())
+	if err != nil {
+		return 0, errors.WithStack(err)
 	}
-	return whisperResponse.Text, nil
+
+	return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
 }
 
 func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {