Prechádzať zdrojové kódy

feat: doubao tts support streaming realtime audio

feitianbubu 2 mesiacov pred
rodič
commit
098e6e7f2b

+ 54 - 7
relay/channel/volcengine/adaptor.go

@@ -70,7 +70,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
 		Request: VolcengineTTSReqInfo{
 			ReqID:     generateRequestID(),
 			Text:      request.Input,
-			Operation: "query",
+			Operation: "submit", // WebSocket uses "submit"
 			Model:     info.OriginModelName,
 		},
 	}
@@ -82,12 +82,11 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
 		}
 	}
 
-	jsonData, err := json.Marshal(volcRequest)
-	if err != nil {
-		return nil, fmt.Errorf("error marshalling volcengine request: %w", err)
-	}
+	// Store the request in context for WebSocket handler
+	c.Set("volcengine_tts_request", volcRequest)
 
-	return bytes.NewReader(jsonData), nil
+	// Return nil as WebSocket doesn't use traditional request body
+	return nil, nil
 }
 
 func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
@@ -268,7 +267,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 		case constant.RelayModeAudioSpeech:
 			// 只有当 baseUrl 是火山默认的官方Url时才改为官方的的TTS接口,否则走透传的New接口
 			if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
-				return "https://openspeech.bytedance.com/api/v1/tts", nil
+				return "wss://openspeech.bytedance.com/api/v1/tts/ws_binary", nil
 			}
 			return fmt.Sprintf("%s/v1/audio/speech", baseUrl), nil
 		default:
@@ -320,12 +319,60 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
 }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+	// For TTS with WebSocket, skip traditional HTTP request
+	if info.RelayMode == constant.RelayModeAudioSpeech {
+		baseUrl := info.ChannelBaseUrl
+		if baseUrl == "" {
+			baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
+		}
+		// Only use WebSocket for official Volcengine endpoint
+		if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
+			return nil, nil // WebSocket handling will be done in DoResponse
+		}
+	}
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
 	if info.RelayMode == constant.RelayModeAudioSpeech {
 		encoding := mapEncoding(c.GetString("response_format"))
+
+		// Check if this is WebSocket mode (resp will be nil for WebSocket)
+		if resp == nil {
+			// Get the WebSocket URL
+			requestURL, urlErr := a.GetRequestURL(info)
+			if urlErr != nil {
+				return nil, types.NewErrorWithStatusCode(
+					urlErr,
+					types.ErrorCodeBadRequestBody,
+					http.StatusInternalServerError,
+				)
+			}
+
+			// Retrieve the volcengine request from context
+			volcRequestInterface, exists := c.Get("volcengine_tts_request")
+			if !exists {
+				return nil, types.NewErrorWithStatusCode(
+					errors.New("volcengine TTS request not found in context"),
+					types.ErrorCodeBadRequestBody,
+					http.StatusInternalServerError,
+				)
+			}
+
+			volcRequest, ok := volcRequestInterface.(VolcengineTTSRequest)
+			if !ok {
+				return nil, types.NewErrorWithStatusCode(
+					errors.New("invalid volcengine TTS request type"),
+					types.ErrorCodeBadRequestBody,
+					http.StatusInternalServerError,
+				)
+			}
+
+			// Handle WebSocket streaming
+			return handleTTSWebSocketResponse(c, requestURL, volcRequest, info, encoding)
+		}
+
+		// Handle traditional HTTP response
 		return handleTTSResponse(c, resp, info, encoding)
 	}
 

+ 715 - 0
relay/channel/volcengine/protocols.go

@@ -0,0 +1,715 @@
+package volcengine
+
+import (
+	"bytes"
+	"encoding/binary"
+	"fmt"
+	"io"
+	"math"
+
+	"github.com/gorilla/websocket"
+)
+
+type (
+	// EventType defines the event type which determines the event of the message.
+	EventType int32
+	// MsgType defines message type which determines how the message will be
+	// serialized with the protocol.
+	MsgType uint8
+	// MsgTypeFlagBits defines the 4-bit message-type specific flags. The specific
+	// values should be defined in each specific usage scenario.
+	MsgTypeFlagBits uint8
+	// VersionBits defines the 4-bit version type.
+	VersionBits uint8
+	// HeaderSizeBits defines the 4-bit header-size type.
+	HeaderSizeBits uint8
+	// SerializationBits defines the 4-bit serialization method type.
+	SerializationBits uint8
+	// CompressionBits defines the 4-bit compression method type.
+	CompressionBits uint8
+)
+
+const (
+	MsgTypeFlagNoSeq       MsgTypeFlagBits = 0     // Non-terminal packet with no sequence
+	MsgTypeFlagPositiveSeq MsgTypeFlagBits = 0b1   // Non-terminal packet with sequence > 0
+	MsgTypeFlagLastNoSeq   MsgTypeFlagBits = 0b10  // last packet with no sequence
+	MsgTypeFlagNegativeSeq MsgTypeFlagBits = 0b11  // last packet with sequence < 0
+	MsgTypeFlagWithEvent   MsgTypeFlagBits = 0b100 // Payload contains event number (int32)
+)
+
+const (
+	Version1 VersionBits = iota + 1
+	Version2
+	Version3
+	Version4
+)
+
+const (
+	HeaderSize4 HeaderSizeBits = iota + 1
+	HeaderSize8
+	HeaderSize12
+	HeaderSize16
+)
+
+const (
+	SerializationRaw    SerializationBits = 0
+	SerializationJSON   SerializationBits = 0b1
+	SerializationThrift SerializationBits = 0b11
+	SerializationCustom SerializationBits = 0b1111
+)
+
+const (
+	CompressionNone   CompressionBits = 0
+	CompressionGzip   CompressionBits = 0b1
+	CompressionCustom CompressionBits = 0b1111
+)
+
+const (
+	MsgTypeInvalid              MsgType = 0
+	MsgTypeFullClientRequest    MsgType = 0b1
+	MsgTypeAudioOnlyClient      MsgType = 0b10
+	MsgTypeFullServerResponse   MsgType = 0b1001
+	MsgTypeAudioOnlyServer      MsgType = 0b1011
+	MsgTypeFrontEndResultServer MsgType = 0b1100
+	MsgTypeError                MsgType = 0b1111
+
+	MsgTypeServerACK = MsgTypeAudioOnlyServer
+)
+
+func (t MsgType) String() string {
+	switch t {
+	case MsgTypeFullClientRequest:
+		return "MsgType_FullClientRequest"
+	case MsgTypeAudioOnlyClient:
+		return "MsgType_AudioOnlyClient"
+	case MsgTypeFullServerResponse:
+		return "MsgType_FullServerResponse"
+	case MsgTypeAudioOnlyServer:
+		return "MsgType_AudioOnlyServer" // MsgTypeServerACK
+	case MsgTypeError:
+		return "MsgType_Error"
+	case MsgTypeFrontEndResultServer:
+		return "MsgType_FrontEndResultServer"
+	default:
+		return fmt.Sprintf("MsgType_(%d)", t)
+	}
+}
+
+const (
+	// Default event, applicable for scenarios not using events or not requiring event transmission,
+	// or for scenarios using events, non-zero values can be used to validate event legitimacy
+	EventType_None EventType = 0
+	// 1 ~ 49 for upstream Connection events
+	EventType_StartConnection  EventType = 1
+	EventType_StartTask        EventType = 1 // Alias of "StartConnection"
+	EventType_FinishConnection EventType = 2
+	EventType_FinishTask       EventType = 2 // Alias of "FinishConnection"
+	// 50 ~ 99 for downstream Connection events
+	// Connection established successfully
+	EventType_ConnectionStarted EventType = 50
+	EventType_TaskStarted       EventType = 50 // Alias of "ConnectionStarted"
+	// Connection failed (possibly due to authentication failure)
+	EventType_ConnectionFailed EventType = 51
+	EventType_TaskFailed       EventType = 51 // Alias of "ConnectionFailed"
+	// Connection ended
+	EventType_ConnectionFinished EventType = 52
+	EventType_TaskFinished       EventType = 52 // Alias of "ConnectionFinished"
+	// 100 ~ 149 for upstream Session events
+	EventType_StartSession  EventType = 100
+	EventType_CancelSession EventType = 101
+	EventType_FinishSession EventType = 102
+	// 150 ~ 199 for downstream Session events
+	EventType_SessionStarted  EventType = 150
+	EventType_SessionCanceled EventType = 151
+	EventType_SessionFinished EventType = 152
+	EventType_SessionFailed   EventType = 153
+	// Usage events
+	EventType_UsageResponse EventType = 154
+	EventType_ChargeData    EventType = 154 // Alias of "UsageResponse"
+	// 200 ~ 249 for upstream general events
+	EventType_TaskRequest  EventType = 200
+	EventType_UpdateConfig EventType = 201
+	// 250 ~ 299 for downstream general events
+	EventType_AudioMuted EventType = 250
+	// 300 ~ 349 for upstream TTS events
+	EventType_SayHello EventType = 300
+	// 350 ~ 399 for downstream TTS events
+	EventType_TTSSentenceStart     EventType = 350
+	EventType_TTSSentenceEnd       EventType = 351
+	EventType_TTSResponse          EventType = 352
+	EventType_TTSEnded             EventType = 359
+	EventType_PodcastRoundStart    EventType = 360
+	EventType_PodcastRoundResponse EventType = 361
+	EventType_PodcastRoundEnd      EventType = 362
+	// 450 ~ 499 for downstream ASR events
+	EventType_ASRInfo     EventType = 450
+	EventType_ASRResponse EventType = 451
+	EventType_ASREnded    EventType = 459
+	// 500 ~ 549 for upstream dialogue events
+	// (Ground-Truth-Alignment) text for speech synthesis
+	EventType_ChatTTSText EventType = 500
+	// 550 ~ 599 for downstream dialogue events
+	EventType_ChatResponse EventType = 550
+	EventType_ChatEnded    EventType = 559
+	// 650 ~ 699 for downstream dialogue events
+	// Events for source (original) language subtitle.
+	EventType_SourceSubtitleStart    EventType = 650
+	EventType_SourceSubtitleResponse EventType = 651
+	EventType_SourceSubtitleEnd      EventType = 652
+	// Events for target (translation) language subtitle.
+	EventType_TranslationSubtitleStart    EventType = 653
+	EventType_TranslationSubtitleResponse EventType = 654
+	EventType_TranslationSubtitleEnd      EventType = 655
+)
+
+func (t EventType) String() string {
+	switch t {
+	case EventType_None:
+		return "EventType_None"
+	case EventType_StartConnection:
+		return "EventType_StartConnection"
+	case EventType_FinishConnection:
+		return "EventType_FinishConnection"
+	case EventType_ConnectionStarted:
+		return "EventType_ConnectionStarted"
+	case EventType_ConnectionFailed:
+		return "EventType_ConnectionFailed"
+	case EventType_ConnectionFinished:
+		return "EventType_ConnectionFinished"
+	case EventType_StartSession:
+		return "EventType_StartSession"
+	case EventType_CancelSession:
+		return "EventType_CancelSession"
+	case EventType_FinishSession:
+		return "EventType_FinishSession"
+	case EventType_SessionStarted:
+		return "EventType_SessionStarted"
+	case EventType_SessionCanceled:
+		return "EventType_SessionCanceled"
+	case EventType_SessionFinished:
+		return "EventType_SessionFinished"
+	case EventType_SessionFailed:
+		return "EventType_SessionFailed"
+	case EventType_UsageResponse:
+		return "EventType_UsageResponse"
+	case EventType_TaskRequest:
+		return "EventType_TaskRequest"
+	case EventType_UpdateConfig:
+		return "EventType_UpdateConfig"
+	case EventType_AudioMuted:
+		return "EventType_AudioMuted"
+	case EventType_SayHello:
+		return "EventType_SayHello"
+	case EventType_TTSSentenceStart:
+		return "EventType_TTSSentenceStart"
+	case EventType_TTSSentenceEnd:
+		return "EventType_TTSSentenceEnd"
+	case EventType_TTSResponse:
+		return "EventType_TTSResponse"
+	case EventType_TTSEnded:
+		return "EventType_TTSEnded"
+	case EventType_PodcastRoundStart:
+		return "EventType_PodcastRoundStart"
+	case EventType_PodcastRoundResponse:
+		return "EventType_PodcastRoundResponse"
+	case EventType_PodcastRoundEnd:
+		return "EventType_PodcastRoundEnd"
+	case EventType_ASRInfo:
+		return "EventType_ASRInfo"
+	case EventType_ASRResponse:
+		return "EventType_ASRResponse"
+	case EventType_ASREnded:
+		return "EventType_ASREnded"
+	case EventType_ChatTTSText:
+		return "EventType_ChatTTSText"
+	case EventType_ChatResponse:
+		return "EventType_ChatResponse"
+	case EventType_ChatEnded:
+		return "EventType_ChatEnded"
+	case EventType_SourceSubtitleStart:
+		return "EventType_SourceSubtitleStart"
+	case EventType_SourceSubtitleResponse:
+		return "EventType_SourceSubtitleResponse"
+	case EventType_SourceSubtitleEnd:
+		return "EventType_SourceSubtitleEnd"
+	case EventType_TranslationSubtitleStart:
+		return "EventType_TranslationSubtitleStart"
+	case EventType_TranslationSubtitleResponse:
+		return "EventType_TranslationSubtitleResponse"
+	case EventType_TranslationSubtitleEnd:
+		return "EventType_TranslationSubtitleEnd"
+	default:
+		return fmt.Sprintf("EventType_(%d)", t)
+	}
+}
+
+// 0                 1                 2                 3
+// | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// |    Version      |   Header Size   |     Msg Type    |      Flags      |
+// |   (4 bits)      |    (4 bits)     |     (4 bits)    |     (4 bits)    |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Serialization   |   Compression   |           Reserved                |
+// |   (4 bits)      |    (4 bits)     |           (8 bits)                |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// |                                                                       |
+// |                   Optional Header Extensions                          |
+// |                     (if Header Size > 1)                              |
+// |                                                                       |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// |                                                                       |
+// |                           Payload                                     |
+// |                      (variable length)                                |
+// |                                                                       |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+
+type Message struct {
+	Version       VersionBits
+	HeaderSize    HeaderSizeBits
+	MsgType       MsgType
+	MsgTypeFlag   MsgTypeFlagBits
+	Serialization SerializationBits
+	Compression   CompressionBits
+
+	EventType EventType
+	SessionID string
+	ConnectID string
+	Sequence  int32
+	ErrorCode uint32
+
+	Payload []byte
+}
+
+func NewMessageFromBytes(data []byte) (*Message, error) {
+	if len(data) < 3 {
+		return nil, fmt.Errorf("data too short: expected at least 3 bytes, got %d", len(data))
+	}
+
+	typeAndFlag := data[1]
+
+	msg, err := NewMessage(MsgType(typeAndFlag>>4), MsgTypeFlagBits(typeAndFlag&0b00001111))
+	if err != nil {
+		return nil, err
+	}
+
+	if err := msg.Unmarshal(data); err != nil {
+		return nil, err
+	}
+
+	return msg, nil
+}
+
+func NewMessage(msgType MsgType, flag MsgTypeFlagBits) (*Message, error) {
+	return &Message{
+		MsgType:       msgType,
+		MsgTypeFlag:   flag,
+		Version:       Version1,
+		HeaderSize:    HeaderSize4,
+		Serialization: SerializationJSON,
+		Compression:   CompressionNone,
+	}, nil
+}
+
+func (m *Message) String() string {
+	switch m.MsgType {
+	case MsgTypeAudioOnlyServer, MsgTypeAudioOnlyClient:
+		if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq {
+			return fmt.Sprintf("%s, %s, Sequence: %d, PayloadSize: %d", m.MsgType, m.EventType, m.Sequence, len(m.Payload))
+		}
+		return fmt.Sprintf("%s, %s, PayloadSize: %d", m.MsgType, m.EventType, len(m.Payload))
+	case MsgTypeError:
+		return fmt.Sprintf("%s, %s, ErrorCode: %d, Payload: %s", m.MsgType, m.EventType, m.ErrorCode, string(m.Payload))
+	default:
+		if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq {
+			return fmt.Sprintf("%s, %s, Sequence: %d, Payload: %s",
+				m.MsgType, m.EventType, m.Sequence, string(m.Payload))
+		}
+		return fmt.Sprintf("%s, %s, Payload: %s", m.MsgType, m.EventType, string(m.Payload))
+	}
+}
+
+func (m *Message) Marshal() ([]byte, error) {
+	buf := new(bytes.Buffer)
+
+	header := []uint8{
+		uint8(m.Version)<<4 | uint8(m.HeaderSize),
+		uint8(m.MsgType)<<4 | uint8(m.MsgTypeFlag),
+		uint8(m.Serialization)<<4 | uint8(m.Compression),
+	}
+
+	headerSize := 4 * int(m.HeaderSize)
+	if padding := headerSize - len(header); padding > 0 {
+		header = append(header, make([]uint8, padding)...)
+	}
+
+	if err := binary.Write(buf, binary.BigEndian, header); err != nil {
+		return nil, err
+	}
+
+	writers, err := m.writers()
+	if err != nil {
+		return nil, err
+	}
+
+	for _, write := range writers {
+		if err := write(buf); err != nil {
+			return nil, err
+		}
+	}
+
+	return buf.Bytes(), nil
+}
+
+func (m *Message) Unmarshal(data []byte) error {
+	buf := bytes.NewBuffer(data)
+
+	versionAndHeaderSize, err := buf.ReadByte()
+	if err != nil {
+		return err
+	}
+
+	m.Version = VersionBits(versionAndHeaderSize >> 4)
+	m.HeaderSize = HeaderSizeBits(versionAndHeaderSize & 0b00001111)
+
+	_, err = buf.ReadByte()
+	if err != nil {
+		return err
+	}
+
+	serializationCompression, err := buf.ReadByte()
+	if err != nil {
+		return err
+	}
+
+	m.Serialization = SerializationBits(serializationCompression & 0b11110000)
+	m.Compression = CompressionBits(serializationCompression & 0b00001111)
+
+	headerSize := 4 * int(m.HeaderSize)
+	readSize := 3
+	if paddingSize := headerSize - readSize; paddingSize > 0 {
+		if n, err := buf.Read(make([]byte, paddingSize)); err != nil || n < paddingSize {
+			return fmt.Errorf("insufficient header bytes: expected %d, got %d", paddingSize, n)
+		}
+	}
+
+	readers, err := m.readers()
+	if err != nil {
+		return err
+	}
+
+	for _, read := range readers {
+		if err := read(buf); err != nil {
+			return err
+		}
+	}
+
+	if _, err := buf.ReadByte(); err != io.EOF {
+		return fmt.Errorf("unexpected data after message: %v", err)
+	}
+
+	return nil
+}
+
+func (m *Message) writers() (writers []func(*bytes.Buffer) error, _ error) {
+	if m.MsgTypeFlag == MsgTypeFlagWithEvent {
+		writers = append(writers, m.writeEvent, m.writeSessionID)
+	}
+
+	switch m.MsgType {
+	case MsgTypeFullClientRequest, MsgTypeFullServerResponse, MsgTypeFrontEndResultServer, MsgTypeAudioOnlyClient, MsgTypeAudioOnlyServer:
+		if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq {
+			writers = append(writers, m.writeSequence)
+		}
+	case MsgTypeError:
+		writers = append(writers, m.writeErrorCode)
+	default:
+		return nil, fmt.Errorf("unsupported message type: %d", m.MsgType)
+	}
+
+	writers = append(writers, m.writePayload)
+	return writers, nil
+}
+
+func (m *Message) writeEvent(buf *bytes.Buffer) error {
+	return binary.Write(buf, binary.BigEndian, m.EventType)
+}
+
+func (m *Message) writeSessionID(buf *bytes.Buffer) error {
+	switch m.EventType {
+	case EventType_StartConnection, EventType_FinishConnection,
+		EventType_ConnectionStarted, EventType_ConnectionFailed:
+		return nil
+	}
+
+	size := len(m.SessionID)
+	if size > math.MaxUint32 {
+		return fmt.Errorf("session ID size (%d) exceeds max(uint32)", size)
+	}
+
+	if err := binary.Write(buf, binary.BigEndian, uint32(size)); err != nil {
+		return err
+	}
+
+	buf.WriteString(m.SessionID)
+	return nil
+}
+
+func (m *Message) writeSequence(buf *bytes.Buffer) error {
+	return binary.Write(buf, binary.BigEndian, m.Sequence)
+}
+
+func (m *Message) writeErrorCode(buf *bytes.Buffer) error {
+	return binary.Write(buf, binary.BigEndian, m.ErrorCode)
+}
+
+func (m *Message) writePayload(buf *bytes.Buffer) error {
+	size := len(m.Payload)
+	if size > math.MaxUint32 {
+		return fmt.Errorf("payload size (%d) exceeds max(uint32)", size)
+	}
+
+	if err := binary.Write(buf, binary.BigEndian, uint32(size)); err != nil {
+		return err
+	}
+
+	buf.Write(m.Payload)
+	return nil
+}
+
+func (m *Message) readers() (readers []func(*bytes.Buffer) error, _ error) {
+	switch m.MsgType {
+	case MsgTypeFullClientRequest, MsgTypeFullServerResponse, MsgTypeFrontEndResultServer, MsgTypeAudioOnlyClient, MsgTypeAudioOnlyServer:
+		if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq {
+			readers = append(readers, m.readSequence)
+		}
+	case MsgTypeError:
+		readers = append(readers, m.readErrorCode)
+	default:
+		return nil, fmt.Errorf("unsupported message type: %d", m.MsgType)
+	}
+
+	if m.MsgTypeFlag == MsgTypeFlagWithEvent {
+		readers = append(readers, m.readEvent, m.readSessionID, m.readConnectID)
+	}
+
+	readers = append(readers, m.readPayload)
+	return readers, nil
+}
+
+func (m *Message) readEvent(buf *bytes.Buffer) error {
+	return binary.Read(buf, binary.BigEndian, &m.EventType)
+}
+
+func (m *Message) readSessionID(buf *bytes.Buffer) error {
+	switch m.EventType {
+	case EventType_StartConnection, EventType_FinishConnection,
+		EventType_ConnectionStarted, EventType_ConnectionFailed,
+		EventType_ConnectionFinished:
+		return nil
+	}
+
+	var size uint32
+	if err := binary.Read(buf, binary.BigEndian, &size); err != nil {
+		return err
+	}
+
+	if size > 0 {
+		m.SessionID = string(buf.Next(int(size)))
+	}
+
+	return nil
+}
+
+func (m *Message) readConnectID(buf *bytes.Buffer) error {
+	switch m.EventType {
+	case EventType_ConnectionStarted, EventType_ConnectionFailed,
+		EventType_ConnectionFinished:
+	default:
+		return nil
+	}
+
+	var size uint32
+	if err := binary.Read(buf, binary.BigEndian, &size); err != nil {
+		return err
+	}
+
+	if size > 0 {
+		m.ConnectID = string(buf.Next(int(size)))
+	}
+
+	return nil
+}
+
+func (m *Message) readSequence(buf *bytes.Buffer) error {
+	return binary.Read(buf, binary.BigEndian, &m.Sequence)
+}
+
+func (m *Message) readErrorCode(buf *bytes.Buffer) error {
+	return binary.Read(buf, binary.BigEndian, &m.ErrorCode)
+}
+
+func (m *Message) readPayload(buf *bytes.Buffer) error {
+	var size uint32
+	if err := binary.Read(buf, binary.BigEndian, &size); err != nil {
+		return err
+	}
+
+	if size > 0 {
+		m.Payload = buf.Next(int(size))
+	}
+
+	return nil
+}
+
+func ReceiveMessage(conn *websocket.Conn) (*Message, error) {
+	mt, frame, err := conn.ReadMessage()
+	if err != nil {
+		return nil, err
+	}
+	if mt != websocket.BinaryMessage && mt != websocket.TextMessage {
+		return nil, fmt.Errorf("unexpected Websocket message type: %d", mt)
+	}
+	msg, err := NewMessageFromBytes(frame)
+	if err != nil {
+		return nil, err
+	}
+	// Log: receive msg
+	return msg, nil
+}
+
+func WaitForEvent(conn *websocket.Conn, msgType MsgType, eventType EventType) (*Message, error) {
+	for {
+		msg, err := ReceiveMessage(conn)
+		if err != nil {
+			return nil, err
+		}
+		if msg.MsgType != msgType || msg.EventType != eventType {
+			return nil, fmt.Errorf("unexpected message: %s", msg)
+		}
+		if msg.MsgType == msgType && msg.EventType == eventType {
+			return msg, nil
+		}
+	}
+}
+
+func FullClientRequest(conn *websocket.Conn, payload []byte) error {
+	msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagNoSeq)
+	if err != nil {
+		return err
+	}
+	msg.Payload = payload
+	// Log: send msg
+	frame, err := msg.Marshal()
+	if err != nil {
+		return err
+	}
+	return conn.WriteMessage(websocket.BinaryMessage, frame)
+}
+
+func AudioOnlyClient(conn *websocket.Conn, payload []byte, flag MsgTypeFlagBits) error {
+	msg, err := NewMessage(MsgTypeAudioOnlyClient, flag)
+	if err != nil {
+		return err
+	}
+	msg.Payload = payload
+	// Log: send msg
+	frame, err := msg.Marshal()
+	if err != nil {
+		return err
+	}
+	return conn.WriteMessage(websocket.BinaryMessage, frame)
+}
+
+func StartConnection(conn *websocket.Conn) error {
+	msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent)
+	if err != nil {
+		return err
+	}
+	msg.EventType = EventType_StartConnection
+	msg.Payload = []byte("{}")
+	// Log: send msg
+	frame, err := msg.Marshal()
+	if err != nil {
+		return err
+	}
+	return conn.WriteMessage(websocket.BinaryMessage, frame)
+}
+
+func FinishConnection(conn *websocket.Conn) error {
+	msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent)
+	if err != nil {
+		return err
+	}
+	msg.EventType = EventType_FinishConnection
+	msg.Payload = []byte("{}")
+	// Log: send msg
+	frame, err := msg.Marshal()
+	if err != nil {
+		return err
+	}
+	return conn.WriteMessage(websocket.BinaryMessage, frame)
+}
+
+func StartSession(conn *websocket.Conn, payload []byte, sessionID string) error {
+	msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent)
+	if err != nil {
+		return err
+	}
+	msg.EventType = EventType_StartSession
+	msg.SessionID = sessionID
+	msg.Payload = payload
+	// Log: send msg
+	frame, err := msg.Marshal()
+	if err != nil {
+		return err
+	}
+	return conn.WriteMessage(websocket.BinaryMessage, frame)
+}
+
+func FinishSession(conn *websocket.Conn, sessionID string) error {
+	msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent)
+	if err != nil {
+		return err
+	}
+	msg.EventType = EventType_FinishSession
+	msg.SessionID = sessionID
+	msg.Payload = []byte("{}")
+	// Log: send msg
+	frame, err := msg.Marshal()
+	if err != nil {
+		return err
+	}
+	return conn.WriteMessage(websocket.BinaryMessage, frame)
+}
+
+func CancelSession(conn *websocket.Conn, sessionID string) error {
+	msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent)
+	if err != nil {
+		return err
+	}
+	msg.EventType = EventType_CancelSession
+	msg.SessionID = sessionID
+	msg.Payload = []byte("{}")
+	// Log: send msg
+	frame, err := msg.Marshal()
+	if err != nil {
+		return err
+	}
+	return conn.WriteMessage(websocket.BinaryMessage, frame)
+}
+
+func TaskRequest(conn *websocket.Conn, payload []byte, sessionID string) error {
+	msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent)
+	if err != nil {
+		return err
+	}
+	msg.EventType = EventType_TaskRequest
+	msg.SessionID = sessionID
+	msg.Payload = payload
+	// Log: send msg
+	frame, err := msg.Marshal()
+	if err != nil {
+		return err
+	}
+	return conn.WriteMessage(websocket.BinaryMessage, frame)
+}

+ 129 - 0
relay/channel/volcengine/tts.go

@@ -1,9 +1,11 @@
 package volcengine
 
 import (
+	"context"
 	"encoding/base64"
 	"encoding/json"
 	"errors"
+	"fmt"
 	"io"
 	"net/http"
 	"strings"
@@ -13,6 +15,7 @@ import (
 	"github.com/QuantumNous/new-api/types"
 	"github.com/gin-gonic/gin"
 	"github.com/google/uuid"
+	"github.com/gorilla/websocket"
 )
 
 type VolcengineTTSRequest struct {
@@ -192,3 +195,129 @@ func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.Re
 func generateRequestID() string {
 	return uuid.New().String()
 }
+
+// handleTTSWebSocketResponse handles streaming TTS response via WebSocket
+func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest VolcengineTTSRequest, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) {
+	// Parse API key for auth
+	_, token, parseErr := parseVolcengineAuth(info.ApiKey)
+	if parseErr != nil {
+		return nil, types.NewErrorWithStatusCode(
+			parseErr,
+			types.ErrorCodeChannelInvalidKey,
+			http.StatusUnauthorized,
+		)
+	}
+
+	// Setup WebSocket headers
+	header := http.Header{}
+	header.Set("Authorization", fmt.Sprintf("Bearer;%s", token))
+
+	// Dial WebSocket connection
+	conn, resp, dialErr := websocket.DefaultDialer.DialContext(context.Background(), requestURL, header)
+	if dialErr != nil {
+		if resp != nil {
+			return nil, types.NewErrorWithStatusCode(
+				fmt.Errorf("failed to connect to websocket: %w, status: %d", dialErr, resp.StatusCode),
+				types.ErrorCodeBadResponseStatusCode,
+				http.StatusBadGateway,
+			)
+		}
+		return nil, types.NewErrorWithStatusCode(
+			fmt.Errorf("failed to connect to websocket: %w", dialErr),
+			types.ErrorCodeBadResponseStatusCode,
+			http.StatusBadGateway,
+		)
+	}
+	defer conn.Close()
+
+	// Update request operation to "submit" for WebSocket
+	volcRequest.Request.Operation = "submit"
+
+	// Marshal request payload
+	payload, marshalErr := json.Marshal(volcRequest)
+	if marshalErr != nil {
+		return nil, types.NewErrorWithStatusCode(
+			fmt.Errorf("failed to marshal request: %w", marshalErr),
+			types.ErrorCodeBadRequestBody,
+			http.StatusInternalServerError,
+		)
+	}
+
+	// Send full client request
+	if sendErr := FullClientRequest(conn, payload); sendErr != nil {
+		return nil, types.NewErrorWithStatusCode(
+			fmt.Errorf("failed to send request: %w", sendErr),
+			types.ErrorCodeBadRequestBody,
+			http.StatusInternalServerError,
+		)
+	}
+
+	// Set response headers
+	contentType := getContentTypeByEncoding(encoding)
+	c.Header("Content-Type", contentType)
+	c.Header("Transfer-Encoding", "chunked")
+
+	// Stream audio data
+	var audioBuffer []byte
+	for {
+		msg, recvErr := ReceiveMessage(conn)
+		if recvErr != nil {
+			if websocket.IsCloseError(recvErr, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
+				break
+			}
+			return nil, types.NewErrorWithStatusCode(
+				fmt.Errorf("failed to receive message: %w", recvErr),
+				types.ErrorCodeBadResponse,
+				http.StatusInternalServerError,
+			)
+		}
+
+		switch msg.MsgType {
+		case MsgTypeError:
+			return nil, types.NewErrorWithStatusCode(
+				fmt.Errorf("received error from server: code=%d, %s", msg.ErrorCode, string(msg.Payload)),
+				types.ErrorCodeBadResponse,
+				http.StatusBadRequest,
+			)
+		case MsgTypeFrontEndResultServer:
+			// Metadata response, can be logged or processed
+			continue
+		case MsgTypeAudioOnlyServer:
+			// Stream audio chunk to client
+			if len(msg.Payload) > 0 {
+				audioBuffer = append(audioBuffer, msg.Payload...)
+				if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil {
+					return nil, types.NewErrorWithStatusCode(
+						fmt.Errorf("failed to write audio data: %w", writeErr),
+						types.ErrorCodeBadResponse,
+						http.StatusInternalServerError,
+					)
+				}
+				c.Writer.Flush()
+			}
+
+			// Check if this is the last packet (negative sequence)
+			if msg.Sequence < 0 {
+				c.Status(http.StatusOK)
+				usage = &dto.Usage{
+					PromptTokens:     info.PromptTokens,
+					CompletionTokens: 0,
+					TotalTokens:      info.PromptTokens,
+				}
+				return usage, nil
+			}
+		default:
+			// Unknown message type, log and continue
+			continue
+		}
+	}
+
+	// If we reach here, connection closed without final packet
+	c.Status(http.StatusOK)
+	usage = &dto.Usage{
+		PromptTokens:     info.PromptTokens,
+		CompletionTokens: 0,
+		TotalTokens:      info.PromptTokens,
+	}
+	return usage, nil
+}