|
|
@@ -5,7 +5,10 @@ import (
|
|
|
"bytes"
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
+ "github.com/pkg/errors"
|
|
|
"io"
|
|
|
+ "math"
|
|
|
+ "mime/multipart"
|
|
|
"net/http"
|
|
|
"one-api/common"
|
|
|
"one-api/constant"
|
|
|
@@ -13,6 +16,7 @@ import (
|
|
|
relaycommon "one-api/relay/common"
|
|
|
relayconstant "one-api/relay/constant"
|
|
|
"one-api/service"
|
|
|
+ "os"
|
|
|
"strings"
|
|
|
"sync"
|
|
|
"time"
|
|
|
@@ -316,6 +320,11 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|
|
}
|
|
|
|
|
|
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
|
+ // count tokens by audio file duration
|
|
|
+ audioTokens, err := countAudioTokens(c)
|
|
|
+ if err != nil {
|
|
|
+ return service.OpenAIErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError), nil
|
|
|
+ }
|
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
|
if err != nil {
|
|
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
|
@@ -340,70 +349,52 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|
|
}
|
|
|
resp.Body.Close()
|
|
|
|
|
|
- var text string
|
|
|
- switch responseFormat {
|
|
|
- case "json":
|
|
|
- text, err = getTextFromJSON(responseBody)
|
|
|
- case "text":
|
|
|
- text, err = getTextFromText(responseBody)
|
|
|
- case "srt":
|
|
|
- text, err = getTextFromSRT(responseBody)
|
|
|
- case "verbose_json":
|
|
|
- text, err = getTextFromVerboseJSON(responseBody)
|
|
|
- case "vtt":
|
|
|
- text, err = getTextFromVTT(responseBody)
|
|
|
- }
|
|
|
-
|
|
|
usage := &dto.Usage{}
|
|
|
- usage.PromptTokens = info.PromptTokens
|
|
|
- usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName)
|
|
|
+ usage.PromptTokens = audioTokens
|
|
|
+ usage.CompletionTokens = 0
|
|
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
|
return nil, usage
|
|
|
}
|
|
|
|
|
|
-func getTextFromVTT(body []byte) (string, error) {
|
|
|
- return getTextFromSRT(body)
|
|
|
-}
|
|
|
+func countAudioTokens(c *gin.Context) (int, error) {
|
|
|
+ body, err := common.GetRequestBody(c)
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.WithStack(err)
|
|
|
+ }
|
|
|
|
|
|
-func getTextFromVerboseJSON(body []byte) (string, error) {
|
|
|
- var whisperResponse dto.WhisperVerboseJSONResponse
|
|
|
- if err := json.Unmarshal(body, &whisperResponse); err != nil {
|
|
|
- return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
|
|
|
+ var reqBody struct {
|
|
|
+ File *multipart.FileHeader `form:"file" binding:"required"`
|
|
|
+ }
|
|
|
+ c.Request.Body = io.NopCloser(bytes.NewReader(body))
|
|
|
+ if err = c.ShouldBind(&reqBody); err != nil {
|
|
|
+ return 0, errors.WithStack(err)
|
|
|
}
|
|
|
- return whisperResponse.Text, nil
|
|
|
-}
|
|
|
|
|
|
-func getTextFromSRT(body []byte) (string, error) {
|
|
|
- scanner := bufio.NewScanner(strings.NewReader(string(body)))
|
|
|
- var builder strings.Builder
|
|
|
- var textLine bool
|
|
|
- for scanner.Scan() {
|
|
|
- line := scanner.Text()
|
|
|
- if textLine {
|
|
|
- builder.WriteString(line)
|
|
|
- textLine = false
|
|
|
- continue
|
|
|
- } else if strings.Contains(line, "-->") {
|
|
|
- textLine = true
|
|
|
- continue
|
|
|
- }
|
|
|
+ reqFp, err := reqBody.File.Open()
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.WithStack(err)
|
|
|
}
|
|
|
- if err := scanner.Err(); err != nil {
|
|
|
- return "", err
|
|
|
+
|
|
|
+ tmpFp, err := os.CreateTemp("", "audio-*")
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.WithStack(err)
|
|
|
}
|
|
|
- return builder.String(), nil
|
|
|
-}
|
|
|
+ defer os.Remove(tmpFp.Name())
|
|
|
|
|
|
-func getTextFromText(body []byte) (string, error) {
|
|
|
- return strings.TrimSuffix(string(body), "\n"), nil
|
|
|
-}
|
|
|
+ _, err = io.Copy(tmpFp, reqFp)
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.WithStack(err)
|
|
|
+ }
|
|
|
+ if err = tmpFp.Close(); err != nil {
|
|
|
+ return 0, errors.WithStack(err)
|
|
|
+ }
|
|
|
|
|
|
-func getTextFromJSON(body []byte) (string, error) {
|
|
|
- var whisperResponse dto.AudioResponse
|
|
|
- if err := json.Unmarshal(body, &whisperResponse); err != nil {
|
|
|
- return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
|
|
|
+ duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name())
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.WithStack(err)
|
|
|
}
|
|
|
- return whisperResponse.Text, nil
|
|
|
+
|
|
|
+ return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
|
|
|
}
|
|
|
|
|
|
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
|