conn.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. package v2raywebsocket
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "io"
  6. "net"
  7. "net/http"
  8. "os"
  9. "sync"
  10. "time"
  11. C "github.com/sagernet/sing-box/constant"
  12. "github.com/sagernet/sing/common"
  13. "github.com/sagernet/sing/common/buf"
  14. E "github.com/sagernet/sing/common/exceptions"
  15. "github.com/sagernet/websocket"
  16. )
  17. type WebsocketConn struct {
  18. *websocket.Conn
  19. *Writer
  20. remoteAddr net.Addr
  21. reader io.Reader
  22. }
  23. func NewServerConn(wsConn *websocket.Conn, remoteAddr net.Addr) *WebsocketConn {
  24. return &WebsocketConn{
  25. Conn: wsConn,
  26. remoteAddr: remoteAddr,
  27. Writer: NewWriter(wsConn, true),
  28. }
  29. }
  30. func (c *WebsocketConn) Close() error {
  31. err := c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(C.TCPTimeout))
  32. if err != nil {
  33. return c.Conn.Close()
  34. }
  35. return nil
  36. }
  37. func (c *WebsocketConn) Read(b []byte) (n int, err error) {
  38. for {
  39. if c.reader == nil {
  40. _, c.reader, err = c.NextReader()
  41. if err != nil {
  42. err = wrapError(err)
  43. return
  44. }
  45. }
  46. n, err = c.reader.Read(b)
  47. if E.IsMulti(err, io.EOF) {
  48. c.reader = nil
  49. continue
  50. }
  51. err = wrapError(err)
  52. return
  53. }
  54. }
  55. func (c *WebsocketConn) RemoteAddr() net.Addr {
  56. if c.remoteAddr != nil {
  57. return c.remoteAddr
  58. }
  59. return c.Conn.RemoteAddr()
  60. }
  61. func (c *WebsocketConn) SetDeadline(t time.Time) error {
  62. return os.ErrInvalid
  63. }
  64. func (c *WebsocketConn) SetReadDeadline(t time.Time) error {
  65. return os.ErrInvalid
  66. }
  67. func (c *WebsocketConn) SetWriteDeadline(t time.Time) error {
  68. return os.ErrInvalid
  69. }
  70. func (c *WebsocketConn) NeedAdditionalReadDeadline() bool {
  71. return true
  72. }
  73. func (c *WebsocketConn) Upstream() any {
  74. return c.Conn.NetConn()
  75. }
  76. func (c *WebsocketConn) UpstreamWriter() any {
  77. return c.Writer
  78. }
  79. type EarlyWebsocketConn struct {
  80. *Client
  81. ctx context.Context
  82. conn *WebsocketConn
  83. access sync.Mutex
  84. create chan struct{}
  85. err error
  86. }
  87. func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) {
  88. if c.conn == nil {
  89. <-c.create
  90. if c.err != nil {
  91. return 0, c.err
  92. }
  93. }
  94. return c.conn.Read(b)
  95. }
  96. func (c *EarlyWebsocketConn) writeRequest(content []byte) error {
  97. var (
  98. earlyData []byte
  99. lateData []byte
  100. conn *websocket.Conn
  101. response *http.Response
  102. err error
  103. )
  104. if len(content) > int(c.maxEarlyData) {
  105. earlyData = content[:c.maxEarlyData]
  106. lateData = content[c.maxEarlyData:]
  107. } else {
  108. earlyData = content
  109. }
  110. if len(earlyData) > 0 {
  111. earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData)
  112. if c.earlyDataHeaderName == "" {
  113. requestURL := c.requestURL
  114. requestURL.Path += earlyDataString
  115. conn, response, err = c.dialer.DialContext(c.ctx, requestURL.String(), c.headers)
  116. } else {
  117. headers := c.headers.Clone()
  118. headers.Set(c.earlyDataHeaderName, earlyDataString)
  119. conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, headers)
  120. }
  121. } else {
  122. conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, c.headers)
  123. }
  124. if err != nil {
  125. return wrapDialError(response, err)
  126. }
  127. c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}
  128. if len(lateData) > 0 {
  129. _, err = c.conn.Write(lateData)
  130. }
  131. return err
  132. }
  133. func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
  134. if c.conn != nil {
  135. return c.conn.Write(b)
  136. }
  137. c.access.Lock()
  138. defer c.access.Unlock()
  139. if c.conn != nil {
  140. return c.conn.Write(b)
  141. }
  142. err = c.writeRequest(b)
  143. c.err = err
  144. close(c.create)
  145. if err != nil {
  146. return
  147. }
  148. return len(b), nil
  149. }
  150. func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error {
  151. if c.conn != nil {
  152. return c.conn.WriteBuffer(buffer)
  153. }
  154. c.access.Lock()
  155. defer c.access.Unlock()
  156. if c.conn != nil {
  157. return c.conn.WriteBuffer(buffer)
  158. }
  159. err := c.writeRequest(buffer.Bytes())
  160. c.err = err
  161. close(c.create)
  162. return err
  163. }
  164. func (c *EarlyWebsocketConn) Close() error {
  165. if c.conn == nil {
  166. return nil
  167. }
  168. return c.conn.Close()
  169. }
  170. func (c *EarlyWebsocketConn) LocalAddr() net.Addr {
  171. if c.conn == nil {
  172. return nil
  173. }
  174. return c.conn.LocalAddr()
  175. }
  176. func (c *EarlyWebsocketConn) RemoteAddr() net.Addr {
  177. if c.conn == nil {
  178. return nil
  179. }
  180. return c.conn.RemoteAddr()
  181. }
  182. func (c *EarlyWebsocketConn) SetDeadline(t time.Time) error {
  183. return os.ErrInvalid
  184. }
  185. func (c *EarlyWebsocketConn) SetReadDeadline(t time.Time) error {
  186. return os.ErrInvalid
  187. }
  188. func (c *EarlyWebsocketConn) SetWriteDeadline(t time.Time) error {
  189. return os.ErrInvalid
  190. }
  191. func (c *EarlyWebsocketConn) NeedAdditionalReadDeadline() bool {
  192. return true
  193. }
  194. func (c *EarlyWebsocketConn) Upstream() any {
  195. return common.PtrOrNil(c.conn)
  196. }
  197. func (c *EarlyWebsocketConn) LazyHeadroom() bool {
  198. return c.conn == nil
  199. }
  200. func wrapError(err error) error {
  201. if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
  202. return io.EOF
  203. }
  204. if websocket.IsCloseError(err, websocket.CloseAbnormalClosure) {
  205. return net.ErrClosed
  206. }
  207. return err
  208. }