CaIon 2 лет назад
Родитель
Сommit
63cd3f05f2

+ 3 - 1
common/model-ratio.go

@@ -37,7 +37,9 @@ var ModelRatio = map[string]float64{
 	"text-davinci-003":          10,
 	"text-davinci-edit-001":     10,
 	"code-davinci-edit-001":     10,
-	"whisper-1":                 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
+	"whisper-1":                 15,  // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
+	"tts-1":                     7.5, // 1k characters -> $0.015
+	"tts-1-hd":                  15,  // 1k characters -> $0.03
 	"davinci":                   10,
 	"curie":                     10,
 	"babbage":                   10,

+ 9 - 0
common/utils.go

@@ -207,3 +207,12 @@ func String2Int(str string) int {
 	}
 	return num
 }
+
+func StringsContains(strs []string, str string) bool {
+	for _, s := range strs {
+		if s == str {
+			return true
+		}
+	}
+	return false
+}

+ 47 - 9
controller/relay-audio.go

@@ -11,10 +11,19 @@ import (
 	"net/http"
 	"one-api/common"
 	"one-api/model"
+	"strings"
 )
 
+var availableVoices = []string{
+	"alloy",
+	"echo",
+	"fable",
+	"onyx",
+	"nova",
+	"shimmer",
+}
+
 func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
-	audioModel := "whisper-1"
 
 	tokenId := c.GetInt("token_id")
 	channelType := c.GetInt("channel")
@@ -22,8 +31,28 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	userId := c.GetInt("id")
 	group := c.GetString("group")
 
+	var audioRequest AudioRequest
+	err := common.UnmarshalBodyReusable(c, &audioRequest)
+	if err != nil {
+		return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
+	}
+
+	// request validation
+	if audioRequest.Model == "" {
+		return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
+	}
+
+	if strings.HasPrefix(audioRequest.Model, "tts-1") {
+		if audioRequest.Voice == "" {
+			return errorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
+		}
+		if !common.StringsContains(availableVoices, audioRequest.Voice) {
+			return errorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest)
+		}
+	}
+
 	preConsumedTokens := common.PreConsumedQuota
-	modelRatio := common.GetModelRatio(audioModel)
+	modelRatio := common.GetModelRatio(audioRequest.Model)
 	groupRatio := common.GetGroupRatio(group)
 	ratio := modelRatio * groupRatio
 	preConsumedQuota := int(float64(preConsumedTokens) * ratio)
@@ -58,8 +87,8 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 		if err != nil {
 			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 		}
-		if modelMap[audioModel] != "" {
-			audioModel = modelMap[audioModel]
+		if modelMap[audioRequest.Model] != "" {
+			audioRequest.Model = modelMap[audioRequest.Model]
 		}
 	}
 
@@ -97,7 +126,12 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 
 	defer func(ctx context.Context) {
 		go func() {
-			quota := countTokenText(audioResponse.Text, audioModel)
+			var quota int
+			if strings.HasPrefix(audioRequest.Model, "tts-1") {
+				quota = countAudioToken(audioRequest.Input, audioRequest.Model)
+			} else {
+				quota = countAudioToken(audioResponse.Text, audioRequest.Model)
+			}
 			quotaDelta := quota - preConsumedQuota
 			err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
 			if err != nil {
@@ -110,7 +144,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 			if quota != 0 {
 				tokenName := c.GetString("token_name")
 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
-				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent, tokenId)
+				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioRequest.Model, tokenName, quota, logContent, tokenId)
 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 				channelId := c.GetInt("channel_id")
 				model.UpdateChannelUsedQuota(channelId, quota)
@@ -127,9 +161,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	if err != nil {
 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 	}
-	err = json.Unmarshal(responseBody, &audioResponse)
-	if err != nil {
-		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+	if strings.HasPrefix(audioRequest.Model, "tts-1") {
+
+	} else {
+		err = json.Unmarshal(responseBody, &audioResponse)
+		if err != nil {
+			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+		}
 	}
 
 	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))

+ 9 - 0
controller/relay-utils.go

@@ -10,6 +10,7 @@ import (
 	"one-api/common"
 	"strconv"
 	"strings"
+	"unicode/utf8"
 )
 
 var stopFinishReason = "stop"
@@ -106,6 +107,14 @@ func countTokenInput(input any, model string) int {
 	return 0
 }
 
+func countAudioToken(text string, model string) int {
+	if strings.HasPrefix(model, "tts") {
+		return utf8.RuneCountInString(text)
+	} else {
+		return countTokenText(text, model)
+	}
+}
+
 func countTokenText(text string, model string) int {
 	tokenEncoder := getTokenEncoder(model)
 	return getTokenNum(tokenEncoder, text)

+ 6 - 0
controller/relay.go

@@ -70,6 +70,12 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
 	return input
 }
 
+type AudioRequest struct {
+	Model string `json:"model"`
+	Voice string `json:"voice"`
+	Input string `json:"input"`
+}
+
 type ChatRequest struct {
 	Model     string    `json:"model"`
 	Messages  []Message `json:"messages"`

+ 6 - 3
middleware/distributor.go

@@ -46,9 +46,8 @@ func Distribute() func(c *gin.Context) {
 				if modelRequest.Model == "" {
 					modelRequest.Model = "midjourney"
 				}
-			} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
-				err = common.UnmarshalBodyReusable(c, &modelRequest)
 			}
+			err = common.UnmarshalBodyReusable(c, &modelRequest)
 			if err != nil {
 				abortWithMessage(c, http.StatusBadRequest, "无效的请求")
 				return
@@ -70,7 +69,11 @@ func Distribute() func(c *gin.Context) {
 			}
 			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
 				if modelRequest.Model == "" {
-					modelRequest.Model = "whisper-1"
+					if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
+						modelRequest.Model = "tts-1"
+					} else {
+						modelRequest.Model = "whisper-1"
+					}
 				}
 			}
 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)

+ 1 - 0
router/relay-router.go

@@ -29,6 +29,7 @@ func SetRelayRouter(router *gin.Engine) {
 		relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
 		relayV1Router.POST("/audio/transcriptions", controller.Relay)
 		relayV1Router.POST("/audio/translations", controller.Relay)
+		relayV1Router.POST("/audio/speech", controller.Relay)
 		relayV1Router.GET("/files", controller.RelayNotImplemented)
 		relayV1Router.POST("/files", controller.RelayNotImplemented)
 		relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)