protocol.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. package hysteria
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "io"
  6. "math/rand"
  7. "net"
  8. "os"
  9. "time"
  10. "github.com/sagernet/quic-go"
  11. "github.com/sagernet/sing/common"
  12. "github.com/sagernet/sing/common/buf"
  13. E "github.com/sagernet/sing/common/exceptions"
  14. M "github.com/sagernet/sing/common/metadata"
  15. )
  16. const (
  17. MbpsToBps = 125000
  18. MinSpeedBPS = 16384
  19. DefaultStreamReceiveWindow = 15728640 // 15 MB/s
  20. DefaultConnectionReceiveWindow = 67108864 // 64 MB/s
  21. DefaultMaxIncomingStreams = 1024
  22. DefaultALPN = "hysteria"
  23. KeepAlivePeriod = 10 * time.Second
  24. )
  25. const Version = 3
  26. type ClientHello struct {
  27. SendBPS uint64
  28. RecvBPS uint64
  29. Auth []byte
  30. }
  31. func WriteClientHello(stream io.Writer, hello ClientHello) error {
  32. var requestLen int
  33. requestLen += 1 // version
  34. requestLen += 8 // sendBPS
  35. requestLen += 8 // recvBPS
  36. requestLen += 2 // auth len
  37. requestLen += len(hello.Auth)
  38. request := buf.NewSize(requestLen)
  39. defer request.Release()
  40. common.Must(
  41. request.WriteByte(Version),
  42. binary.Write(request, binary.BigEndian, hello.SendBPS),
  43. binary.Write(request, binary.BigEndian, hello.RecvBPS),
  44. binary.Write(request, binary.BigEndian, uint16(len(hello.Auth))),
  45. common.Error(request.Write(hello.Auth)),
  46. )
  47. return common.Error(stream.Write(request.Bytes()))
  48. }
  49. func ReadClientHello(reader io.Reader) (*ClientHello, error) {
  50. var version uint8
  51. err := binary.Read(reader, binary.BigEndian, &version)
  52. if err != nil {
  53. return nil, err
  54. }
  55. if version != Version {
  56. return nil, E.New("unsupported client version: ", version)
  57. }
  58. var clientHello ClientHello
  59. err = binary.Read(reader, binary.BigEndian, &clientHello.SendBPS)
  60. if err != nil {
  61. return nil, err
  62. }
  63. err = binary.Read(reader, binary.BigEndian, &clientHello.RecvBPS)
  64. if err != nil {
  65. return nil, err
  66. }
  67. var authLen uint16
  68. err = binary.Read(reader, binary.BigEndian, &authLen)
  69. if err != nil {
  70. return nil, err
  71. }
  72. clientHello.Auth = make([]byte, authLen)
  73. _, err = io.ReadFull(reader, clientHello.Auth)
  74. if err != nil {
  75. return nil, err
  76. }
  77. return &clientHello, nil
  78. }
  79. type ServerHello struct {
  80. OK bool
  81. SendBPS uint64
  82. RecvBPS uint64
  83. Message string
  84. }
  85. func ReadServerHello(stream io.Reader) (*ServerHello, error) {
  86. var responseLen int
  87. responseLen += 1 // ok
  88. responseLen += 8 // sendBPS
  89. responseLen += 8 // recvBPS
  90. responseLen += 2 // message len
  91. response := buf.NewSize(responseLen)
  92. defer response.Release()
  93. _, err := response.ReadFullFrom(stream, responseLen)
  94. if err != nil {
  95. return nil, err
  96. }
  97. var serverHello ServerHello
  98. serverHello.OK = response.Byte(0) == 1
  99. serverHello.SendBPS = binary.BigEndian.Uint64(response.Range(1, 9))
  100. serverHello.RecvBPS = binary.BigEndian.Uint64(response.Range(9, 17))
  101. messageLen := binary.BigEndian.Uint16(response.Range(17, 19))
  102. if messageLen == 0 {
  103. return &serverHello, nil
  104. }
  105. message := make([]byte, messageLen)
  106. _, err = io.ReadFull(stream, message)
  107. if err != nil {
  108. return nil, err
  109. }
  110. serverHello.Message = string(message)
  111. return &serverHello, nil
  112. }
  113. func WriteServerHello(stream io.Writer, hello ServerHello) error {
  114. var responseLen int
  115. responseLen += 1 // ok
  116. responseLen += 8 // sendBPS
  117. responseLen += 8 // recvBPS
  118. responseLen += 2 // message len
  119. responseLen += len(hello.Message)
  120. response := buf.NewSize(responseLen)
  121. defer response.Release()
  122. if hello.OK {
  123. common.Must(response.WriteByte(1))
  124. } else {
  125. common.Must(response.WriteByte(0))
  126. }
  127. common.Must(
  128. binary.Write(response, binary.BigEndian, hello.SendBPS),
  129. binary.Write(response, binary.BigEndian, hello.RecvBPS),
  130. binary.Write(response, binary.BigEndian, uint16(len(hello.Message))),
  131. common.Error(response.WriteString(hello.Message)),
  132. )
  133. return common.Error(stream.Write(response.Bytes()))
  134. }
  135. type ClientRequest struct {
  136. UDP bool
  137. Host string
  138. Port uint16
  139. }
  140. func ReadClientRequest(stream io.Reader) (*ClientRequest, error) {
  141. var clientRequest ClientRequest
  142. err := binary.Read(stream, binary.BigEndian, &clientRequest.UDP)
  143. if err != nil {
  144. return nil, err
  145. }
  146. var hostLen uint16
  147. err = binary.Read(stream, binary.BigEndian, &hostLen)
  148. if err != nil {
  149. return nil, err
  150. }
  151. host := make([]byte, hostLen)
  152. _, err = io.ReadFull(stream, host)
  153. if err != nil {
  154. return nil, err
  155. }
  156. clientRequest.Host = string(host)
  157. err = binary.Read(stream, binary.BigEndian, &clientRequest.Port)
  158. if err != nil {
  159. return nil, err
  160. }
  161. return &clientRequest, nil
  162. }
  163. func WriteClientRequest(stream io.Writer, request ClientRequest) error {
  164. var requestLen int
  165. requestLen += 1 // udp
  166. requestLen += 2 // host len
  167. requestLen += len(request.Host)
  168. requestLen += 2 // port
  169. buffer := buf.NewSize(requestLen)
  170. defer buffer.Release()
  171. if request.UDP {
  172. common.Must(buffer.WriteByte(1))
  173. } else {
  174. common.Must(buffer.WriteByte(0))
  175. }
  176. common.Must(
  177. binary.Write(buffer, binary.BigEndian, uint16(len(request.Host))),
  178. common.Error(buffer.WriteString(request.Host)),
  179. binary.Write(buffer, binary.BigEndian, request.Port),
  180. )
  181. return common.Error(stream.Write(buffer.Bytes()))
  182. }
  183. type ServerResponse struct {
  184. OK bool
  185. UDPSessionID uint32
  186. Message string
  187. }
  188. func ReadServerResponse(stream io.Reader) (*ServerResponse, error) {
  189. var responseLen int
  190. responseLen += 1 // ok
  191. responseLen += 4 // udp session id
  192. responseLen += 2 // message len
  193. response := buf.NewSize(responseLen)
  194. defer response.Release()
  195. _, err := response.ReadFullFrom(stream, responseLen)
  196. if err != nil {
  197. return nil, err
  198. }
  199. var serverResponse ServerResponse
  200. serverResponse.OK = response.Byte(0) == 1
  201. serverResponse.UDPSessionID = binary.BigEndian.Uint32(response.Range(1, 5))
  202. messageLen := binary.BigEndian.Uint16(response.Range(5, 7))
  203. if messageLen == 0 {
  204. return &serverResponse, nil
  205. }
  206. message := make([]byte, messageLen)
  207. _, err = io.ReadFull(stream, message)
  208. if err != nil {
  209. return nil, err
  210. }
  211. serverResponse.Message = string(message)
  212. return &serverResponse, nil
  213. }
  214. func WriteServerResponse(stream io.Writer, response ServerResponse) error {
  215. var responseLen int
  216. responseLen += 1 // ok
  217. responseLen += 4 // udp session id
  218. responseLen += 2 // message len
  219. responseLen += len(response.Message)
  220. buffer := buf.NewSize(responseLen)
  221. defer buffer.Release()
  222. if response.OK {
  223. common.Must(buffer.WriteByte(1))
  224. } else {
  225. common.Must(buffer.WriteByte(0))
  226. }
  227. common.Must(
  228. binary.Write(buffer, binary.BigEndian, response.UDPSessionID),
  229. binary.Write(buffer, binary.BigEndian, uint16(len(response.Message))),
  230. common.Error(buffer.WriteString(response.Message)),
  231. )
  232. return common.Error(stream.Write(buffer.Bytes()))
  233. }
  234. type UDPMessage struct {
  235. SessionID uint32
  236. Host string
  237. Port uint16
  238. MsgID uint16 // doesn't matter when not fragmented, but must not be 0 when fragmented
  239. FragID uint8 // doesn't matter when not fragmented, starts at 0 when fragmented
  240. FragCount uint8 // must be 1 when not fragmented
  241. Data []byte
  242. }
  243. func (m UDPMessage) HeaderSize() int {
  244. return 4 + 2 + len(m.Host) + 2 + 2 + 1 + 1 + 2
  245. }
  246. func (m UDPMessage) Size() int {
  247. return m.HeaderSize() + len(m.Data)
  248. }
  249. func ParseUDPMessage(packet []byte) (message UDPMessage, err error) {
  250. reader := bytes.NewReader(packet)
  251. err = binary.Read(reader, binary.BigEndian, &message.SessionID)
  252. if err != nil {
  253. return
  254. }
  255. var hostLen uint16
  256. err = binary.Read(reader, binary.BigEndian, &hostLen)
  257. if err != nil {
  258. return
  259. }
  260. _, err = reader.Seek(int64(hostLen), io.SeekCurrent)
  261. if err != nil {
  262. return
  263. }
  264. if 6+int(hostLen) > len(packet) {
  265. err = E.New("invalid host length")
  266. return
  267. }
  268. message.Host = string(packet[6 : 6+hostLen])
  269. err = binary.Read(reader, binary.BigEndian, &message.Port)
  270. if err != nil {
  271. return
  272. }
  273. err = binary.Read(reader, binary.BigEndian, &message.MsgID)
  274. if err != nil {
  275. return
  276. }
  277. err = binary.Read(reader, binary.BigEndian, &message.FragID)
  278. if err != nil {
  279. return
  280. }
  281. err = binary.Read(reader, binary.BigEndian, &message.FragCount)
  282. if err != nil {
  283. return
  284. }
  285. var dataLen uint16
  286. err = binary.Read(reader, binary.BigEndian, &dataLen)
  287. if err != nil {
  288. return
  289. }
  290. if reader.Len() != int(dataLen) {
  291. err = E.New("invalid data length")
  292. }
  293. dataOffset := int(reader.Size()) - reader.Len()
  294. message.Data = packet[dataOffset:]
  295. return
  296. }
  297. func WriteUDPMessage(conn quic.Connection, message UDPMessage) error {
  298. var messageLen int
  299. messageLen += 4 // session id
  300. messageLen += 2 // host len
  301. messageLen += len(message.Host)
  302. messageLen += 2 // port
  303. messageLen += 2 // msg id
  304. messageLen += 1 // frag id
  305. messageLen += 1 // frag count
  306. messageLen += 2 // data len
  307. messageLen += len(message.Data)
  308. buffer := buf.NewSize(messageLen)
  309. defer buffer.Release()
  310. err := writeUDPMessage(conn, message, buffer)
  311. if errSize, ok := err.(quic.ErrMessageTooLarge); ok {
  312. // need to frag
  313. message.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1
  314. fragMsgs := FragUDPMessage(message, int(errSize))
  315. for _, fragMsg := range fragMsgs {
  316. buffer.FullReset()
  317. err = writeUDPMessage(conn, fragMsg, buffer)
  318. if err != nil {
  319. return err
  320. }
  321. }
  322. return nil
  323. }
  324. return err
  325. }
  326. func writeUDPMessage(conn quic.Connection, message UDPMessage, buffer *buf.Buffer) error {
  327. common.Must(
  328. binary.Write(buffer, binary.BigEndian, message.SessionID),
  329. binary.Write(buffer, binary.BigEndian, uint16(len(message.Host))),
  330. common.Error(buffer.WriteString(message.Host)),
  331. binary.Write(buffer, binary.BigEndian, message.Port),
  332. binary.Write(buffer, binary.BigEndian, message.MsgID),
  333. binary.Write(buffer, binary.BigEndian, message.FragID),
  334. binary.Write(buffer, binary.BigEndian, message.FragCount),
  335. binary.Write(buffer, binary.BigEndian, uint16(len(message.Data))),
  336. common.Error(buffer.Write(message.Data)),
  337. )
  338. return conn.SendMessage(buffer.Bytes())
  339. }
  340. var _ net.Conn = (*Conn)(nil)
  341. type Conn struct {
  342. quic.Stream
  343. destination M.Socksaddr
  344. needReadResponse bool
  345. }
  346. func NewConn(stream quic.Stream, destination M.Socksaddr, isClient bool) *Conn {
  347. return &Conn{
  348. Stream: stream,
  349. destination: destination,
  350. needReadResponse: isClient,
  351. }
  352. }
  353. func (c *Conn) Read(p []byte) (n int, err error) {
  354. if c.needReadResponse {
  355. var response *ServerResponse
  356. response, err = ReadServerResponse(c.Stream)
  357. if err != nil {
  358. c.Close()
  359. return
  360. }
  361. if !response.OK {
  362. c.Close()
  363. return 0, E.New("remote error: ", response.Message)
  364. }
  365. c.needReadResponse = false
  366. }
  367. return c.Stream.Read(p)
  368. }
  369. func (c *Conn) LocalAddr() net.Addr {
  370. return M.Socksaddr{}
  371. }
  372. func (c *Conn) RemoteAddr() net.Addr {
  373. return c.destination.TCPAddr()
  374. }
  375. func (c *Conn) ReaderReplaceable() bool {
  376. return !c.needReadResponse
  377. }
  378. func (c *Conn) WriterReplaceable() bool {
  379. return true
  380. }
  381. func (c *Conn) Upstream() any {
  382. return c.Stream
  383. }
  384. type PacketConn struct {
  385. session quic.Connection
  386. stream quic.Stream
  387. sessionId uint32
  388. destination M.Socksaddr
  389. msgCh <-chan *UDPMessage
  390. closer io.Closer
  391. }
  392. func NewPacketConn(session quic.Connection, stream quic.Stream, sessionId uint32, destination M.Socksaddr, msgCh <-chan *UDPMessage, closer io.Closer) *PacketConn {
  393. return &PacketConn{
  394. session: session,
  395. stream: stream,
  396. sessionId: sessionId,
  397. destination: destination,
  398. msgCh: msgCh,
  399. closer: closer,
  400. }
  401. }
  402. func (c *PacketConn) Hold() {
  403. // Hold the stream until it's closed
  404. buf := make([]byte, 1024)
  405. for {
  406. _, err := c.stream.Read(buf)
  407. if err != nil {
  408. break
  409. }
  410. }
  411. _ = c.Close()
  412. }
  413. func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
  414. msg := <-c.msgCh
  415. if msg == nil {
  416. err = net.ErrClosed
  417. return
  418. }
  419. err = common.Error(buffer.Write(msg.Data))
  420. destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port).Unwrap()
  421. return
  422. }
  423. func (c *PacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
  424. msg := <-c.msgCh
  425. if msg == nil {
  426. err = net.ErrClosed
  427. return
  428. }
  429. buffer = buf.As(msg.Data)
  430. destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port).Unwrap()
  431. return
  432. }
  433. func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
  434. return WriteUDPMessage(c.session, UDPMessage{
  435. SessionID: c.sessionId,
  436. Host: destination.AddrString(),
  437. Port: destination.Port,
  438. FragCount: 1,
  439. Data: buffer.Bytes(),
  440. })
  441. }
  442. func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
  443. msg := <-c.msgCh
  444. if msg == nil {
  445. err = net.ErrClosed
  446. return
  447. }
  448. n = copy(p, msg.Data)
  449. destination := M.ParseSocksaddrHostPort(msg.Host, msg.Port)
  450. if destination.IsFqdn() {
  451. addr = destination
  452. } else {
  453. addr = destination.UDPAddr()
  454. }
  455. return
  456. }
  457. func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
  458. err = c.WritePacket(buf.As(p), M.SocksaddrFromNet(addr))
  459. if err == nil {
  460. n = len(p)
  461. }
  462. return
  463. }
  464. func (c *PacketConn) LocalAddr() net.Addr {
  465. return M.Socksaddr{}
  466. }
  467. func (c *PacketConn) RemoteAddr() net.Addr {
  468. return c.destination.UDPAddr()
  469. }
  470. func (c *PacketConn) SetDeadline(t time.Time) error {
  471. return os.ErrInvalid
  472. }
  473. func (c *PacketConn) SetReadDeadline(t time.Time) error {
  474. return os.ErrInvalid
  475. }
  476. func (c *PacketConn) SetWriteDeadline(t time.Time) error {
  477. return os.ErrInvalid
  478. }
  479. func (c *PacketConn) NeedAdditionalReadDeadline() bool {
  480. return true
  481. }
  482. func (c *PacketConn) Read(b []byte) (n int, err error) {
  483. n, _, err = c.ReadFrom(b)
  484. return
  485. }
  486. func (c *PacketConn) Write(b []byte) (n int, err error) {
  487. return c.WriteTo(b, c.destination)
  488. }
  489. func (c *PacketConn) Close() error {
  490. return common.Close(c.stream, c.closer)
  491. }