protocols.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. package volcengine
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "fmt"
  6. "io"
  7. "math"
  8. "github.com/gorilla/websocket"
  9. )
  10. type (
  11. EventType int32
  12. MsgType uint8
  13. MsgTypeFlagBits uint8
  14. VersionBits uint8
  15. HeaderSizeBits uint8
  16. SerializationBits uint8
  17. CompressionBits uint8
  18. )
  19. const (
  20. MsgTypeFlagNoSeq MsgTypeFlagBits = 0
  21. MsgTypeFlagPositiveSeq MsgTypeFlagBits = 0b1
  22. MsgTypeFlagNegativeSeq MsgTypeFlagBits = 0b11
  23. MsgTypeFlagWithEvent MsgTypeFlagBits = 0b100
  24. )
  25. const (
  26. Version1 VersionBits = iota + 1
  27. )
  28. const (
  29. HeaderSize4 HeaderSizeBits = iota + 1
  30. )
  31. const (
  32. SerializationJSON SerializationBits = 0b1
  33. )
  34. const (
  35. CompressionNone CompressionBits = 0
  36. )
  37. const (
  38. MsgTypeFullClientRequest MsgType = 0b1
  39. MsgTypeAudioOnlyClient MsgType = 0b10
  40. MsgTypeFullServerResponse MsgType = 0b1001
  41. MsgTypeAudioOnlyServer MsgType = 0b1011
  42. MsgTypeFrontEndResultServer MsgType = 0b1100
  43. MsgTypeError MsgType = 0b1111
  44. )
  45. func (t MsgType) String() string {
  46. switch t {
  47. case MsgTypeFullClientRequest:
  48. return "MsgType_FullClientRequest"
  49. case MsgTypeAudioOnlyClient:
  50. return "MsgType_AudioOnlyClient"
  51. case MsgTypeFullServerResponse:
  52. return "MsgType_FullServerResponse"
  53. case MsgTypeAudioOnlyServer:
  54. return "MsgType_AudioOnlyServer"
  55. case MsgTypeError:
  56. return "MsgType_Error"
  57. case MsgTypeFrontEndResultServer:
  58. return "MsgType_FrontEndResultServer"
  59. default:
  60. return fmt.Sprintf("MsgType_(%d)", t)
  61. }
  62. }
  63. const (
  64. EventType_None EventType = 0
  65. EventType_StartConnection EventType = 1
  66. EventType_FinishConnection EventType = 2
  67. EventType_ConnectionStarted EventType = 50
  68. EventType_ConnectionFailed EventType = 51
  69. EventType_ConnectionFinished EventType = 52
  70. EventType_StartSession EventType = 100
  71. EventType_CancelSession EventType = 101
  72. EventType_FinishSession EventType = 102
  73. EventType_SessionStarted EventType = 150
  74. EventType_SessionCanceled EventType = 151
  75. EventType_SessionFinished EventType = 152
  76. EventType_SessionFailed EventType = 153
  77. EventType_UsageResponse EventType = 154
  78. EventType_TaskRequest EventType = 200
  79. EventType_UpdateConfig EventType = 201
  80. EventType_AudioMuted EventType = 250
  81. EventType_SayHello EventType = 300
  82. EventType_TTSSentenceStart EventType = 350
  83. EventType_TTSSentenceEnd EventType = 351
  84. EventType_TTSResponse EventType = 352
  85. EventType_TTSEnded EventType = 359
  86. EventType_PodcastRoundStart EventType = 360
  87. EventType_PodcastRoundResponse EventType = 361
  88. EventType_PodcastRoundEnd EventType = 362
  89. EventType_ASRInfo EventType = 450
  90. EventType_ASRResponse EventType = 451
  91. EventType_ASREnded EventType = 459
  92. EventType_ChatTTSText EventType = 500
  93. EventType_ChatResponse EventType = 550
  94. EventType_ChatEnded EventType = 559
  95. EventType_SourceSubtitleStart EventType = 650
  96. EventType_SourceSubtitleResponse EventType = 651
  97. EventType_SourceSubtitleEnd EventType = 652
  98. EventType_TranslationSubtitleStart EventType = 653
  99. EventType_TranslationSubtitleResponse EventType = 654
  100. EventType_TranslationSubtitleEnd EventType = 655
  101. )
  102. func (t EventType) String() string {
  103. switch t {
  104. case EventType_None:
  105. return "EventType_None"
  106. case EventType_StartConnection:
  107. return "EventType_StartConnection"
  108. case EventType_FinishConnection:
  109. return "EventType_FinishConnection"
  110. case EventType_ConnectionStarted:
  111. return "EventType_ConnectionStarted"
  112. case EventType_ConnectionFailed:
  113. return "EventType_ConnectionFailed"
  114. case EventType_ConnectionFinished:
  115. return "EventType_ConnectionFinished"
  116. case EventType_StartSession:
  117. return "EventType_StartSession"
  118. case EventType_CancelSession:
  119. return "EventType_CancelSession"
  120. case EventType_FinishSession:
  121. return "EventType_FinishSession"
  122. case EventType_SessionStarted:
  123. return "EventType_SessionStarted"
  124. case EventType_SessionCanceled:
  125. return "EventType_SessionCanceled"
  126. case EventType_SessionFinished:
  127. return "EventType_SessionFinished"
  128. case EventType_SessionFailed:
  129. return "EventType_SessionFailed"
  130. case EventType_UsageResponse:
  131. return "EventType_UsageResponse"
  132. case EventType_TaskRequest:
  133. return "EventType_TaskRequest"
  134. case EventType_UpdateConfig:
  135. return "EventType_UpdateConfig"
  136. case EventType_AudioMuted:
  137. return "EventType_AudioMuted"
  138. case EventType_SayHello:
  139. return "EventType_SayHello"
  140. case EventType_TTSSentenceStart:
  141. return "EventType_TTSSentenceStart"
  142. case EventType_TTSSentenceEnd:
  143. return "EventType_TTSSentenceEnd"
  144. case EventType_TTSResponse:
  145. return "EventType_TTSResponse"
  146. case EventType_TTSEnded:
  147. return "EventType_TTSEnded"
  148. case EventType_PodcastRoundStart:
  149. return "EventType_PodcastRoundStart"
  150. case EventType_PodcastRoundResponse:
  151. return "EventType_PodcastRoundResponse"
  152. case EventType_PodcastRoundEnd:
  153. return "EventType_PodcastRoundEnd"
  154. case EventType_ASRInfo:
  155. return "EventType_ASRInfo"
  156. case EventType_ASRResponse:
  157. return "EventType_ASRResponse"
  158. case EventType_ASREnded:
  159. return "EventType_ASREnded"
  160. case EventType_ChatTTSText:
  161. return "EventType_ChatTTSText"
  162. case EventType_ChatResponse:
  163. return "EventType_ChatResponse"
  164. case EventType_ChatEnded:
  165. return "EventType_ChatEnded"
  166. case EventType_SourceSubtitleStart:
  167. return "EventType_SourceSubtitleStart"
  168. case EventType_SourceSubtitleResponse:
  169. return "EventType_SourceSubtitleResponse"
  170. case EventType_SourceSubtitleEnd:
  171. return "EventType_SourceSubtitleEnd"
  172. case EventType_TranslationSubtitleStart:
  173. return "EventType_TranslationSubtitleStart"
  174. case EventType_TranslationSubtitleResponse:
  175. return "EventType_TranslationSubtitleResponse"
  176. case EventType_TranslationSubtitleEnd:
  177. return "EventType_TranslationSubtitleEnd"
  178. default:
  179. return fmt.Sprintf("EventType_(%d)", t)
  180. }
  181. }
  182. type Message struct {
  183. Version VersionBits
  184. HeaderSize HeaderSizeBits
  185. MsgType MsgType
  186. MsgTypeFlag MsgTypeFlagBits
  187. Serialization SerializationBits
  188. Compression CompressionBits
  189. EventType EventType
  190. SessionID string
  191. ConnectID string
  192. Sequence int32
  193. ErrorCode uint32
  194. Payload []byte
  195. }
  196. func NewMessageFromBytes(data []byte) (*Message, error) {
  197. if len(data) < 3 {
  198. return nil, fmt.Errorf("data too short: expected at least 3 bytes, got %d", len(data))
  199. }
  200. typeAndFlag := data[1]
  201. msg, err := NewMessage(MsgType(typeAndFlag>>4), MsgTypeFlagBits(typeAndFlag&0b00001111))
  202. if err != nil {
  203. return nil, err
  204. }
  205. if err := msg.Unmarshal(data); err != nil {
  206. return nil, err
  207. }
  208. return msg, nil
  209. }
  210. func NewMessage(msgType MsgType, flag MsgTypeFlagBits) (*Message, error) {
  211. return &Message{
  212. MsgType: msgType,
  213. MsgTypeFlag: flag,
  214. Version: Version1,
  215. HeaderSize: HeaderSize4,
  216. Serialization: SerializationJSON,
  217. Compression: CompressionNone,
  218. }, nil
  219. }
  220. func (m *Message) String() string {
  221. switch m.MsgType {
  222. case MsgTypeAudioOnlyServer, MsgTypeAudioOnlyClient:
  223. if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq {
  224. return fmt.Sprintf("%s, %s, Sequence: %d, PayloadSize: %d", m.MsgType, m.EventType, m.Sequence, len(m.Payload))
  225. }
  226. return fmt.Sprintf("%s, %s, PayloadSize: %d", m.MsgType, m.EventType, len(m.Payload))
  227. case MsgTypeError:
  228. return fmt.Sprintf("%s, %s, ErrorCode: %d, Payload: %s", m.MsgType, m.EventType, m.ErrorCode, string(m.Payload))
  229. default:
  230. if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq {
  231. return fmt.Sprintf("%s, %s, Sequence: %d, Payload: %s",
  232. m.MsgType, m.EventType, m.Sequence, string(m.Payload))
  233. }
  234. return fmt.Sprintf("%s, %s, Payload: %s", m.MsgType, m.EventType, string(m.Payload))
  235. }
  236. }
  237. func (m *Message) Marshal() ([]byte, error) {
  238. buf := new(bytes.Buffer)
  239. header := []uint8{
  240. uint8(m.Version)<<4 | uint8(m.HeaderSize),
  241. uint8(m.MsgType)<<4 | uint8(m.MsgTypeFlag),
  242. uint8(m.Serialization)<<4 | uint8(m.Compression),
  243. }
  244. headerSize := 4 * int(m.HeaderSize)
  245. if padding := headerSize - len(header); padding > 0 {
  246. header = append(header, make([]uint8, padding)...)
  247. }
  248. if err := binary.Write(buf, binary.BigEndian, header); err != nil {
  249. return nil, err
  250. }
  251. writers, err := m.writers()
  252. if err != nil {
  253. return nil, err
  254. }
  255. for _, write := range writers {
  256. if err := write(buf); err != nil {
  257. return nil, err
  258. }
  259. }
  260. return buf.Bytes(), nil
  261. }
  262. func (m *Message) Unmarshal(data []byte) error {
  263. buf := bytes.NewBuffer(data)
  264. versionAndHeaderSize, err := buf.ReadByte()
  265. if err != nil {
  266. return err
  267. }
  268. m.Version = VersionBits(versionAndHeaderSize >> 4)
  269. m.HeaderSize = HeaderSizeBits(versionAndHeaderSize & 0b00001111)
  270. _, err = buf.ReadByte()
  271. if err != nil {
  272. return err
  273. }
  274. serializationCompression, err := buf.ReadByte()
  275. if err != nil {
  276. return err
  277. }
  278. m.Serialization = SerializationBits(serializationCompression & 0b11110000)
  279. m.Compression = CompressionBits(serializationCompression & 0b00001111)
  280. headerSize := 4 * int(m.HeaderSize)
  281. readSize := 3
  282. if paddingSize := headerSize - readSize; paddingSize > 0 {
  283. if n, err := buf.Read(make([]byte, paddingSize)); err != nil || n < paddingSize {
  284. return fmt.Errorf("insufficient header bytes: expected %d, got %d", paddingSize, n)
  285. }
  286. }
  287. readers, err := m.readers()
  288. if err != nil {
  289. return err
  290. }
  291. for _, read := range readers {
  292. if err := read(buf); err != nil {
  293. return err
  294. }
  295. }
  296. if _, err := buf.ReadByte(); err != io.EOF {
  297. return fmt.Errorf("unexpected data after message: %v", err)
  298. }
  299. return nil
  300. }
  301. func (m *Message) writers() (writers []func(*bytes.Buffer) error, _ error) {
  302. if m.MsgTypeFlag == MsgTypeFlagWithEvent {
  303. writers = append(writers, m.writeEvent, m.writeSessionID)
  304. }
  305. switch m.MsgType {
  306. case MsgTypeFullClientRequest, MsgTypeFullServerResponse, MsgTypeFrontEndResultServer, MsgTypeAudioOnlyClient, MsgTypeAudioOnlyServer:
  307. if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq {
  308. writers = append(writers, m.writeSequence)
  309. }
  310. case MsgTypeError:
  311. writers = append(writers, m.writeErrorCode)
  312. default:
  313. return nil, fmt.Errorf("unsupported message type: %d", m.MsgType)
  314. }
  315. writers = append(writers, m.writePayload)
  316. return writers, nil
  317. }
  318. func (m *Message) writeEvent(buf *bytes.Buffer) error {
  319. return binary.Write(buf, binary.BigEndian, m.EventType)
  320. }
  321. func (m *Message) writeSessionID(buf *bytes.Buffer) error {
  322. switch m.EventType {
  323. case EventType_StartConnection, EventType_FinishConnection,
  324. EventType_ConnectionStarted, EventType_ConnectionFailed:
  325. return nil
  326. }
  327. size := len(m.SessionID)
  328. if size > math.MaxUint32 {
  329. return fmt.Errorf("session ID size (%d) exceeds max(uint32)", size)
  330. }
  331. if err := binary.Write(buf, binary.BigEndian, uint32(size)); err != nil {
  332. return err
  333. }
  334. buf.WriteString(m.SessionID)
  335. return nil
  336. }
  337. func (m *Message) writeSequence(buf *bytes.Buffer) error {
  338. return binary.Write(buf, binary.BigEndian, m.Sequence)
  339. }
  340. func (m *Message) writeErrorCode(buf *bytes.Buffer) error {
  341. return binary.Write(buf, binary.BigEndian, m.ErrorCode)
  342. }
  343. func (m *Message) writePayload(buf *bytes.Buffer) error {
  344. size := len(m.Payload)
  345. if size > math.MaxUint32 {
  346. return fmt.Errorf("payload size (%d) exceeds max(uint32)", size)
  347. }
  348. if err := binary.Write(buf, binary.BigEndian, uint32(size)); err != nil {
  349. return err
  350. }
  351. buf.Write(m.Payload)
  352. return nil
  353. }
  354. func (m *Message) readers() (readers []func(*bytes.Buffer) error, _ error) {
  355. switch m.MsgType {
  356. case MsgTypeFullClientRequest, MsgTypeFullServerResponse, MsgTypeFrontEndResultServer, MsgTypeAudioOnlyClient, MsgTypeAudioOnlyServer:
  357. if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq {
  358. readers = append(readers, m.readSequence)
  359. }
  360. case MsgTypeError:
  361. readers = append(readers, m.readErrorCode)
  362. default:
  363. return nil, fmt.Errorf("unsupported message type: %d", m.MsgType)
  364. }
  365. if m.MsgTypeFlag == MsgTypeFlagWithEvent {
  366. readers = append(readers, m.readEvent, m.readSessionID, m.readConnectID)
  367. }
  368. readers = append(readers, m.readPayload)
  369. return readers, nil
  370. }
  371. func (m *Message) readEvent(buf *bytes.Buffer) error {
  372. return binary.Read(buf, binary.BigEndian, &m.EventType)
  373. }
  374. func (m *Message) readSessionID(buf *bytes.Buffer) error {
  375. switch m.EventType {
  376. case EventType_StartConnection, EventType_FinishConnection,
  377. EventType_ConnectionStarted, EventType_ConnectionFailed,
  378. EventType_ConnectionFinished:
  379. return nil
  380. }
  381. var size uint32
  382. if err := binary.Read(buf, binary.BigEndian, &size); err != nil {
  383. return err
  384. }
  385. if size > 0 {
  386. m.SessionID = string(buf.Next(int(size)))
  387. }
  388. return nil
  389. }
  390. func (m *Message) readConnectID(buf *bytes.Buffer) error {
  391. switch m.EventType {
  392. case EventType_ConnectionStarted, EventType_ConnectionFailed,
  393. EventType_ConnectionFinished:
  394. default:
  395. return nil
  396. }
  397. var size uint32
  398. if err := binary.Read(buf, binary.BigEndian, &size); err != nil {
  399. return err
  400. }
  401. if size > 0 {
  402. m.ConnectID = string(buf.Next(int(size)))
  403. }
  404. return nil
  405. }
  406. func (m *Message) readSequence(buf *bytes.Buffer) error {
  407. return binary.Read(buf, binary.BigEndian, &m.Sequence)
  408. }
  409. func (m *Message) readErrorCode(buf *bytes.Buffer) error {
  410. return binary.Read(buf, binary.BigEndian, &m.ErrorCode)
  411. }
  412. func (m *Message) readPayload(buf *bytes.Buffer) error {
  413. var size uint32
  414. if err := binary.Read(buf, binary.BigEndian, &size); err != nil {
  415. return err
  416. }
  417. if size > 0 {
  418. m.Payload = buf.Next(int(size))
  419. }
  420. return nil
  421. }
  422. func ReceiveMessage(conn *websocket.Conn) (*Message, error) {
  423. mt, frame, err := conn.ReadMessage()
  424. if err != nil {
  425. return nil, err
  426. }
  427. if mt != websocket.BinaryMessage && mt != websocket.TextMessage {
  428. return nil, fmt.Errorf("unexpected Websocket message type: %d", mt)
  429. }
  430. msg, err := NewMessageFromBytes(frame)
  431. if err != nil {
  432. return nil, err
  433. }
  434. return msg, nil
  435. }
  436. func FullClientRequest(conn *websocket.Conn, payload []byte) error {
  437. msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagNoSeq)
  438. if err != nil {
  439. return err
  440. }
  441. msg.Payload = payload
  442. frame, err := msg.Marshal()
  443. if err != nil {
  444. return err
  445. }
  446. return conn.WriteMessage(websocket.BinaryMessage, frame)
  447. }