client.go 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. package v2raywebsocket
  2. import (
  3. "context"
  4. "net"
  5. "net/http"
  6. "net/url"
  7. "strings"
  8. "time"
  9. "github.com/sagernet/sing-box/adapter"
  10. "github.com/sagernet/sing-box/common/tls"
  11. "github.com/sagernet/sing-box/option"
  12. E "github.com/sagernet/sing/common/exceptions"
  13. M "github.com/sagernet/sing/common/metadata"
  14. N "github.com/sagernet/sing/common/network"
  15. "github.com/sagernet/websocket"
  16. )
  17. var _ adapter.V2RayClientTransport = (*Client)(nil)
  18. type Client struct {
  19. dialer *websocket.Dialer
  20. uri string
  21. headers http.Header
  22. maxEarlyData uint32
  23. earlyDataHeaderName string
  24. }
  25. func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayWebsocketOptions, tlsConfig tls.Config) adapter.V2RayClientTransport {
  26. wsDialer := &websocket.Dialer{
  27. ReadBufferSize: 4 * 1024,
  28. WriteBufferSize: 4 * 1024,
  29. HandshakeTimeout: time.Second * 8,
  30. }
  31. if tlsConfig != nil {
  32. if len(tlsConfig.NextProtos()) == 0 {
  33. tlsConfig.SetNextProtos([]string{"http/1.1"})
  34. }
  35. wsDialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
  36. conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
  37. if err != nil {
  38. return nil, err
  39. }
  40. return tls.ClientHandshake(ctx, conn, tlsConfig)
  41. }
  42. } else {
  43. wsDialer.NetDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
  44. return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
  45. }
  46. }
  47. var uri url.URL
  48. if tlsConfig == nil {
  49. uri.Scheme = "ws"
  50. } else {
  51. uri.Scheme = "wss"
  52. }
  53. uri.Host = serverAddr.String()
  54. uri.Path = options.Path
  55. if !strings.HasPrefix(uri.Path, "/") {
  56. uri.Path = "/" + uri.Path
  57. }
  58. if strings.HasSuffix(uri.Path, "?") {
  59. uri.ForceQuery = true
  60. uri.Path = strings.TrimSuffix(uri.Path, "?")
  61. }
  62. headers := make(http.Header)
  63. for key, value := range options.Headers {
  64. headers[key] = value
  65. }
  66. return &Client{
  67. wsDialer,
  68. uri.String(),
  69. headers,
  70. options.MaxEarlyData,
  71. options.EarlyDataHeaderName,
  72. }
  73. }
  74. func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
  75. if c.maxEarlyData <= 0 {
  76. conn, response, err := c.dialer.DialContext(ctx, c.uri, c.headers)
  77. if err == nil {
  78. return &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}, nil
  79. }
  80. return nil, wrapDialError(response, err)
  81. } else {
  82. return &EarlyWebsocketConn{Client: c, ctx: ctx, create: make(chan struct{})}, nil
  83. }
  84. }
  85. func wrapDialError(response *http.Response, err error) error {
  86. if response == nil {
  87. return err
  88. }
  89. return E.Extend(err, "HTTP ", response.StatusCode, " ", response.Status)
  90. }