client.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. package v2raywebsocket
  2. import (
  3. "context"
  4. "net"
  5. "net/http"
  6. "net/url"
  7. "time"
  8. "github.com/sagernet/sing-box/adapter"
  9. "github.com/sagernet/sing-box/common/tls"
  10. "github.com/sagernet/sing-box/option"
  11. E "github.com/sagernet/sing/common/exceptions"
  12. M "github.com/sagernet/sing/common/metadata"
  13. N "github.com/sagernet/sing/common/network"
  14. sHTTP "github.com/sagernet/sing/protocol/http"
  15. "github.com/sagernet/websocket"
  16. )
  17. var _ adapter.V2RayClientTransport = (*Client)(nil)
  18. type Client struct {
  19. dialer *websocket.Dialer
  20. requestURL url.URL
  21. requestURLString string
  22. headers http.Header
  23. maxEarlyData uint32
  24. earlyDataHeaderName string
  25. }
  26. func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayWebsocketOptions, tlsConfig tls.Config) adapter.V2RayClientTransport {
  27. wsDialer := &websocket.Dialer{
  28. ReadBufferSize: 4 * 1024,
  29. WriteBufferSize: 4 * 1024,
  30. HandshakeTimeout: time.Second * 8,
  31. }
  32. if tlsConfig != nil {
  33. if len(tlsConfig.NextProtos()) == 0 {
  34. tlsConfig.SetNextProtos([]string{"http/1.1"})
  35. }
  36. wsDialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
  37. conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
  38. if err != nil {
  39. return nil, err
  40. }
  41. tlsConn, err := tls.ClientHandshake(ctx, conn, tlsConfig)
  42. if err != nil {
  43. return nil, err
  44. }
  45. return &deadConn{tlsConn}, nil
  46. }
  47. } else {
  48. wsDialer.NetDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
  49. conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
  50. if err != nil {
  51. return nil, err
  52. }
  53. return &deadConn{conn}, nil
  54. }
  55. }
  56. var requestURL url.URL
  57. if tlsConfig == nil {
  58. requestURL.Scheme = "ws"
  59. } else {
  60. requestURL.Scheme = "wss"
  61. }
  62. requestURL.Host = serverAddr.String()
  63. requestURL.Path = options.Path
  64. err := sHTTP.URLSetPath(&requestURL, options.Path)
  65. if err != nil {
  66. return nil
  67. }
  68. headers := make(http.Header)
  69. for key, value := range options.Headers {
  70. headers[key] = value
  71. }
  72. return &Client{
  73. wsDialer,
  74. requestURL,
  75. requestURL.String(),
  76. headers,
  77. options.MaxEarlyData,
  78. options.EarlyDataHeaderName,
  79. }
  80. }
  81. func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
  82. if c.maxEarlyData <= 0 {
  83. conn, response, err := c.dialer.DialContext(ctx, c.requestURLString, c.headers)
  84. if err == nil {
  85. return &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}, nil
  86. }
  87. return nil, wrapDialError(response, err)
  88. } else {
  89. return &EarlyWebsocketConn{Client: c, ctx: ctx, create: make(chan struct{})}, nil
  90. }
  91. }
  92. func wrapDialError(response *http.Response, err error) error {
  93. if response == nil {
  94. return err
  95. }
  96. return E.Extend(err, "HTTP ", response.StatusCode, " ", response.Status)
  97. }