dialer.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. package websocket
  2. import (
  3. "bytes"
  4. "context"
  5. _ "embed"
  6. "encoding/base64"
  7. "io"
  8. gonet "net"
  9. "net/http"
  10. "time"
  11. "github.com/gorilla/websocket"
  12. "github.com/xtls/xray-core/common"
  13. "github.com/xtls/xray-core/common/errors"
  14. "github.com/xtls/xray-core/common/net"
  15. "github.com/xtls/xray-core/common/platform"
  16. "github.com/xtls/xray-core/common/uuid"
  17. "github.com/xtls/xray-core/transport/internet"
  18. "github.com/xtls/xray-core/transport/internet/stat"
  19. "github.com/xtls/xray-core/transport/internet/tls"
  20. )
  21. //go:embed dialer.html
  22. var webpage []byte
  23. var conns chan *websocket.Conn
  24. func init() {
  25. addr := platform.NewEnvFlag(platform.BrowserDialerAddress).GetValue(func() string { return "" })
  26. if addr != "" {
  27. token := uuid.New()
  28. csrfToken := token.String()
  29. webpage = bytes.ReplaceAll(webpage, []byte("csrfToken"), []byte(csrfToken))
  30. conns = make(chan *websocket.Conn, 256)
  31. go http.ListenAndServe(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  32. if r.URL.Path == "/websocket" {
  33. if r.URL.Query().Get("token") == csrfToken {
  34. if conn, err := upgrader.Upgrade(w, r, nil); err == nil {
  35. conns <- conn
  36. } else {
  37. errors.LogError(context.Background(), "Browser dialer http upgrade unexpected error")
  38. }
  39. }
  40. } else {
  41. w.Write(webpage)
  42. }
  43. }))
  44. }
  45. }
  46. // Dial dials a WebSocket connection to the given destination.
  47. func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
  48. errors.LogInfo(ctx, "creating connection to ", dest)
  49. var conn net.Conn
  50. if streamSettings.ProtocolSettings.(*Config).Ed > 0 {
  51. ctx, cancel := context.WithCancel(ctx)
  52. conn = &delayDialConn{
  53. dialed: make(chan bool, 1),
  54. cancel: cancel,
  55. ctx: ctx,
  56. dest: dest,
  57. streamSettings: streamSettings,
  58. }
  59. } else {
  60. var err error
  61. if conn, err = dialWebSocket(ctx, dest, streamSettings, nil); err != nil {
  62. return nil, errors.New("failed to dial WebSocket").Base(err)
  63. }
  64. }
  65. return stat.Connection(conn), nil
  66. }
  67. func init() {
  68. common.Must(internet.RegisterTransportDialer(protocolName, Dial))
  69. }
  70. func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig, ed []byte) (net.Conn, error) {
  71. wsSettings := streamSettings.ProtocolSettings.(*Config)
  72. dialer := &websocket.Dialer{
  73. NetDial: func(network, addr string) (net.Conn, error) {
  74. return internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
  75. },
  76. ReadBufferSize: 4 * 1024,
  77. WriteBufferSize: 4 * 1024,
  78. HandshakeTimeout: time.Second * 8,
  79. }
  80. protocol := "ws"
  81. if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
  82. protocol = "wss"
  83. tlsConfig := config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1"))
  84. dialer.TLSClientConfig = tlsConfig
  85. if fingerprint := tls.GetFingerprint(config.Fingerprint); fingerprint != nil {
  86. dialer.NetDialTLSContext = func(_ context.Context, _, addr string) (gonet.Conn, error) {
  87. // Like the NetDial in the dialer
  88. pconn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
  89. if err != nil {
  90. errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
  91. return nil, err
  92. }
  93. // TLS and apply the handshake
  94. cn := tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
  95. if err := cn.WebsocketHandshakeContext(ctx); err != nil {
  96. errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
  97. return nil, err
  98. }
  99. if !tlsConfig.InsecureSkipVerify {
  100. if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil {
  101. errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
  102. return nil, err
  103. }
  104. }
  105. return cn, nil
  106. }
  107. }
  108. }
  109. host := dest.NetAddr()
  110. if (protocol == "ws" && dest.Port == 80) || (protocol == "wss" && dest.Port == 443) {
  111. host = dest.Address.String()
  112. }
  113. uri := protocol + "://" + host + wsSettings.GetNormalizedPath()
  114. if conns != nil {
  115. data := []byte(uri)
  116. if ed != nil {
  117. data = append(data, " "+base64.RawURLEncoding.EncodeToString(ed)...)
  118. }
  119. var conn *websocket.Conn
  120. for {
  121. conn = <-conns
  122. if conn.WriteMessage(websocket.TextMessage, data) != nil {
  123. conn.Close()
  124. } else {
  125. break
  126. }
  127. }
  128. if _, p, err := conn.ReadMessage(); err != nil {
  129. conn.Close()
  130. return nil, err
  131. } else if s := string(p); s != "ok" {
  132. conn.Close()
  133. return nil, errors.New(s)
  134. }
  135. return newConnection(conn, conn.RemoteAddr(), nil), nil
  136. }
  137. header := wsSettings.GetRequestHeader()
  138. if ed != nil {
  139. // RawURLEncoding is support by both V2Ray/V2Fly and XRay.
  140. header.Set("Sec-WebSocket-Protocol", base64.RawURLEncoding.EncodeToString(ed))
  141. }
  142. conn, resp, err := dialer.DialContext(ctx, uri, header)
  143. if err != nil {
  144. var reason string
  145. if resp != nil {
  146. reason = resp.Status
  147. }
  148. return nil, errors.New("failed to dial to (", uri, "): ", reason).Base(err)
  149. }
  150. return newConnection(conn, conn.RemoteAddr(), nil), nil
  151. }
  152. type delayDialConn struct {
  153. net.Conn
  154. closed bool
  155. dialed chan bool
  156. cancel context.CancelFunc
  157. ctx context.Context
  158. dest net.Destination
  159. streamSettings *internet.MemoryStreamConfig
  160. }
  161. func (d *delayDialConn) Write(b []byte) (int, error) {
  162. if d.closed {
  163. return 0, io.ErrClosedPipe
  164. }
  165. if d.Conn == nil {
  166. ed := b
  167. if len(ed) > int(d.streamSettings.ProtocolSettings.(*Config).Ed) {
  168. ed = nil
  169. }
  170. var err error
  171. if d.Conn, err = dialWebSocket(d.ctx, d.dest, d.streamSettings, ed); err != nil {
  172. d.Close()
  173. return 0, errors.New("failed to dial WebSocket").Base(err)
  174. }
  175. d.dialed <- true
  176. if ed != nil {
  177. return len(ed), nil
  178. }
  179. }
  180. return d.Conn.Write(b)
  181. }
  182. func (d *delayDialConn) Read(b []byte) (int, error) {
  183. if d.closed {
  184. return 0, io.ErrClosedPipe
  185. }
  186. if d.Conn == nil {
  187. select {
  188. case <-d.ctx.Done():
  189. return 0, io.ErrUnexpectedEOF
  190. case <-d.dialed:
  191. }
  192. }
  193. return d.Conn.Read(b)
  194. }
  195. func (d *delayDialConn) Close() error {
  196. d.closed = true
  197. d.cancel()
  198. if d.Conn == nil {
  199. return nil
  200. }
  201. return d.Conn.Close()
  202. }