Browse Source

Improve connection timeout

世界 3 years ago
parent
commit
c7fabe40ed

+ 4 - 1
box.go

@@ -143,9 +143,12 @@ func (s *Box) Start() error {
 	if err != nil {
 		return err
 	}
-	for _, in := range s.inbounds {
+	for i, in := range s.inbounds {
 		err = in.Start()
 		if err != nil {
+			for g := 0; g < i; g++ {
+				s.inbounds[g].Close()
+			}
 			return err
 		}
 	}

+ 2 - 0
common/dialer/default.go

@@ -59,6 +59,8 @@ func NewDefault(router adapter.Router, options option.DialerOptions) *DefaultDia
 	}
 	if options.ConnectTimeout != 0 {
 		dialer.Timeout = time.Duration(options.ConnectTimeout)
+	} else {
+		dialer.Timeout = C.DefaultTCPTimeout
 	}
 	return &DefaultDialer{tfo.Dialer{Dialer: dialer, DisableTFO: !options.TCPFastOpen}, listener}
 }

+ 0 - 4
common/dialer/dialer.go

@@ -6,7 +6,6 @@ import (
 	"github.com/sagernet/sing-box/adapter"
 	"github.com/sagernet/sing-box/option"
 	"github.com/sagernet/sing-dns"
-	"github.com/sagernet/sing/common"
 	N "github.com/sagernet/sing/common/network"
 )
 
@@ -24,8 +23,5 @@ func NewOutbound(router adapter.Router, options option.OutboundDialerOptions) N.
 	if domainStrategy != dns.DomainStrategyAsIS || options.Detour == "" {
 		dialer = NewResolveDialer(router, dialer, domainStrategy, time.Duration(options.FallbackDelay))
 	}
-	if options.OverrideOptions.IsValid() {
-		dialer = NewOverride(dialer, common.PtrValueOrDefault(options.OverrideOptions))
-	}
 	return dialer
 }

+ 0 - 69
common/dialer/override.go

@@ -1,69 +0,0 @@
-package dialer
-
-import (
-	"context"
-	"crypto/tls"
-	"net"
-
-	C "github.com/sagernet/sing-box/constant"
-	"github.com/sagernet/sing-box/option"
-	M "github.com/sagernet/sing/common/metadata"
-	N "github.com/sagernet/sing/common/network"
-	"github.com/sagernet/sing/common/uot"
-)
-
-var _ N.Dialer = (*OverrideDialer)(nil)
-
-type OverrideDialer struct {
-	upstream   N.Dialer
-	tlsEnabled bool
-	tlsConfig  tls.Config
-	uotEnabled bool
-}
-
-func NewOverride(upstream N.Dialer, options option.OverrideStreamOptions) N.Dialer {
-	return &OverrideDialer{
-		upstream,
-		options.TLS,
-		tls.Config{
-			ServerName:         options.TLSServerName,
-			InsecureSkipVerify: options.TLSInsecure,
-		},
-		options.UDPOverTCP,
-	}
-}
-
-func (d *OverrideDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
-	switch network {
-	case C.NetworkTCP:
-		conn, err := d.upstream.DialContext(ctx, C.NetworkTCP, destination)
-		if err != nil {
-			return nil, err
-		}
-		return tls.Client(conn, &d.tlsConfig), nil
-	case C.NetworkUDP:
-		if d.uotEnabled {
-			tcpConn, err := d.upstream.DialContext(ctx, C.NetworkTCP, destination)
-			if err != nil {
-				return nil, err
-			}
-			return uot.NewClientConn(tcpConn), nil
-		}
-	}
-	return d.upstream.DialContext(ctx, network, destination)
-}
-
-func (d *OverrideDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
-	if d.uotEnabled {
-		tcpConn, err := d.upstream.DialContext(ctx, C.NetworkTCP, destination)
-		if err != nil {
-			return nil, err
-		}
-		return uot.NewClientConn(tcpConn), nil
-	}
-	return d.upstream.ListenPacket(ctx, destination)
-}
-
-func (d *OverrideDialer) Upstream() any {
-	return d.upstream
-}

+ 86 - 0
common/dialer/tls.go

@@ -0,0 +1,86 @@
+package dialer
+
+import (
+	"context"
+	"crypto/tls"
+	"crypto/x509"
+	"net"
+	"net/netip"
+	"os"
+
+	C "github.com/sagernet/sing-box/constant"
+	"github.com/sagernet/sing-box/option"
+	E "github.com/sagernet/sing/common/exceptions"
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+)
+
+type TLSDialer struct {
+	dialer N.Dialer
+	config *tls.Config
+}
+
+func NewTLS(dialer N.Dialer, serverAddress string, options option.OutboundTLSOptions) (N.Dialer, error) {
+	if !options.Enabled {
+		return dialer, nil
+	}
+
+	var serverName string
+	if options.ServerName != "" {
+		serverName = options.ServerName
+	} else if serverAddress != "" {
+		if _, err := netip.ParseAddr(serverName); err != nil {
+			serverName = serverAddress
+		}
+	}
+	if serverName == "" && options.Insecure {
+		return nil, E.New("missing server_name or insecure=true")
+	}
+
+	var tlsConfig tls.Config
+	if options.DisableSNI {
+		tlsConfig.ServerName = "127.0.0.1"
+	} else {
+		tlsConfig.ServerName = serverName
+	}
+	if options.Insecure {
+		tlsConfig.InsecureSkipVerify = options.Insecure
+	} else if options.DisableSNI {
+		tlsConfig.InsecureSkipVerify = true
+		tlsConfig.VerifyConnection = func(state tls.ConnectionState) error {
+			verifyOptions := x509.VerifyOptions{
+				DNSName:       serverName,
+				Intermediates: x509.NewCertPool(),
+			}
+			for _, cert := range state.PeerCertificates[1:] {
+				verifyOptions.Intermediates.AddCert(cert)
+			}
+			_, err := state.PeerCertificates[0].Verify(verifyOptions)
+			return err
+		}
+	}
+
+	return &TLSDialer{
+		dialer: dialer,
+		config: &tlsConfig,
+	}, nil
+}
+
+func (d *TLSDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
+	if network != C.NetworkTCP {
+		return nil, os.ErrInvalid
+	}
+	conn, err := d.dialer.DialContext(ctx, network, destination)
+	if err != nil {
+		return nil, err
+	}
+	tlsConn := tls.Client(conn, d.config)
+	ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout)
+	defer cancel()
+	err = tlsConn.HandshakeContext(ctx)
+	return tlsConn, err
+}
+
+func (d *TLSDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
+	return nil, os.ErrInvalid
+}

+ 2 - 1
common/sniff/sniff.go

@@ -9,6 +9,7 @@ import (
 	"time"
 
 	"github.com/sagernet/sing-box/adapter"
+	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing/common/buf"
 	E "github.com/sagernet/sing/common/exceptions"
 )
@@ -19,7 +20,7 @@ type (
 )
 
 func PeekStream(ctx context.Context, conn net.Conn, buffer *buf.Buffer, sniffers ...StreamSniffer) (*adapter.InboundContext, error) {
-	err := conn.SetReadDeadline(time.Now().Add(300 * time.Millisecond))
+	err := conn.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout))
 	if err != nil {
 		return nil, err
 	}

+ 8 - 0
constant/timeout.go

@@ -0,0 +1,8 @@
+package constant
+
+import "time"
+
+const (
+	DefaultTCPTimeout  = 5 * time.Second
+	ReadPayloadTimeout = 300 * time.Millisecond
+)

+ 19 - 22
option/outbound.go

@@ -82,20 +82,8 @@ type DialerOptions struct {
 
 type OutboundDialerOptions struct {
 	DialerOptions
-	OverrideOptions *OverrideStreamOptions `json:"override,omitempty"`
-	DomainStrategy  DomainStrategy         `json:"domain_strategy,omitempty"`
-	FallbackDelay   Duration               `json:"fallback_delay,omitempty"`
-}
-
-type OverrideStreamOptions struct {
-	TLS           bool   `json:"tls,omitempty"`
-	TLSServerName string `json:"tls_servername,omitempty"`
-	TLSInsecure   bool   `json:"tls_insecure,omitempty"`
-	UDPOverTCP    bool   `json:"udp_over_tcp,omitempty"`
-}
-
-func (o *OverrideStreamOptions) IsValid() bool {
-	return o != nil && (o.TLS || o.UDPOverTCP)
+	DomainStrategy DomainStrategy `json:"domain_strategy,omitempty"`
+	FallbackDelay  Duration       `json:"fallback_delay,omitempty"`
 }
 
 type ServerOptions struct {
@@ -125,8 +113,16 @@ type SocksOutboundOptions struct {
 type HTTPOutboundOptions struct {
 	OutboundDialerOptions
 	ServerOptions
-	Username string `json:"username,omitempty"`
-	Password string `json:"password,omitempty"`
+	Username   string              `json:"username,omitempty"`
+	Password   string              `json:"password,omitempty"`
+	TLSOptions *OutboundTLSOptions `json:"tls,omitempty"`
+}
+
+type OutboundTLSOptions struct {
+	Enabled    bool   `json:"enabled,omitempty"`
+	DisableSNI bool   `json:"disable_sni,omitempty"`
+	ServerName string `json:"server_name,omitempty"`
+	Insecure   bool   `json:"insecure,omitempty"`
 }
 
 type ShadowsocksOutboundOptions struct {
@@ -140,10 +136,11 @@ type ShadowsocksOutboundOptions struct {
 type VMessOutboundOptions struct {
 	OutboundDialerOptions
 	ServerOptions
-	UUID                string      `json:"uuid"`
-	Security            string      `json:"security"`
-	AlterId             int         `json:"alter_id,omitempty"`
-	GlobalPadding       bool        `json:"global_padding,omitempty"`
-	AuthenticatedLength bool        `json:"authenticated_length,omitempty"`
-	Network             NetworkList `json:"network,omitempty"`
+	UUID                string              `json:"uuid"`
+	Security            string              `json:"security"`
+	AlterId             int                 `json:"alter_id,omitempty"`
+	GlobalPadding       bool                `json:"global_padding,omitempty"`
+	AuthenticatedLength bool                `json:"authenticated_length,omitempty"`
+	Network             NetworkList         `json:"network,omitempty"`
+	TLSOptions          *OutboundTLSOptions `json:"tls,omitempty"`
 }

+ 1 - 1
outbound/builder.go

@@ -21,7 +21,7 @@ func New(router adapter.Router, logger log.ContextLogger, options option.Outboun
 	case C.TypeSocks:
 		return NewSocks(router, logger, options.Tag, options.SocksOptions)
 	case C.TypeHTTP:
-		return NewHTTP(router, logger, options.Tag, options.HTTPOptions), nil
+		return NewHTTP(router, logger, options.Tag, options.HTTPOptions)
 	case C.TypeShadowsocks:
 		return NewShadowsocks(router, logger, options.Tag, options.ShadowsocksOptions)
 	case C.TypeVMess:

+ 1 - 1
outbound/default.go

@@ -93,7 +93,7 @@ func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) erro
 	}
 	_payload := buf.StackNew()
 	payload := common.Dup(_payload)
-	err := conn.SetReadDeadline(time.Now().Add(300 * time.Millisecond))
+	err := conn.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout))
 	if err != nil {
 		return err
 	}

+ 8 - 3
outbound/http.go

@@ -10,6 +10,7 @@ import (
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
+	"github.com/sagernet/sing/common"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 	"github.com/sagernet/sing/protocol/http"
@@ -22,7 +23,11 @@ type HTTP struct {
 	client *http.Client
 }
 
-func NewHTTP(router adapter.Router, logger log.ContextLogger, tag string, options option.HTTPOutboundOptions) *HTTP {
+func NewHTTP(router adapter.Router, logger log.ContextLogger, tag string, options option.HTTPOutboundOptions) (*HTTP, error) {
+	detour, err := dialer.NewTLS(dialer.NewOutbound(router, options.OutboundDialerOptions), options.Server, common.PtrValueOrDefault(options.TLSOptions))
+	if err != nil {
+		return nil, err
+	}
 	return &HTTP{
 		myOutboundAdapter{
 			protocol: C.TypeHTTP,
@@ -30,8 +35,8 @@ func NewHTTP(router adapter.Router, logger log.ContextLogger, tag string, option
 			tag:      tag,
 			network:  []string{C.NetworkTCP},
 		},
-		http.NewClient(dialer.NewOutbound(router, options.OutboundDialerOptions), options.ServerOptions.Build(), options.Username, options.Password),
-	}
+		http.NewClient(detour, options.ServerOptions.Build(), options.Username, options.Password),
+	}, nil
 }
 
 func (h *HTTP) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {

+ 6 - 1
outbound/vmess.go

@@ -10,6 +10,7 @@ import (
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
 	"github.com/sagernet/sing-vmess"
+	"github.com/sagernet/sing/common"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 )
@@ -35,6 +36,10 @@ func NewVMess(router adapter.Router, logger log.ContextLogger, tag string, optio
 	if err != nil {
 		return nil, err
 	}
+	detour, err := dialer.NewTLS(dialer.NewOutbound(router, options.OutboundDialerOptions), options.Server, common.PtrValueOrDefault(options.TLSOptions))
+	if err != nil {
+		return nil, err
+	}
 	return &VMess{
 		myOutboundAdapter{
 			protocol: C.TypeDirect,
@@ -42,7 +47,7 @@ func NewVMess(router adapter.Router, logger log.ContextLogger, tag string, optio
 			tag:      tag,
 			network:  options.Network.Build(),
 		},
-		dialer.NewOutbound(router, options.OutboundDialerOptions),
+		detour,
 		client,
 		options.ServerOptions.Build(),
 	}, nil

+ 13 - 6
test/box_test.go

@@ -35,13 +35,20 @@ func mkPort(t *testing.T) uint16 {
 }
 
 func startInstance(t *testing.T, options option.Options) {
-	instance, err := box.New(context.Background(), options)
+	var err error
+	for retry := 0; retry < 3; retry++ {
+		instance, err := box.New(context.Background(), options)
+		require.NoError(t, err)
+		err = instance.Start()
+		if err != nil {
+			time.Sleep(5 * time.Millisecond)
+			continue
+		}
+		t.Cleanup(func() {
+			instance.Close()
+		})
+	}
 	require.NoError(t, err)
-	require.NoError(t, instance.Start())
-	t.Cleanup(func() {
-		instance.Close()
-	})
-	time.Sleep(time.Second)
 }
 
 func testSuit(t *testing.T, clientPort uint16, testPort uint16) {

+ 2 - 2
test/clash_test.go

@@ -484,7 +484,7 @@ func listen(network, address string) (net.Listener, error) {
 		}
 
 		lastErr = err
-		time.Sleep(time.Millisecond * 200)
+		time.Sleep(5 * time.Millisecond)
 	}
 	return nil, lastErr
 }
@@ -500,7 +500,7 @@ func listenPacket(network, address string) (net.PacketConn, error) {
 		}
 
 		lastErr = err
-		time.Sleep(time.Millisecond * 200)
+		time.Sleep(5 * time.Millisecond)
 	}
 	return nil, lastErr
 }