Pārlūkot izejas kodu

Add dial parallel for outbound dialer

世界 3 gadi atpakaļ
vecāks
revīzija
3699a57847

+ 1 - 1
common/dialer/default.go

@@ -38,7 +38,7 @@ func NewDefault(options option.DialerOptions) *DefaultDialer {
 		listener.Control = control.Append(listener.Control, ProtectPath(options.ProtectPath))
 	}
 	if options.ConnectTimeout != 0 {
-		dialer.Timeout = time.Duration(options.ConnectTimeout) * time.Second
+		dialer.Timeout = time.Duration(options.ConnectTimeout)
 	}
 	return &DefaultDialer{tfo.Dialer{Dialer: dialer, DisableTFO: !options.TCPFastOpen}, listener}
 }

+ 7 - 1
common/dialer/dialer.go

@@ -1,6 +1,8 @@
 package dialer
 
 import (
+	"time"
+
 	"github.com/sagernet/sing/common"
 	N "github.com/sagernet/sing/common/network"
 
@@ -21,7 +23,11 @@ func NewOutbound(router adapter.Router, options option.OutboundDialerOptions) N.
 	dialer := New(router, options.DialerOptions)
 	domainStrategy := C.DomainStrategy(options.DomainStrategy)
 	if domainStrategy != C.DomainStrategyAsIS || options.Detour == "" && !C.CGO_ENABLED {
-		dialer = NewResolveDialer(router, dialer, domainStrategy)
+		fallbackDelay := time.Duration(options.FallbackDelay)
+		if fallbackDelay == 0 {
+			fallbackDelay = time.Millisecond * 300
+		}
+		dialer = NewResolveDialer(router, dialer, domainStrategy, fallbackDelay)
 	}
 	if options.OverrideOptions.IsValid() {
 		dialer = NewOverride(dialer, common.PtrValueOrDefault(options.OverrideOptions))

+ 91 - 0
common/dialer/parallel.go

@@ -0,0 +1,91 @@
+package dialer
+
+import (
+	"context"
+	"net"
+	"net/netip"
+	"time"
+
+	"github.com/sagernet/sing/common"
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+
+	C "github.com/sagernet/sing-box/constant"
+)
+
+func DialParallel(ctx context.Context, dialer N.Dialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.DomainStrategy, fallbackDelay time.Duration) (net.Conn, error) {
+	// kanged form net.Dial
+
+	returned := make(chan struct{})
+	defer close(returned)
+
+	addresses4 := common.Filter(destinationAddresses, func(address netip.Addr) bool {
+		return address.Is4() || address.Is4In6()
+	})
+	addresses6 := common.Filter(destinationAddresses, func(address netip.Addr) bool {
+		return address.Is6() && !address.Is4In6()
+	})
+	if len(addresses4) == 0 || len(addresses6) == 0 {
+		return DialSerial(ctx, dialer, network, destination, destinationAddresses)
+	}
+	var primaries, fallbacks []netip.Addr
+	switch strategy {
+	case C.DomainStrategyPreferIPv6:
+		primaries = addresses6
+		fallbacks = addresses4
+	default:
+		primaries = addresses4
+		fallbacks = addresses6
+	}
+	type dialResult struct {
+		net.Conn
+		error
+		primary bool
+		done    bool
+	}
+	results := make(chan dialResult) // unbuffered
+	startRacer := func(ctx context.Context, primary bool) {
+		ras := primaries
+		if !primary {
+			ras = fallbacks
+		}
+		c, err := DialSerial(ctx, dialer, network, destination, ras)
+		select {
+		case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
+		case <-returned:
+			if c != nil {
+				c.Close()
+			}
+		}
+	}
+	var primary, fallback dialResult
+	primaryCtx, primaryCancel := context.WithCancel(ctx)
+	defer primaryCancel()
+	go startRacer(primaryCtx, true)
+	fallbackTimer := time.NewTimer(fallbackDelay)
+	defer fallbackTimer.Stop()
+	for {
+		select {
+		case <-fallbackTimer.C:
+			fallbackCtx, fallbackCancel := context.WithCancel(ctx)
+			defer fallbackCancel()
+			go startRacer(fallbackCtx, false)
+
+		case res := <-results:
+			if res.error == nil {
+				return res.Conn, nil
+			}
+			if res.primary {
+				primary = res
+			} else {
+				fallback = res
+			}
+			if primary.done && fallback.done {
+				return nil, primary.error
+			}
+			if res.primary && fallbackTimer.Stop() {
+				fallbackTimer.Reset(0)
+			}
+		}
+	}
+}

+ 13 - 6
common/dialer/resolve.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"net"
 	"net/netip"
+	"time"
 
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
@@ -13,16 +14,18 @@ import (
 )
 
 type ResolveDialer struct {
-	dialer   N.Dialer
-	router   adapter.Router
-	strategy C.DomainStrategy
+	dialer        N.Dialer
+	router        adapter.Router
+	strategy      C.DomainStrategy
+	fallbackDelay time.Duration
 }
 
-func NewResolveDialer(router adapter.Router, dialer N.Dialer, strategy C.DomainStrategy) *ResolveDialer {
+func NewResolveDialer(router adapter.Router, dialer N.Dialer, strategy C.DomainStrategy, fallbackDelay time.Duration) *ResolveDialer {
 	return &ResolveDialer{
 		dialer,
 		router,
 		strategy,
+		fallbackDelay,
 	}
 }
 
@@ -40,7 +43,7 @@ func (d *ResolveDialer) DialContext(ctx context.Context, network string, destina
 	if err != nil {
 		return nil, err
 	}
-	return DialSerial(ctx, d.dialer, network, destination, addresses)
+	return DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy, d.fallbackDelay)
 }
 
 func (d *ResolveDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
@@ -57,7 +60,11 @@ func (d *ResolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd
 	if err != nil {
 		return nil, err
 	}
-	return ListenSerial(ctx, d.dialer, destination, addresses)
+	conn, err := ListenSerial(ctx, d.dialer, destination, addresses)
+	if err != nil {
+		return nil, err
+	}
+	return NewResolvePacketConn(d.router, d.strategy, conn), nil
 }
 
 func (d *ResolveDialer) Upstream() any {

+ 83 - 0
common/dialer/resolve_conn.go

@@ -0,0 +1,83 @@
+package dialer
+
+import (
+	"context"
+	"net"
+
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/buf"
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+
+	"github.com/sagernet/sing-box/adapter"
+	C "github.com/sagernet/sing-box/constant"
+)
+
+func NewResolvePacketConn(router adapter.Router, strategy C.DomainStrategy, conn net.PacketConn) N.NetPacketConn {
+	if udpConn, ok := conn.(*net.UDPConn); ok {
+		return &ResolveUDPConn{udpConn, router, strategy}
+	} else {
+		return &ResolvePacketConn{conn, router, strategy}
+	}
+}
+
+type ResolveUDPConn struct {
+	*net.UDPConn
+	router   adapter.Router
+	strategy C.DomainStrategy
+}
+
+func (w *ResolveUDPConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
+	n, addr, err := w.ReadFromUDPAddrPort(buffer.FreeBytes())
+	if err != nil {
+		return M.Socksaddr{}, err
+	}
+	buffer.Truncate(n)
+	return M.SocksaddrFromNetIP(addr), nil
+}
+
+func (w *ResolveUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
+	defer buffer.Release()
+	if destination.Family().IsFqdn() {
+		addresses, err := w.router.Lookup(context.Background(), destination.Fqdn, w.strategy)
+		if err != nil {
+			return err
+		}
+		return common.Error(w.UDPConn.WriteTo(buffer.Bytes(), M.SocksaddrFromAddrPort(addresses[0], destination.Port).UDPAddr()))
+	}
+	return common.Error(w.UDPConn.WriteToUDP(buffer.Bytes(), destination.UDPAddr()))
+}
+
+func (w *ResolveUDPConn) Upstream() any {
+	return w.UDPConn
+}
+
+type ResolvePacketConn struct {
+	net.PacketConn
+	router   adapter.Router
+	strategy C.DomainStrategy
+}
+
+func (w *ResolvePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
+	_, addr, err := buffer.ReadPacketFrom(w)
+	if err != nil {
+		return M.Socksaddr{}, err
+	}
+	return M.SocksaddrFromNet(addr), err
+}
+
+func (w *ResolvePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
+	defer buffer.Release()
+	if destination.Family().IsFqdn() {
+		addresses, err := w.router.Lookup(context.Background(), destination.Fqdn, w.strategy)
+		if err != nil {
+			return err
+		}
+		return common.Error(w.WriteTo(buffer.Bytes(), M.SocksaddrFromAddrPort(addresses[0], destination.Port).UDPAddr()))
+	}
+	return common.Error(w.WriteTo(buffer.Bytes(), destination.UDPAddr()))
+}
+
+func (w *ResolvePacketConn) Upstream() any {
+	return w.PacketConn
+}

+ 1 - 0
common/dialer/serial.go

@@ -18,6 +18,7 @@ func DialSerial(ctx context.Context, dialer N.Dialer, network string, destinatio
 		conn, err = dialer.DialContext(ctx, network, M.SocksaddrFromAddrPort(address, destination.Port))
 		if err != nil {
 			connErrors = append(connErrors, err)
+			continue
 		}
 		return conn, nil
 	}

+ 6 - 2
dns/client.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"net"
 	"net/netip"
+	"strings"
 	"time"
 
 	"github.com/sagernet/sing/common"
@@ -71,11 +72,14 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
 	if !c.disableCache {
 		c.storeCache(question, response)
 	}
-	return message, err
+	return response, err
 }
 
 func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, domain string, strategy C.DomainStrategy) ([]netip.Addr, error) {
-	dnsName, err := dnsmessage.NewName(domain)
+	if strings.HasPrefix(domain, ".") {
+		domain = domain[:len(domain)-1]
+	}
+	dnsName, err := dnsmessage.NewName(domain + ".")
 	if err != nil {
 		return nil, wrapError(err)
 	}

+ 9 - 2
dns/transport.go

@@ -22,8 +22,15 @@ func NewTransport(ctx context.Context, dialer N.Dialer, logger log.Logger, addre
 	}
 	host := serverURL.Hostname()
 	port := serverURL.Port()
-	if port == "" {
-		port = "53"
+	switch serverURL.Scheme {
+	case "tls":
+		if port == "" {
+			port = "853"
+		}
+	default:
+		if port == "" {
+			port = "53"
+		}
 	}
 	destination := M.ParseSocksaddrHostPortStr(host, port)
 	switch serverURL.Scheme {

+ 2 - 4
dns/transport_tcp.go

@@ -77,10 +77,9 @@ func (t *TCPTransport) offer() (*dnsConnection, error) {
 func (t *TCPTransport) newConnection(conn *dnsConnection) {
 	defer close(conn.done)
 	defer conn.Close()
-	ctx, cancel := context.WithCancel(t.ctx)
-	err := task.Any(t.ctx, func() error {
+	err := task.Any(t.ctx, func(ctx context.Context) error {
 		return t.loopIn(conn)
-	}, func() error {
+	}, func(ctx context.Context) error {
 		select {
 		case <-ctx.Done():
 			return nil
@@ -88,7 +87,6 @@ func (t *TCPTransport) newConnection(conn *dnsConnection) {
 			return os.ErrClosed
 		}
 	})
-	cancel()
 	conn.err = err
 	if err != nil {
 		t.logger.Debug("connection closed: ", err)

+ 2 - 4
dns/transport_tls.go

@@ -85,10 +85,9 @@ func (t *TLSTransport) offer(ctx context.Context) (*dnsConnection, error) {
 func (t *TLSTransport) newConnection(conn *dnsConnection) {
 	defer close(conn.done)
 	defer conn.Close()
-	ctx, cancel := context.WithCancel(t.ctx)
-	err := task.Any(t.ctx, func() error {
+	err := task.Any(t.ctx, func(ctx context.Context) error {
 		return t.loopIn(conn)
-	}, func() error {
+	}, func(ctx context.Context) error {
 		select {
 		case <-ctx.Done():
 			return nil
@@ -96,7 +95,6 @@ func (t *TLSTransport) newConnection(conn *dnsConnection) {
 			return os.ErrClosed
 		}
 	})
-	cancel()
 	conn.err = err
 	if err != nil {
 		t.logger.Debug("connection closed: ", err)

+ 2 - 4
dns/transport_udp.go

@@ -73,10 +73,9 @@ func (t *UDPTransport) offer() (*dnsConnection, error) {
 func (t *UDPTransport) newConnection(conn *dnsConnection) {
 	defer close(conn.done)
 	defer conn.Close()
-	ctx, cancel := context.WithCancel(t.ctx)
-	err := task.Any(t.ctx, func() error {
+	err := task.Any(t.ctx, func(ctx context.Context) error {
 		return t.loopIn(conn)
-	}, func() error {
+	}, func(ctx context.Context) error {
 		select {
 		case <-ctx.Done():
 			return nil
@@ -84,7 +83,6 @@ func (t *UDPTransport) newConnection(conn *dnsConnection) {
 			return os.ErrClosed
 		}
 	})
-	cancel()
 	conn.err = err
 	if err != nil {
 		t.logger.Debug("connection closed: ", err)

+ 2 - 2
go.mod

@@ -7,13 +7,13 @@ require (
 	github.com/goccy/go-json v0.9.8
 	github.com/logrusorgru/aurora v2.0.3+incompatible
 	github.com/oschwald/maxminddb-golang v1.9.0
-	github.com/sagernet/sing v0.0.0-20220707133944-6a0987c52ae4
+	github.com/sagernet/sing v0.0.0-20220708041648-04e100e91a92
 	github.com/sagernet/sing-shadowsocks v0.0.0-20220701084835-2208da1d8649
 	github.com/sirupsen/logrus v1.8.1
 	github.com/spf13/cobra v1.5.0
 	github.com/stretchr/testify v1.8.0
 	golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d
-	golang.org/x/net v0.0.0-20220630215102-69896b714898
+	golang.org/x/net v0.0.0-20220706163947-c90051bbdb60
 )
 
 require (

+ 4 - 4
go.sum

@@ -23,8 +23,8 @@ github.com/oschwald/maxminddb-golang v1.9.0/go.mod h1:TK+s/Z2oZq0rSl4PSeAEoP0bgm
 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
-github.com/sagernet/sing v0.0.0-20220707133944-6a0987c52ae4 h1:nV/DyNi+O1VxNoChD5E9M6Y0VoFdVr0UEW9h9JnqxNs=
-github.com/sagernet/sing v0.0.0-20220707133944-6a0987c52ae4/go.mod h1:3ZmoGNg/nNJTyHAZFNRSPaXpNIwpDvyIiAUd0KIWV5c=
+github.com/sagernet/sing v0.0.0-20220708041648-04e100e91a92 h1:c+Jg/o4UBZ+7CFdKWy8XhPN5X1rtulYdMqdgjx6PNUo=
+github.com/sagernet/sing v0.0.0-20220708041648-04e100e91a92/go.mod h1:3ZmoGNg/nNJTyHAZFNRSPaXpNIwpDvyIiAUd0KIWV5c=
 github.com/sagernet/sing-shadowsocks v0.0.0-20220701084835-2208da1d8649 h1:whNDUGOAX5GPZkSy4G3Gv9QyIgk5SXRyjkRuP7ohF8k=
 github.com/sagernet/sing-shadowsocks v0.0.0-20220701084835-2208da1d8649/go.mod h1:MuyT+9fEPjvauAv0fSE0a6Q+l0Tv2ZrAafTkYfnxBFw=
 github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE=
@@ -41,8 +41,8 @@ github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PK
 github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
 golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY=
 golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
-golang.org/x/net v0.0.0-20220630215102-69896b714898 h1:K7wO6V1IrczY9QOQ2WkVpw4JQSwCd52UsxVEirZUfiw=
-golang.org/x/net v0.0.0-20220630215102-69896b714898/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
+golang.org/x/net v0.0.0-20220706163947-c90051bbdb60 h1:8NSylCMxLW4JvserAndSgFL7aPli6A68yf0bYFTcWCM=
+golang.org/x/net v0.0.0-20220706163947-c90051bbdb60/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
 golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20220627191245-f75cf1eec38b h1:2n253B2r0pYSmEV+UNCQoPfU/FiaizQEK5Gu4Bq4JE8=
 golang.org/x/sys v0.0.0-20220627191245-f75cf1eec38b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

+ 1 - 0
inbound/default.go

@@ -142,6 +142,7 @@ func (a *myInboundAdapter) loopTCPIn() {
 			a.logger.WithContext(ctx).Info("inbound connection from ", metadata.Source)
 			hErr := a.connHandler.NewConnection(ctx, conn, metadata)
 			if hErr != nil {
+				conn.Close()
 				a.NewError(ctx, E.Cause(hErr, "process connection from ", metadata.Source))
 			}
 		}()

+ 8 - 7
option/outbound.go

@@ -67,19 +67,20 @@ func (h *Outbound) UnmarshalJSON(bytes []byte) error {
 }
 
 type DialerOptions struct {
-	Detour         string `json:"detour,omitempty"`
-	BindInterface  string `json:"bind_interface,omitempty"`
-	ProtectPath    string `json:"protect_path,omitempty"`
-	RoutingMark    int    `json:"routing_mark,omitempty"`
-	ReuseAddr      bool   `json:"reuse_addr,omitempty"`
-	ConnectTimeout int    `json:"connect_timeout,omitempty"`
-	TCPFastOpen    bool   `json:"tcp_fast_open,omitempty"`
+	Detour         string   `json:"detour,omitempty"`
+	BindInterface  string   `json:"bind_interface,omitempty"`
+	ProtectPath    string   `json:"protect_path,omitempty"`
+	RoutingMark    int      `json:"routing_mark,omitempty"`
+	ReuseAddr      bool     `json:"reuse_addr,omitempty"`
+	ConnectTimeout Duration `json:"connect_timeout,omitempty"`
+	TCPFastOpen    bool     `json:"tcp_fast_open,omitempty"`
 }
 
 type OutboundDialerOptions struct {
 	DialerOptions
 	OverrideOptions *OverrideStreamOptions `json:"override,omitempty"`
 	DomainStrategy  DomainStrategy         `json:"domain_strategy,omitempty"`
+	FallbackDelay   Duration               `json:"fallback_delay,omitempty"`
 }
 
 type OverrideStreamOptions struct {

+ 21 - 0
option/types.go

@@ -3,6 +3,7 @@ package option
 import (
 	"net/netip"
 	"strings"
+	"time"
 
 	E "github.com/sagernet/sing/common/exceptions"
 
@@ -135,3 +136,23 @@ func (s *DomainStrategy) UnmarshalJSON(bytes []byte) error {
 	}
 	return nil
 }
+
+type Duration time.Duration
+
+func (d Duration) MarshalJSON() ([]byte, error) {
+	return json.Marshal((time.Duration)(d).String())
+}
+
+func (d *Duration) UnmarshalJSON(bytes []byte) error {
+	var value string
+	err := json.Unmarshal(bytes, &value)
+	if err != nil {
+		return err
+	}
+	duration, err := time.ParseDuration(value)
+	if err != nil {
+		return err
+	}
+	*d = Duration(duration)
+	return nil
+}

+ 0 - 2
route/router.go

@@ -450,7 +450,6 @@ func (r *Router) match(ctx context.Context, metadata adapter.InboundContext, def
 			r.logger.WithContext(ctx).Error("outbound not found: ", detour)
 		}
 	}
-	r.logger.WithContext(ctx).Info("no match")
 	return defaultOutbound
 }
 
@@ -470,7 +469,6 @@ func (r *Router) matchDNS(ctx context.Context) adapter.DNSTransport {
 			r.dnsLogger.WithContext(ctx).Error("transport not found: ", detour)
 		}
 	}
-	r.dnsLogger.WithContext(ctx).Info("no match")
 	return r.defaultTransport
 }