conn.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. package v2raywebsocket
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "io"
  6. "net"
  7. "net/http"
  8. "os"
  9. "time"
  10. C "github.com/sagernet/sing-box/constant"
  11. "github.com/sagernet/sing/common"
  12. "github.com/sagernet/sing/common/buf"
  13. E "github.com/sagernet/sing/common/exceptions"
  14. "github.com/sagernet/websocket"
  15. )
  16. type WebsocketConn struct {
  17. *websocket.Conn
  18. *Writer
  19. remoteAddr net.Addr
  20. reader io.Reader
  21. }
  22. func NewServerConn(wsConn *websocket.Conn, remoteAddr net.Addr) *WebsocketConn {
  23. return &WebsocketConn{
  24. Conn: wsConn,
  25. remoteAddr: remoteAddr,
  26. Writer: NewWriter(wsConn, true),
  27. }
  28. }
  29. func (c *WebsocketConn) Close() error {
  30. err := c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(C.TCPTimeout))
  31. if err != nil {
  32. return c.Conn.Close()
  33. }
  34. return nil
  35. }
  36. func (c *WebsocketConn) Read(b []byte) (n int, err error) {
  37. for {
  38. if c.reader == nil {
  39. _, c.reader, err = c.NextReader()
  40. if err != nil {
  41. err = wrapError(err)
  42. return
  43. }
  44. }
  45. n, err = c.reader.Read(b)
  46. if E.IsMulti(err, io.EOF) {
  47. c.reader = nil
  48. continue
  49. }
  50. err = wrapError(err)
  51. return
  52. }
  53. }
  54. func (c *WebsocketConn) RemoteAddr() net.Addr {
  55. if c.remoteAddr != nil {
  56. return c.remoteAddr
  57. }
  58. return c.Conn.RemoteAddr()
  59. }
  60. func (c *WebsocketConn) SetDeadline(t time.Time) error {
  61. return os.ErrInvalid
  62. }
  63. func (c *WebsocketConn) SetReadDeadline(t time.Time) error {
  64. return os.ErrInvalid
  65. }
  66. func (c *WebsocketConn) SetWriteDeadline(t time.Time) error {
  67. return os.ErrInvalid
  68. }
  69. func (c *WebsocketConn) NeedAdditionalReadDeadline() bool {
  70. return true
  71. }
  72. func (c *WebsocketConn) Upstream() any {
  73. return c.Conn.NetConn()
  74. }
  75. func (c *WebsocketConn) UpstreamWriter() any {
  76. return c.Writer
  77. }
  78. type EarlyWebsocketConn struct {
  79. *Client
  80. ctx context.Context
  81. conn *WebsocketConn
  82. create chan struct{}
  83. }
  84. func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) {
  85. if c.conn == nil {
  86. <-c.create
  87. }
  88. return c.conn.Read(b)
  89. }
  90. func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
  91. if c.conn != nil {
  92. return c.conn.Write(b)
  93. }
  94. var (
  95. earlyData []byte
  96. lateData []byte
  97. conn *websocket.Conn
  98. response *http.Response
  99. )
  100. if len(b) > int(c.maxEarlyData) {
  101. earlyData = b[:c.maxEarlyData]
  102. lateData = b[c.maxEarlyData:]
  103. } else {
  104. earlyData = b
  105. }
  106. if len(earlyData) > 0 {
  107. earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData)
  108. if c.earlyDataHeaderName == "" {
  109. conn, response, err = c.dialer.DialContext(c.ctx, c.uri+earlyDataString, c.headers)
  110. } else {
  111. headers := c.headers.Clone()
  112. headers.Set(c.earlyDataHeaderName, earlyDataString)
  113. conn, response, err = c.dialer.DialContext(c.ctx, c.uri, headers)
  114. }
  115. } else {
  116. conn, response, err = c.dialer.DialContext(c.ctx, c.uri, c.headers)
  117. }
  118. if err != nil {
  119. return 0, wrapDialError(response, err)
  120. }
  121. c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}
  122. close(c.create)
  123. if len(lateData) > 0 {
  124. _, err = c.conn.Write(lateData)
  125. }
  126. if err != nil {
  127. return
  128. }
  129. return len(b), nil
  130. }
  131. func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error {
  132. if c.conn != nil {
  133. return c.conn.WriteBuffer(buffer)
  134. }
  135. var (
  136. earlyData []byte
  137. lateData []byte
  138. conn *websocket.Conn
  139. response *http.Response
  140. err error
  141. )
  142. if buffer.Len() > int(c.maxEarlyData) {
  143. earlyData = buffer.Bytes()[:c.maxEarlyData]
  144. lateData = buffer.Bytes()[c.maxEarlyData:]
  145. } else {
  146. earlyData = buffer.Bytes()
  147. }
  148. if len(earlyData) > 0 {
  149. earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData)
  150. if c.earlyDataHeaderName == "" {
  151. conn, response, err = c.dialer.DialContext(c.ctx, c.uri+earlyDataString, c.headers)
  152. } else {
  153. headers := c.headers.Clone()
  154. headers.Set(c.earlyDataHeaderName, earlyDataString)
  155. conn, response, err = c.dialer.DialContext(c.ctx, c.uri, headers)
  156. }
  157. } else {
  158. conn, response, err = c.dialer.DialContext(c.ctx, c.uri, c.headers)
  159. }
  160. if err != nil {
  161. return wrapDialError(response, err)
  162. }
  163. c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}
  164. close(c.create)
  165. if len(lateData) > 0 {
  166. _, err = c.conn.Write(lateData)
  167. }
  168. return err
  169. }
  170. func (c *EarlyWebsocketConn) Close() error {
  171. if c.conn == nil {
  172. return nil
  173. }
  174. return c.conn.Close()
  175. }
  176. func (c *EarlyWebsocketConn) LocalAddr() net.Addr {
  177. if c.conn == nil {
  178. return nil
  179. }
  180. return c.conn.LocalAddr()
  181. }
  182. func (c *EarlyWebsocketConn) RemoteAddr() net.Addr {
  183. if c.conn == nil {
  184. return nil
  185. }
  186. return c.conn.RemoteAddr()
  187. }
  188. func (c *EarlyWebsocketConn) SetDeadline(t time.Time) error {
  189. return os.ErrInvalid
  190. }
  191. func (c *EarlyWebsocketConn) SetReadDeadline(t time.Time) error {
  192. return os.ErrInvalid
  193. }
  194. func (c *EarlyWebsocketConn) SetWriteDeadline(t time.Time) error {
  195. return os.ErrInvalid
  196. }
  197. func (c *EarlyWebsocketConn) NeedAdditionalReadDeadline() bool {
  198. return true
  199. }
  200. func (c *EarlyWebsocketConn) Upstream() any {
  201. return common.PtrOrNil(c.conn)
  202. }
  203. func (c *EarlyWebsocketConn) LazyHeadroom() bool {
  204. return c.conn == nil
  205. }
  206. func wrapError(err error) error {
  207. if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
  208. return io.EOF
  209. }
  210. if websocket.IsCloseError(err, websocket.CloseAbnormalClosure) {
  211. return net.ErrClosed
  212. }
  213. return err
  214. }