Selaa lähdekoodia

feat: doubao tts add is stream check

feitianbubu 2 kuukautta sitten
vanhempi
sitoutus
431b3a84f6
2 muutettua tiedostoa jossa 36 lisäystä ja 33 poistoa
  1. 35 24
      relay/channel/volcengine/adaptor.go
  2. 1 9
      relay/channel/volcengine/tts.go

+ 35 - 24
relay/channel/volcengine/adaptor.go

@@ -23,6 +23,12 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
+const (
+	// Context keys for passing data between methods
+	contextKeyTTSRequest     = "volcengine_tts_request"
+	contextKeyResponseFormat = "response_format"
+)
+
 type Adaptor struct {
 }
 
@@ -50,7 +56,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
 	speedRatio := request.Speed
 	encoding := mapEncoding(request.ResponseFormat)
 
-	c.Set("response_format", encoding)
+	c.Set(contextKeyResponseFormat, encoding)
 
 	volcRequest := VolcengineTTSRequest{
 		App: VolcengineTTSApp{
@@ -70,7 +76,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
 		Request: VolcengineTTSReqInfo{
 			ReqID:     generateRequestID(),
 			Text:      request.Input,
-			Operation: "submit", // WebSocket uses "submit"
+			Operation: "submit", // default WebSocket uses "submit"
 			Model:     info.OriginModelName,
 		},
 	}
@@ -83,10 +89,20 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
 	}
 
 	// Store the request in context for WebSocket handler
-	c.Set("volcengine_tts_request", volcRequest)
+	c.Set(contextKeyTTSRequest, volcRequest)
+	// https://www.volcengine.com/docs/6561/1257584
+	// operation需要设置为submit才是流式返回
+	if volcRequest.Request.Operation == "submit" {
+		info.IsStream = true
+	}
 
 	// Return nil as WebSocket doesn't use traditional request body
-	return nil, nil
+	jsonData, err := json.Marshal(volcRequest)
+	if err != nil {
+		return nil, fmt.Errorf("error marshalling volcengine request: %w", err)
+	}
+
+	return bytes.NewReader(jsonData), nil
 }
 
 func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
@@ -327,7 +343,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 		}
 		// Only use WebSocket for official Volcengine endpoint
 		if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
-			return nil, nil // WebSocket handling will be done in DoResponse
+			if info.IsStream {
+				return nil, nil
+			}
 		}
 	}
 	return channel.DoApiRequest(a, c, info, requestBody)
@@ -335,22 +353,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
 	if info.RelayMode == constant.RelayModeAudioSpeech {
-		encoding := mapEncoding(c.GetString("response_format"))
-
-		// Check if this is WebSocket mode (resp will be nil for WebSocket)
-		if resp == nil {
-			// Get the WebSocket URL
-			requestURL, urlErr := a.GetRequestURL(info)
-			if urlErr != nil {
-				return nil, types.NewErrorWithStatusCode(
-					urlErr,
-					types.ErrorCodeBadRequestBody,
-					http.StatusInternalServerError,
-				)
-			}
-
-			// Retrieve the volcengine request from context
-			volcRequestInterface, exists := c.Get("volcengine_tts_request")
+		encoding := mapEncoding(c.GetString(contextKeyResponseFormat))
+		if info.IsStream {
+			volcRequestInterface, exists := c.Get(contextKeyTTSRequest)
 			if !exists {
 				return nil, types.NewErrorWithStatusCode(
 					errors.New("volcengine TTS request not found in context"),
@@ -368,11 +373,17 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 				)
 			}
 
-			// Handle WebSocket streaming
+			// Get the WebSocket URL
+			requestURL, urlErr := a.GetRequestURL(info)
+			if urlErr != nil {
+				return nil, types.NewErrorWithStatusCode(
+					urlErr,
+					types.ErrorCodeBadRequestBody,
+					http.StatusInternalServerError,
+				)
+			}
 			return handleTTSWebSocketResponse(c, requestURL, volcRequest, info, encoding)
 		}
-
-		// Handle traditional HTTP response
 		return handleTTSResponse(c, resp, info, encoding)
 	}
 

+ 1 - 9
relay/channel/volcengine/tts.go

@@ -230,10 +230,6 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
 	}
 	defer conn.Close()
 
-	// Update request operation to "submit" for WebSocket
-	volcRequest.Request.Operation = "submit"
-
-	// Marshal request payload
 	payload, marshalErr := json.Marshal(volcRequest)
 	if marshalErr != nil {
 		return nil, types.NewErrorWithStatusCode(
@@ -280,10 +276,8 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
 				http.StatusBadRequest,
 			)
 		case MsgTypeFrontEndResultServer:
-			// Metadata response, can be logged or processed
 			continue
 		case MsgTypeAudioOnlyServer:
-			// Stream audio chunk to client
 			if len(msg.Payload) > 0 {
 				audioBuffer = append(audioBuffer, msg.Payload...)
 				if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil {
@@ -293,10 +287,10 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
 						http.StatusInternalServerError,
 					)
 				}
+				//logger.Infof("write audio chunk size: %d", len(msg.Payload))
 				c.Writer.Flush()
 			}
 
-			// Check if this is the last packet (negative sequence)
 			if msg.Sequence < 0 {
 				c.Status(http.StatusOK)
 				usage = &dto.Usage{
@@ -307,12 +301,10 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
 				return usage, nil
 			}
 		default:
-			// Unknown message type, log and continue
 			continue
 		}
 	}
 
-	// If we reach here, connection closed without final packet
 	c.Status(http.StatusOK)
 	usage = &dto.Usage{
 		PromptTokens:     info.PromptTokens,