123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123 |
- package v2raywebsocket
- import (
- "context"
- "net"
- "net/http"
- "net/url"
- "strings"
- "time"
- "github.com/sagernet/sing-box/adapter"
- "github.com/sagernet/sing-box/common/tls"
- C "github.com/sagernet/sing-box/constant"
- "github.com/sagernet/sing-box/option"
- "github.com/sagernet/sing/common/buf"
- "github.com/sagernet/sing/common/bufio"
- "github.com/sagernet/sing/common/bufio/deadline"
- E "github.com/sagernet/sing/common/exceptions"
- M "github.com/sagernet/sing/common/metadata"
- N "github.com/sagernet/sing/common/network"
- sHTTP "github.com/sagernet/sing/protocol/http"
- "github.com/sagernet/ws"
- )
- var _ adapter.V2RayClientTransport = (*Client)(nil)
- type Client struct {
- dialer N.Dialer
- serverAddr M.Socksaddr
- requestURL url.URL
- headers http.Header
- maxEarlyData uint32
- earlyDataHeaderName string
- }
- func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayWebsocketOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) {
- if tlsConfig != nil {
- if len(tlsConfig.NextProtos()) == 0 {
- tlsConfig.SetNextProtos([]string{"http/1.1"})
- }
- dialer = tls.NewDialer(dialer, tlsConfig)
- }
- var requestURL url.URL
- if tlsConfig == nil {
- requestURL.Scheme = "ws"
- } else {
- requestURL.Scheme = "wss"
- }
- requestURL.Host = serverAddr.String()
- requestURL.Path = options.Path
- err := sHTTP.URLSetPath(&requestURL, options.Path)
- if err != nil {
- return nil, E.Cause(err, "parse path")
- }
- if !strings.HasPrefix(requestURL.Path, "/") {
- requestURL.Path = "/" + requestURL.Path
- }
- headers := options.Headers.Build()
- if host := headers.Get("Host"); host != "" {
- headers.Del("Host")
- requestURL.Host = host
- }
- if headers.Get("User-Agent") == "" {
- headers.Set("User-Agent", "Go-http-client/1.1")
- }
- return &Client{
- dialer,
- serverAddr,
- requestURL,
- headers,
- options.MaxEarlyData,
- options.EarlyDataHeaderName,
- }, nil
- }
- func (c *Client) dialContext(ctx context.Context, requestURL *url.URL, headers http.Header) (*WebsocketConn, error) {
- conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.serverAddr)
- if err != nil {
- return nil, err
- }
- var deadlineConn net.Conn
- if deadline.NeedAdditionalReadDeadline(conn) {
- deadlineConn = deadline.NewConn(conn)
- } else {
- deadlineConn = conn
- }
- deadlineConn.SetDeadline(time.Now().Add(C.TCPTimeout))
- var protocols []string
- if protocolHeader := headers.Get("Sec-WebSocket-Protocol"); protocolHeader != "" {
- protocols = []string{protocolHeader}
- headers.Del("Sec-WebSocket-Protocol")
- }
- reader, _, err := ws.Dialer{Header: ws.HandshakeHeaderHTTP(headers), Protocols: protocols}.Upgrade(deadlineConn, requestURL)
- deadlineConn.SetDeadline(time.Time{})
- if err != nil {
- return nil, err
- }
- if reader != nil {
- buffer := buf.NewSize(reader.Buffered())
- _, err = buffer.ReadFullFrom(reader, buffer.Len())
- if err != nil {
- return nil, err
- }
- conn = bufio.NewCachedConn(conn, buffer)
- }
- return NewConn(conn, nil, ws.StateClientSide), nil
- }
- func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
- if c.maxEarlyData <= 0 {
- conn, err := c.dialContext(ctx, &c.requestURL, c.headers)
- if err != nil {
- return nil, err
- }
- return conn, nil
- } else {
- return &EarlyWebsocketConn{Client: c, ctx: ctx, create: make(chan struct{})}, nil
- }
- }
- func (c *Client) Close() error {
- return nil
- }
|