client.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. package v2raywebsocket
  2. import (
  3. "context"
  4. "net"
  5. "net/http"
  6. "net/url"
  7. "strings"
  8. "time"
  9. "github.com/sagernet/sing-box/adapter"
  10. "github.com/sagernet/sing-box/common/tls"
  11. "github.com/sagernet/sing-box/option"
  12. E "github.com/sagernet/sing/common/exceptions"
  13. M "github.com/sagernet/sing/common/metadata"
  14. N "github.com/sagernet/sing/common/network"
  15. "github.com/sagernet/websocket"
  16. )
  17. var _ adapter.V2RayClientTransport = (*Client)(nil)
  18. type Client struct {
  19. dialer *websocket.Dialer
  20. uri string
  21. headers http.Header
  22. maxEarlyData uint32
  23. earlyDataHeaderName string
  24. }
  25. func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayWebsocketOptions, tlsConfig tls.Config) adapter.V2RayClientTransport {
  26. wsDialer := &websocket.Dialer{
  27. ReadBufferSize: 4 * 1024,
  28. WriteBufferSize: 4 * 1024,
  29. HandshakeTimeout: time.Second * 8,
  30. }
  31. if tlsConfig != nil {
  32. wsDialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
  33. conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
  34. if err != nil {
  35. return nil, err
  36. }
  37. return tls.ClientHandshake(ctx, conn, tlsConfig)
  38. }
  39. } else {
  40. wsDialer.NetDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
  41. return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
  42. }
  43. }
  44. var uri url.URL
  45. if tlsConfig == nil {
  46. uri.Scheme = "ws"
  47. } else {
  48. uri.Scheme = "wss"
  49. }
  50. uri.Host = serverAddr.String()
  51. uri.Path = options.Path
  52. if !strings.HasPrefix(uri.Path, "/") {
  53. uri.Path = "/" + uri.Path
  54. }
  55. headers := make(http.Header)
  56. for key, value := range options.Headers {
  57. headers.Set(key, value)
  58. }
  59. return &Client{
  60. wsDialer,
  61. uri.String(),
  62. headers,
  63. options.MaxEarlyData,
  64. options.EarlyDataHeaderName,
  65. }
  66. }
  67. func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
  68. if c.maxEarlyData <= 0 {
  69. conn, response, err := c.dialer.DialContext(ctx, c.uri, c.headers)
  70. if err == nil {
  71. return &WebsocketConn{Conn: conn, Writer: &Writer{conn, false}}, nil
  72. }
  73. return nil, wrapDialError(response, err)
  74. } else {
  75. return &EarlyWebsocketConn{Client: c, ctx: ctx, create: make(chan struct{})}, nil
  76. }
  77. }
  78. func wrapDialError(response *http.Response, err error) error {
  79. if response == nil {
  80. return err
  81. }
  82. return E.Extend(err, "HTTP ", response.StatusCode, " ", response.Status)
  83. }