protocol.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. package hysteria
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "fmt"
  6. "io"
  7. "github.com/apernet/quic-go/quicvarint"
  8. "github.com/xtls/xray-core/common/errors"
  9. )
  10. const (
  11. // Max length values are for preventing DoS attacks
  12. MaxAddressLength = 2048
  13. MaxMessageLength = 2048
  14. MaxPaddingLength = 4096
  15. MaxUDPSize = 4096
  16. maxVarInt1 = 63
  17. maxVarInt2 = 16383
  18. maxVarInt4 = 1073741823
  19. maxVarInt8 = 4611686018427387903
  20. )
  21. // TCPRequest format:
  22. // Address length (QUIC varint)
  23. // Address (bytes)
  24. // Padding length (QUIC varint)
  25. // Padding (bytes)
  26. func ReadTCPRequest(r io.Reader) (string, error) {
  27. bReader := quicvarint.NewReader(r)
  28. addrLen, err := quicvarint.Read(bReader)
  29. if err != nil {
  30. return "", err
  31. }
  32. if addrLen == 0 || addrLen > MaxAddressLength {
  33. return "", errors.New("invalid address length")
  34. }
  35. addrBuf := make([]byte, addrLen)
  36. _, err = io.ReadFull(r, addrBuf)
  37. if err != nil {
  38. return "", err
  39. }
  40. paddingLen, err := quicvarint.Read(bReader)
  41. if err != nil {
  42. return "", err
  43. }
  44. if paddingLen > MaxPaddingLength {
  45. return "", errors.New("invalid padding length")
  46. }
  47. if paddingLen > 0 {
  48. _, err = io.CopyN(io.Discard, r, int64(paddingLen))
  49. if err != nil {
  50. return "", err
  51. }
  52. }
  53. return string(addrBuf), nil
  54. }
  55. func WriteTCPRequest(w io.Writer, addr string) error {
  56. padding := tcpRequestPadding.String()
  57. paddingLen := len(padding)
  58. addrLen := len(addr)
  59. sz := int(quicvarint.Len(uint64(addrLen))) + addrLen +
  60. int(quicvarint.Len(uint64(paddingLen))) + paddingLen
  61. buf := make([]byte, sz)
  62. i := varintPut(buf, uint64(addrLen))
  63. i += copy(buf[i:], addr)
  64. i += varintPut(buf[i:], uint64(paddingLen))
  65. copy(buf[i:], padding)
  66. _, err := w.Write(buf)
  67. return err
  68. }
  69. // TCPResponse format:
  70. // Status (byte, 0=ok, 1=error)
  71. // Message length (QUIC varint)
  72. // Message (bytes)
  73. // Padding length (QUIC varint)
  74. // Padding (bytes)
  75. func ReadTCPResponse(r io.Reader) (bool, string, error) {
  76. var status [1]byte
  77. if _, err := io.ReadFull(r, status[:]); err != nil {
  78. return false, "", err
  79. }
  80. bReader := quicvarint.NewReader(r)
  81. msgLen, err := quicvarint.Read(bReader)
  82. if err != nil {
  83. return false, "", err
  84. }
  85. if msgLen > MaxMessageLength {
  86. return false, "", errors.New("invalid message length")
  87. }
  88. var msgBuf []byte
  89. // No message is fine
  90. if msgLen > 0 {
  91. msgBuf = make([]byte, msgLen)
  92. _, err = io.ReadFull(r, msgBuf)
  93. if err != nil {
  94. return false, "", err
  95. }
  96. }
  97. paddingLen, err := quicvarint.Read(bReader)
  98. if err != nil {
  99. return false, "", err
  100. }
  101. if paddingLen > MaxPaddingLength {
  102. return false, "", errors.New("invalid padding length")
  103. }
  104. if paddingLen > 0 {
  105. _, err = io.CopyN(io.Discard, r, int64(paddingLen))
  106. if err != nil {
  107. return false, "", err
  108. }
  109. }
  110. return status[0] == 0, string(msgBuf), nil
  111. }
  112. func WriteTCPResponse(w io.Writer, ok bool, msg string) error {
  113. padding := tcpResponsePadding.String()
  114. paddingLen := len(padding)
  115. msgLen := len(msg)
  116. sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen +
  117. int(quicvarint.Len(uint64(paddingLen))) + paddingLen
  118. buf := make([]byte, sz)
  119. if ok {
  120. buf[0] = 0
  121. } else {
  122. buf[0] = 1
  123. }
  124. i := varintPut(buf[1:], uint64(msgLen))
  125. i += copy(buf[1+i:], msg)
  126. i += varintPut(buf[1+i:], uint64(paddingLen))
  127. copy(buf[1+i:], padding)
  128. _, err := w.Write(buf)
  129. return err
  130. }
  131. // UDPMessage format:
  132. // Session ID (uint32 BE)
  133. // Packet ID (uint16 BE)
  134. // Fragment ID (uint8)
  135. // Fragment count (uint8)
  136. // Address length (QUIC varint)
  137. // Address (bytes)
  138. // Data...
  139. type UDPMessage struct {
  140. SessionID uint32 // 4
  141. PacketID uint16 // 2
  142. FragID uint8 // 1
  143. FragCount uint8 // 1
  144. Addr string // varint + bytes
  145. Data []byte
  146. }
  147. func (m *UDPMessage) HeaderSize() int {
  148. lAddr := len(m.Addr)
  149. return 4 + 2 + 1 + 1 + int(quicvarint.Len(uint64(lAddr))) + lAddr
  150. }
  151. func (m *UDPMessage) Size() int {
  152. return m.HeaderSize() + len(m.Data)
  153. }
  154. func (m *UDPMessage) Serialize(buf []byte) int {
  155. // Make sure the buffer is big enough
  156. if len(buf) < m.Size() {
  157. return -1
  158. }
  159. // binary.BigEndian.PutUint32(buf, m.SessionID)
  160. binary.BigEndian.PutUint16(buf[4:], m.PacketID)
  161. buf[6] = m.FragID
  162. buf[7] = m.FragCount
  163. i := varintPut(buf[8:], uint64(len(m.Addr)))
  164. i += copy(buf[8+i:], m.Addr)
  165. i += copy(buf[8+i:], m.Data)
  166. return 8 + i
  167. }
  168. func ParseUDPMessage(msg []byte) (*UDPMessage, error) {
  169. m := &UDPMessage{}
  170. buf := bytes.NewBuffer(msg)
  171. if err := binary.Read(buf, binary.BigEndian, &m.SessionID); err != nil {
  172. return nil, err
  173. }
  174. if err := binary.Read(buf, binary.BigEndian, &m.PacketID); err != nil {
  175. return nil, err
  176. }
  177. if err := binary.Read(buf, binary.BigEndian, &m.FragID); err != nil {
  178. return nil, err
  179. }
  180. if err := binary.Read(buf, binary.BigEndian, &m.FragCount); err != nil {
  181. return nil, err
  182. }
  183. lAddr, err := quicvarint.Read(buf)
  184. if err != nil {
  185. return nil, err
  186. }
  187. if lAddr == 0 || lAddr > MaxMessageLength {
  188. return nil, errors.New("invalid address length")
  189. }
  190. bs := buf.Bytes()
  191. if len(bs) <= int(lAddr) {
  192. // We use <= instead of < here as we expect at least one byte of data after the address
  193. return nil, errors.New("invalid message length")
  194. }
  195. m.Addr = string(bs[:lAddr])
  196. m.Data = bs[lAddr:]
  197. return m, nil
  198. }
  199. // varintPut is like quicvarint.Append, but instead of appending to a slice,
  200. // it writes to a fixed-size buffer. Returns the number of bytes written.
  201. func varintPut(b []byte, i uint64) int {
  202. if i <= maxVarInt1 {
  203. b[0] = uint8(i)
  204. return 1
  205. }
  206. if i <= maxVarInt2 {
  207. b[0] = uint8(i>>8) | 0x40
  208. b[1] = uint8(i)
  209. return 2
  210. }
  211. if i <= maxVarInt4 {
  212. b[0] = uint8(i>>24) | 0x80
  213. b[1] = uint8(i >> 16)
  214. b[2] = uint8(i >> 8)
  215. b[3] = uint8(i)
  216. return 4
  217. }
  218. if i <= maxVarInt8 {
  219. b[0] = uint8(i>>56) | 0xc0
  220. b[1] = uint8(i >> 48)
  221. b[2] = uint8(i >> 40)
  222. b[3] = uint8(i >> 32)
  223. b[4] = uint8(i >> 24)
  224. b[5] = uint8(i >> 16)
  225. b[6] = uint8(i >> 8)
  226. b[7] = uint8(i)
  227. return 8
  228. }
  229. panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i))
  230. }