Просмотр исходного кода

Merge pull request #2067 from feitianbubu/pr/add-doubao-audio

新增支持豆包语音合成2.0功能
IcedTangerine 2 месяцев назад
Родитель
Сommit
1ec664a348

+ 65 - 3
relay/channel/volcengine/adaptor.go

@@ -37,8 +37,50 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
 }
 
 func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
+	if info.RelayMode != constant.RelayModeAudioSpeech {
+		return nil, errors.New("unsupported audio relay mode")
+	}
+
+	appID, token, err := parseVolcengineAuth(info.ApiKey)
+	if err != nil {
+		return nil, err
+	}
+
+	voiceType := mapVoiceType(request.Voice)
+	speedRatio := mapSpeedRatio(request.Speed)
+	encoding := mapEncoding(request.ResponseFormat)
+
+	c.Set("response_format", encoding)
+
+	volcRequest := VolcengineTTSRequest{
+		App: VolcengineTTSApp{
+			AppID:   appID,
+			Token:   token,
+			Cluster: "volcano_tts",
+		},
+		User: VolcengineTTSUser{
+			UID: "openai_relay_user",
+		},
+		Audio: VolcengineTTSAudio{
+			VoiceType:  voiceType,
+			Encoding:   encoding,
+			SpeedRatio: speedRatio,
+			Rate:       24000,
+		},
+		Request: VolcengineTTSReqInfo{
+			ReqID:     generateRequestID(),
+			Text:      request.Input,
+			Operation: "query",
+			Model:     info.OriginModelName,
+		},
+	}
+
+	jsonData, err := json.Marshal(volcRequest)
+	if err != nil {
+		return nil, fmt.Errorf("error marshalling volcengine request: %w", err)
+	}
+
+	return bytes.NewReader(jsonData), nil
 }
 
 func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
@@ -190,7 +232,6 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
-	// 支持自定义域名,如果未设置则使用默认域名
 	baseUrl := info.ChannelBaseUrl
 	if baseUrl == "" {
 		baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
@@ -217,6 +258,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 			return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
 		case constant.RelayModeRerank:
 			return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
+		case constant.RelayModeAudioSpeech:
+			// 只有当 baseUrl 是火山默认的官方Url时才改为官方的的TTS接口,否则走透传的New接口
+			if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
+				return "https://openspeech.bytedance.com/api/v1/tts", nil
+			}
+			return fmt.Sprintf("%s/v1/audio/speech", baseUrl), nil
 		default:
 		}
 	}
@@ -225,6 +272,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
+
+	if info.RelayMode == constant.RelayModeAudioSpeech {
+		parts := strings.Split(info.ApiKey, "|")
+		if len(parts) == 2 {
+			req.Set("Authorization", "Bearer;"+parts[1])
+		}
+		req.Set("Content-Type", "application/json")
+		return nil
+	}
+
 	req.Set("Authorization", "Bearer "+info.ApiKey)
 	return nil
 }
@@ -260,6 +317,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 }
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+	if info.RelayMode == constant.RelayModeAudioSpeech {
+		encoding := mapEncoding(c.GetString("response_format"))
+		return handleTTSResponse(c, resp, info, encoding)
+	}
+
 	adaptor := openai.Adaptor{}
 	usage, err = adaptor.DoResponse(c, resp, info)
 	return

+ 208 - 0
relay/channel/volcengine/tts.go

@@ -0,0 +1,208 @@
+package volcengine
+
+import (
+	"encoding/base64"
+	"encoding/json"
+	"errors"
+	"io"
+	"net/http"
+	"strings"
+
+	"github.com/QuantumNous/new-api/dto"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/QuantumNous/new-api/types"
+	"github.com/gin-gonic/gin"
+	"github.com/google/uuid"
+)
+
+type VolcengineTTSRequest struct {
+	App     VolcengineTTSApp     `json:"app"`
+	User    VolcengineTTSUser    `json:"user"`
+	Audio   VolcengineTTSAudio   `json:"audio"`
+	Request VolcengineTTSReqInfo `json:"request"`
+}
+
+type VolcengineTTSApp struct {
+	AppID   string `json:"appid"`
+	Token   string `json:"token"`
+	Cluster string `json:"cluster"`
+}
+
+type VolcengineTTSUser struct {
+	UID string `json:"uid"`
+}
+
+type VolcengineTTSAudio struct {
+	VoiceType        string  `json:"voice_type"`
+	Encoding         string  `json:"encoding"`
+	SpeedRatio       float64 `json:"speed_ratio"`
+	Rate             int     `json:"rate"`
+	Bitrate          int     `json:"bitrate,omitempty"`
+	LoudnessRatio    float64 `json:"loudness_ratio,omitempty"`
+	EnableEmotion    bool    `json:"enable_emotion,omitempty"`
+	Emotion          string  `json:"emotion,omitempty"`
+	EmotionScale     float64 `json:"emotion_scale,omitempty"`
+	ExplicitLanguage string  `json:"explicit_language,omitempty"`
+	ContextLanguage  string  `json:"context_language,omitempty"`
+}
+
+type VolcengineTTSReqInfo struct {
+	ReqID           string                   `json:"reqid"`
+	Text            string                   `json:"text"`
+	Operation       string                   `json:"operation"`
+	Model           string                   `json:"model,omitempty"`
+	TextType        string                   `json:"text_type,omitempty"`
+	SilenceDuration float64                  `json:"silence_duration,omitempty"`
+	WithTimestamp   interface{}              `json:"with_timestamp,omitempty"`
+	ExtraParam      *VolcengineTTSExtraParam `json:"extra_param,omitempty"`
+}
+
+type VolcengineTTSExtraParam struct {
+	DisableMarkdownFilter      bool                      `json:"disable_markdown_filter,omitempty"`
+	EnableLatexTn              bool                      `json:"enable_latex_tn,omitempty"`
+	MuteCutThreshold           string                    `json:"mute_cut_threshold,omitempty"`
+	MuteCutRemainMs            string                    `json:"mute_cut_remain_ms,omitempty"`
+	DisableEmojiFilter         bool                      `json:"disable_emoji_filter,omitempty"`
+	UnsupportedCharRatioThresh float64                   `json:"unsupported_char_ratio_thresh,omitempty"`
+	AigcWatermark              bool                      `json:"aigc_watermark,omitempty"`
+	CacheConfig                *VolcengineTTSCacheConfig `json:"cache_config,omitempty"`
+}
+
+type VolcengineTTSCacheConfig struct {
+	TextType int  `json:"text_type,omitempty"`
+	UseCache bool `json:"use_cache,omitempty"`
+}
+
+type VolcengineTTSResponse struct {
+	ReqID    string                     `json:"reqid"`
+	Code     int                        `json:"code"`
+	Message  string                     `json:"message"`
+	Sequence int                        `json:"sequence"`
+	Data     string                     `json:"data"`
+	Addition *VolcengineTTSAdditionInfo `json:"addition,omitempty"`
+}
+
+type VolcengineTTSAdditionInfo struct {
+	Duration string `json:"duration"`
+}
+
+var openAIToVolcengineVoiceMap = map[string]string{
+	"alloy":   "zh_male_M392_conversation_wvae_bigtts",
+	"echo":    "zh_male_wenhao_mars_bigtts",
+	"fable":   "zh_female_tianmei_mars_bigtts",
+	"onyx":    "zh_male_zhibei_mars_bigtts",
+	"nova":    "zh_female_shuangkuaisisi_mars_bigtts",
+	"shimmer": "zh_female_cancan_mars_bigtts",
+}
+
+var responseFormatToEncodingMap = map[string]string{
+	"mp3":  "mp3",
+	"opus": "ogg_opus",
+	"aac":  "mp3",
+	"flac": "mp3",
+	"wav":  "wav",
+	"pcm":  "pcm",
+}
+
+func parseVolcengineAuth(apiKey string) (appID, token string, err error) {
+	parts := strings.Split(apiKey, "|")
+	if len(parts) != 2 {
+		return "", "", errors.New("invalid api key format, expected: appid|access_token")
+	}
+	return parts[0], parts[1], nil
+}
+
+func mapVoiceType(openAIVoice string) string {
+	if voice, ok := openAIToVolcengineVoiceMap[openAIVoice]; ok {
+		return voice
+	}
+	return openAIVoice
+}
+
+// [0.1,2],默认为 1,通常保留一位小数即可
+func mapSpeedRatio(speed float64) float64 {
+	if speed == 0 {
+		return 1.0
+	}
+	if speed < 0.1 {
+		return 0.1
+	}
+	if speed > 2.0 {
+		return 2.0
+	}
+	return speed
+}
+
+func mapEncoding(responseFormat string) string {
+	if encoding, ok := responseFormatToEncodingMap[responseFormat]; ok {
+		return encoding
+	}
+	return "mp3"
+}
+
+func getContentTypeByEncoding(encoding string) string {
+	contentTypeMap := map[string]string{
+		"mp3":      "audio/mpeg",
+		"ogg_opus": "audio/ogg",
+		"wav":      "audio/wav",
+		"pcm":      "audio/pcm",
+	}
+	if ct, ok := contentTypeMap[encoding]; ok {
+		return ct
+	}
+	return "application/octet-stream"
+}
+
+func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) {
+	body, readErr := io.ReadAll(resp.Body)
+	if readErr != nil {
+		return nil, types.NewErrorWithStatusCode(
+			errors.New("failed to read volcengine response"),
+			types.ErrorCodeReadResponseBodyFailed,
+			http.StatusInternalServerError,
+		)
+	}
+	defer resp.Body.Close()
+
+	var volcResp VolcengineTTSResponse
+	if unmarshalErr := json.Unmarshal(body, &volcResp); unmarshalErr != nil {
+		return nil, types.NewErrorWithStatusCode(
+			errors.New("failed to parse volcengine response"),
+			types.ErrorCodeBadResponseBody,
+			http.StatusInternalServerError,
+		)
+	}
+
+	if volcResp.Code != 3000 {
+		return nil, types.NewErrorWithStatusCode(
+			errors.New(volcResp.Message),
+			types.ErrorCodeBadResponse,
+			http.StatusBadRequest,
+		)
+	}
+
+	audioData, decodeErr := base64.StdEncoding.DecodeString(volcResp.Data)
+	if decodeErr != nil {
+		return nil, types.NewErrorWithStatusCode(
+			errors.New("failed to decode audio data"),
+			types.ErrorCodeBadResponseBody,
+			http.StatusInternalServerError,
+		)
+	}
+
+	contentType := getContentTypeByEncoding(encoding)
+	c.Header("Content-Type", contentType)
+	c.Data(http.StatusOK, contentType, audioData)
+
+	usage = &dto.Usage{
+		PromptTokens:     info.PromptTokens,
+		CompletionTokens: 0,
+		TotalTokens:      info.PromptTokens,
+	}
+
+	return usage, nil
+}
+
+func generateRequestID() string {
+	return uuid.New().String()
+}

+ 2 - 0
web/src/components/table/channels/modals/EditChannelModal.jsx

@@ -107,6 +107,8 @@ function type2secretPrompt(type) {
       return '按照如下格式输入:AppId|SecretId|SecretKey';
     case 33:
       return '按照如下格式输入:Ak|Sk|Region';
+    case 45:
+        return '请输入渠道对应的鉴权密钥, 豆包语音输入:AppId|AccessToken';
     case 50:
       return '按照如下格式输入: AccessKey|SecretKey, 如果上游是New API,则直接输ApiKey';
     case 51: