conn.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. package v2raywebsocket
  2. import (
  3. "encoding/base64"
  4. "io"
  5. "net"
  6. "net/http"
  7. "os"
  8. "time"
  9. E "github.com/sagernet/sing/common/exceptions"
  10. "github.com/gorilla/websocket"
  11. )
  12. type WebsocketConn struct {
  13. *websocket.Conn
  14. remoteAddr net.Addr
  15. reader io.Reader
  16. }
  17. func (c *WebsocketConn) Read(b []byte) (n int, err error) {
  18. for {
  19. if c.reader == nil {
  20. _, c.reader, err = c.NextReader()
  21. if err != nil {
  22. return
  23. }
  24. }
  25. n, err = c.reader.Read(b)
  26. if E.IsMulti(err, io.EOF) {
  27. c.reader = nil
  28. continue
  29. }
  30. return
  31. }
  32. }
  33. func (c *WebsocketConn) Write(b []byte) (n int, err error) {
  34. err = c.WriteMessage(websocket.BinaryMessage, b)
  35. if err != nil {
  36. return
  37. }
  38. return len(b), nil
  39. }
  40. func (c *WebsocketConn) RemoteAddr() net.Addr {
  41. if c.remoteAddr != nil {
  42. return c.remoteAddr
  43. }
  44. return c.Conn.RemoteAddr()
  45. }
  46. func (c *WebsocketConn) SetDeadline(t time.Time) error {
  47. return os.ErrInvalid
  48. }
  49. type EarlyWebsocketConn struct {
  50. *Client
  51. conn *WebsocketConn
  52. create chan struct{}
  53. }
  54. func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) {
  55. if c.conn == nil {
  56. <-c.create
  57. }
  58. return c.conn.Read(b)
  59. }
  60. func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
  61. if c.conn != nil {
  62. return c.conn.Write(b)
  63. }
  64. var (
  65. earlyData []byte
  66. lateData []byte
  67. conn *websocket.Conn
  68. response *http.Response
  69. )
  70. if len(earlyData) > int(c.maxEarlyData) {
  71. earlyData = earlyData[:c.maxEarlyData]
  72. lateData = lateData[c.maxEarlyData:]
  73. } else {
  74. earlyData = b
  75. }
  76. if len(earlyData) > 0 {
  77. earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData)
  78. if c.earlyDataHeaderName == "" {
  79. conn, response, err = c.dialer.Dial(c.uri+earlyDataString, c.headers)
  80. } else {
  81. headers := c.headers.Clone()
  82. headers.Set(c.earlyDataHeaderName, earlyDataString)
  83. conn, response, err = c.dialer.Dial(c.uri, headers)
  84. }
  85. } else {
  86. conn, response, err = c.dialer.Dial(c.uri, c.headers)
  87. }
  88. if err != nil {
  89. return 0, wrapDialError(response, err)
  90. }
  91. c.conn = &WebsocketConn{Conn: conn}
  92. close(c.create)
  93. if len(lateData) > 0 {
  94. _, err = c.conn.Write(lateData)
  95. }
  96. if err != nil {
  97. return
  98. }
  99. return len(b), nil
  100. }
  101. func (c *EarlyWebsocketConn) Close() error {
  102. if c.conn == nil {
  103. return nil
  104. }
  105. return c.conn.Close()
  106. }
  107. func (c *EarlyWebsocketConn) LocalAddr() net.Addr {
  108. if c.conn == nil {
  109. return nil
  110. }
  111. return c.conn.LocalAddr()
  112. }
  113. func (c *EarlyWebsocketConn) RemoteAddr() net.Addr {
  114. if c.conn == nil {
  115. return nil
  116. }
  117. return c.conn.RemoteAddr()
  118. }
  119. func (c *EarlyWebsocketConn) SetDeadline(t time.Time) error {
  120. if c.conn == nil {
  121. return os.ErrInvalid
  122. }
  123. return c.conn.SetDeadline(t)
  124. }
  125. func (c *EarlyWebsocketConn) SetReadDeadline(t time.Time) error {
  126. if c.conn == nil {
  127. return os.ErrInvalid
  128. }
  129. return c.conn.SetReadDeadline(t)
  130. }
  131. func (c *EarlyWebsocketConn) SetWriteDeadline(t time.Time) error {
  132. if c.conn == nil {
  133. return os.ErrInvalid
  134. }
  135. return c.conn.SetWriteDeadline(t)
  136. }