Просмотр исходного кода

Core: Fix memory leaks with RequireFeatures() (#4095)

Fixes https://github.com/XTLS/Xray-core/issues/4054
Fixes https://github.com/XTLS/Xray-core/issues/3338
Fixes https://github.com/XTLS/Xray-core/issues/3221
yuhan6665 1 год назад
Родитель
Сommit
0e2304c403

+ 1 - 1
app/dispatcher/default.go

@@ -106,7 +106,7 @@ func init() {
 	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
 		d := new(DefaultDispatcher)
 		if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error {
-			core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) {
+			core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { // FakeDNSEngine is optional
 				d.fdns = fdns
 			})
 			return d.Init(config.(*Config), om, router, pm, sm, dc)

+ 7 - 3
app/dns/nameserver.go

@@ -35,7 +35,7 @@ type Client struct {
 var errExpectedIPNonMatch = errors.New("expectIPs not match")
 
 // NewServer creates a name server object according to the network destination url.
-func NewServer(dest net.Destination, dispatcher routing.Dispatcher, queryStrategy QueryStrategy) (Server, error) {
+func NewServer(dest net.Destination, dispatcher routing.Dispatcher, queryStrategy QueryStrategy, fd dns.FakeDNSEngine) (Server, error) {
 	if address := dest.Address; address.Family().IsDomain() {
 		u, err := url.Parse(address.Domain())
 		if err != nil {
@@ -55,7 +55,7 @@ func NewServer(dest net.Destination, dispatcher routing.Dispatcher, queryStrateg
 		case strings.EqualFold(u.Scheme, "tcp+local"): // DNS-over-TCP Local mode
 			return NewTCPLocalNameServer(u, queryStrategy)
 		case strings.EqualFold(u.String(), "fakedns"):
-			return NewFakeDNSServer(), nil
+			return NewFakeDNSServer(fd), nil
 		}
 	}
 	if dest.Network == net.Network_Unknown {
@@ -78,9 +78,13 @@ func NewClient(
 ) (*Client, error) {
 	client := &Client{}
 
+	var fd dns.FakeDNSEngine
 	err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher) error {
+		core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { // FakeDNSEngine is optional
+			fd = fdns
+		})
 		// Create a new server for each client for now
-		server, err := NewServer(ns.Address.AsDestination(), dispatcher, ns.GetQueryStrategy())
+		server, err := NewServer(ns.Address.AsDestination(), dispatcher, ns.GetQueryStrategy(), fd)
 		if err != nil {
 			return errors.New("failed to create nameserver").Base(err).AtWarning()
 		}

+ 4 - 8
app/dns/nameserver_fakedns.go

@@ -5,7 +5,6 @@ import (
 
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/net"
-	"github.com/xtls/xray-core/core"
 	"github.com/xtls/xray-core/features/dns"
 )
 
@@ -13,8 +12,8 @@ type FakeDNSServer struct {
 	fakeDNSEngine dns.FakeDNSEngine
 }
 
-func NewFakeDNSServer() *FakeDNSServer {
-	return &FakeDNSServer{}
+func NewFakeDNSServer(fd dns.FakeDNSEngine) *FakeDNSServer {
+	return &FakeDNSServer{fakeDNSEngine: fd}
 }
 
 func (FakeDNSServer) Name() string {
@@ -23,12 +22,9 @@ func (FakeDNSServer) Name() string {
 
 func (f *FakeDNSServer) QueryIP(ctx context.Context, domain string, _ net.IP, opt dns.IPOption, _ bool) ([]net.IP, error) {
 	if f.fakeDNSEngine == nil {
-		if err := core.RequireFeatures(ctx, func(fd dns.FakeDNSEngine) {
-			f.fakeDNSEngine = fd
-		}); err != nil {
-			return nil, errors.New("Unable to locate a fake DNS Engine").Base(err).AtError()
-		}
+		return nil, errors.New("Unable to locate a fake DNS Engine").AtError()
 	}
+
 	var ips []net.Address
 	if fkr0, ok := f.fakeDNSEngine.(dns.FakeDNSEngineRev0); ok {
 		ips = fkr0.GetFakeIPForDomain3(domain, opt.IPv4Enable, opt.IPv6Enable)

+ 5 - 2
app/observatory/burst/burstobserver.go

@@ -12,6 +12,7 @@ import (
 	"github.com/xtls/xray-core/core"
 	"github.com/xtls/xray-core/features/extension"
 	"github.com/xtls/xray-core/features/outbound"
+	"github.com/xtls/xray-core/features/routing"
 	"google.golang.org/protobuf/proto"
 )
 
@@ -88,13 +89,15 @@ func (o *Observer) Close() error {
 
 func New(ctx context.Context, config *Config) (*Observer, error) {
 	var outboundManager outbound.Manager
-	err := core.RequireFeatures(ctx, func(om outbound.Manager) {
+	var dispatcher routing.Dispatcher
+	err := core.RequireFeatures(ctx, func(om outbound.Manager, rd routing.Dispatcher) {
 		outboundManager = om
+		dispatcher = rd
 	})
 	if err != nil {
 		return nil, errors.New("Cannot get depended features").Base(err)
 	}
-	hp := NewHealthPing(ctx, config.PingConfig)
+	hp := NewHealthPing(ctx, dispatcher, config.PingConfig)
 	return &Observer{
 		config: config,
 		ctx:    ctx,

+ 5 - 1
app/observatory/burst/healthping.go

@@ -9,6 +9,7 @@ import (
 
 	"github.com/xtls/xray-core/common/dice"
 	"github.com/xtls/xray-core/common/errors"
+	"github.com/xtls/xray-core/features/routing"
 )
 
 // HealthPingSettings holds settings for health Checker
@@ -23,6 +24,7 @@ type HealthPingSettings struct {
 // HealthPing is the health checker for balancers
 type HealthPing struct {
 	ctx         context.Context
+	dispatcher  routing.Dispatcher
 	access      sync.Mutex
 	ticker      *time.Ticker
 	tickerClose chan struct{}
@@ -32,7 +34,7 @@ type HealthPing struct {
 }
 
 // NewHealthPing creates a new HealthPing with settings
-func NewHealthPing(ctx context.Context, config *HealthPingConfig) *HealthPing {
+func NewHealthPing(ctx context.Context, dispatcher routing.Dispatcher, config *HealthPingConfig) *HealthPing {
 	settings := &HealthPingSettings{}
 	if config != nil {
 		settings = &HealthPingSettings{
@@ -65,6 +67,7 @@ func NewHealthPing(ctx context.Context, config *HealthPingConfig) *HealthPing {
 	}
 	return &HealthPing{
 		ctx:      ctx,
+		dispatcher: dispatcher,
 		Settings: settings,
 		Results:  nil,
 	}
@@ -149,6 +152,7 @@ func (h *HealthPing) doCheck(tags []string, duration time.Duration, rounds int)
 		handler := tag
 		client := newPingClient(
 			h.ctx,
+			h.dispatcher,
 			h.Settings.Destination,
 			h.Settings.Timeout,
 			handler,

+ 5 - 4
app/observatory/burst/ping.go

@@ -6,6 +6,7 @@ import (
 	"time"
 
 	"github.com/xtls/xray-core/common/net"
+	"github.com/xtls/xray-core/features/routing"
 	"github.com/xtls/xray-core/transport/internet/tagged"
 )
 
@@ -14,10 +15,10 @@ type pingClient struct {
 	httpClient  *http.Client
 }
 
-func newPingClient(ctx context.Context, destination string, timeout time.Duration, handler string) *pingClient {
+func newPingClient(ctx context.Context, dispatcher routing.Dispatcher, destination string, timeout time.Duration, handler string) *pingClient {
 	return &pingClient{
 		destination: destination,
-		httpClient:  newHTTPClient(ctx, handler, timeout),
+		httpClient:  newHTTPClient(ctx, dispatcher, handler, timeout),
 	}
 }
 
@@ -28,7 +29,7 @@ func newDirectPingClient(destination string, timeout time.Duration) *pingClient
 	}
 }
 
-func newHTTPClient(ctxv context.Context, handler string, timeout time.Duration) *http.Client {
+func newHTTPClient(ctxv context.Context, dispatcher routing.Dispatcher, handler string, timeout time.Duration) *http.Client {
 	tr := &http.Transport{
 		DisableKeepAlives: true,
 		DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
@@ -36,7 +37,7 @@ func newHTTPClient(ctxv context.Context, handler string, timeout time.Duration)
 			if err != nil {
 				return nil, err
 			}
-			return tagged.Dialer(ctxv, dest, handler)
+			return tagged.Dialer(ctxv, dispatcher, dest, handler)
 		},
 	}
 	return &http.Client{

+ 7 - 2
app/observatory/observer.go

@@ -18,6 +18,7 @@ import (
 	"github.com/xtls/xray-core/core"
 	"github.com/xtls/xray-core/features/extension"
 	"github.com/xtls/xray-core/features/outbound"
+	"github.com/xtls/xray-core/features/routing"
 	"github.com/xtls/xray-core/transport/internet/tagged"
 	"google.golang.org/protobuf/proto"
 )
@@ -32,6 +33,7 @@ type Observer struct {
 	finished *done.Instance
 
 	ohm outbound.Manager
+	dispatcher routing.Dispatcher
 }
 
 func (o *Observer) GetObservation(ctx context.Context) (proto.Message, error) {
@@ -131,7 +133,7 @@ func (o *Observer) probe(outbound string) ProbeResult {
 					return errors.New("cannot understand address").Base(err)
 				}
 				trackedCtx := session.TrackedConnectionError(o.ctx, errorCollectorForRequest)
-				conn, err := tagged.Dialer(trackedCtx, dest, outbound)
+				conn, err := tagged.Dialer(trackedCtx, o.dispatcher, dest, outbound)
 				if err != nil {
 					return errors.New("cannot dial remote address ", dest).Base(err)
 				}
@@ -215,8 +217,10 @@ func (o *Observer) findStatusLocationLockHolderOnly(outbound string) int {
 
 func New(ctx context.Context, config *Config) (*Observer, error) {
 	var outboundManager outbound.Manager
-	err := core.RequireFeatures(ctx, func(om outbound.Manager) {
+	var dispatcher routing.Dispatcher
+	err := core.RequireFeatures(ctx, func(om outbound.Manager, rd routing.Dispatcher) {
 		outboundManager = om
+		dispatcher = rd
 	})
 	if err != nil {
 		return nil, errors.New("Cannot get depended features").Base(err)
@@ -225,6 +229,7 @@ func New(ctx context.Context, config *Config) (*Observer, error) {
 		config: config,
 		ctx:    ctx,
 		ohm:    outboundManager,
+		dispatcher: dispatcher,
 	}, nil
 }
 

+ 5 - 7
app/router/balancing.go

@@ -5,7 +5,6 @@ import (
 	sync "sync"
 
 	"github.com/xtls/xray-core/app/observatory"
-	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/core"
 	"github.com/xtls/xray-core/features/extension"
@@ -31,6 +30,11 @@ type RoundRobinStrategy struct {
 
 func (s *RoundRobinStrategy) InjectContext(ctx context.Context) {
 	s.ctx = ctx
+	if len(s.FallbackTag) > 0 {
+		core.RequireFeaturesAsync(s.ctx, func(observatory extension.Observatory) {
+			s.observatory = observatory
+		})
+	}
 }
 
 func (s *RoundRobinStrategy) GetPrincipleTarget(strings []string) []string {
@@ -38,12 +42,6 @@ func (s *RoundRobinStrategy) GetPrincipleTarget(strings []string) []string {
 }
 
 func (s *RoundRobinStrategy) PickOutbound(tags []string) string {
-	if len(s.FallbackTag) > 0 && s.observatory == nil {
-		common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error {
-			s.observatory = observatory
-			return nil
-		}))
-	}
 	if s.observatory != nil {
 		observeReport, err := s.observatory.GetObservation(s.ctx)
 		if err == nil {

+ 7 - 7
app/router/strategy_leastload.go

@@ -7,7 +7,6 @@ import (
 	"time"
 
 	"github.com/xtls/xray-core/app/observatory"
-	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/dice"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/core"
@@ -58,8 +57,11 @@ type node struct {
 	RTTDeviationCost time.Duration
 }
 
-func (l *LeastLoadStrategy) InjectContext(ctx context.Context) {
-	l.ctx = ctx
+func (s *LeastLoadStrategy) InjectContext(ctx context.Context) {
+	s.ctx = ctx
+	core.RequireFeaturesAsync(s.ctx, func(observatory extension.Observatory) {
+		s.observer = observatory
+	})
 }
 
 func (s *LeastLoadStrategy) PickOutbound(candidates []string) string {
@@ -136,10 +138,8 @@ func (s *LeastLoadStrategy) selectLeastLoad(nodes []*node) []*node {
 
 func (s *LeastLoadStrategy) getNodes(candidates []string, maxRTT time.Duration) []*node {
 	if s.observer == nil {
-		common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error {
-			s.observer = observatory
-			return nil
-		}))
+		errors.LogError(s.ctx, "observer is nil")
+		return make([]*node, 0)
 	}
 	observeResult, err := s.observer.GetObservation(s.ctx)
 	if err != nil {

+ 6 - 7
app/router/strategy_leastping.go

@@ -4,7 +4,6 @@ import (
 	"context"
 
 	"github.com/xtls/xray-core/app/observatory"
-	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/core"
 	"github.com/xtls/xray-core/features/extension"
@@ -21,19 +20,19 @@ func (l *LeastPingStrategy) GetPrincipleTarget(strings []string) []string {
 
 func (l *LeastPingStrategy) InjectContext(ctx context.Context) {
 	l.ctx = ctx
+	core.RequireFeaturesAsync(l.ctx, func(observatory extension.Observatory) {
+		l.observatory = observatory
+	})
 }
 
 func (l *LeastPingStrategy) PickOutbound(strings []string) string {
 	if l.observatory == nil {
-		common.Must(core.RequireFeatures(l.ctx, func(observatory extension.Observatory) error {
-			l.observatory = observatory
-			return nil
-		}))
+		errors.LogError(l.ctx, "observer is nil")
+		return ""
 	}
-
 	observeReport, err := l.observatory.GetObservation(l.ctx)
 	if err != nil {
-		errors.LogInfoInner(l.ctx, err, "cannot get observe report")
+		errors.LogInfoInner(l.ctx, err, "cannot get observer report")
 		return ""
 	}
 	outboundsList := outboundList(strings)

+ 5 - 7
app/router/strategy_random.go

@@ -4,7 +4,6 @@ import (
 	"context"
 
 	"github.com/xtls/xray-core/app/observatory"
-	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/dice"
 	"github.com/xtls/xray-core/core"
 	"github.com/xtls/xray-core/features/extension"
@@ -20,6 +19,11 @@ type RandomStrategy struct {
 
 func (s *RandomStrategy) InjectContext(ctx context.Context) {
 	s.ctx = ctx
+	if len(s.FallbackTag) > 0 {
+		core.RequireFeaturesAsync(s.ctx, func(observatory extension.Observatory) {
+			s.observatory = observatory
+		})
+	}
 }
 
 func (s *RandomStrategy) GetPrincipleTarget(strings []string) []string {
@@ -27,12 +31,6 @@ func (s *RandomStrategy) GetPrincipleTarget(strings []string) []string {
 }
 
 func (s *RandomStrategy) PickOutbound(candidates []string) string {
-	if len(s.FallbackTag) > 0 && s.observatory == nil {
-		common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error {
-			s.observatory = observatory
-			return nil
-		}))
-	}
 	if s.observatory != nil {
 		observeReport, err := s.observatory.GetObservation(s.ctx)
 		if err == nil {

+ 37 - 0
core/xray.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"reflect"
 	"sync"
+	"time"
 
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/errors"
@@ -156,6 +157,12 @@ func RequireFeatures(ctx context.Context, callback interface{}) error {
 	return v.RequireFeatures(callback)
 }
 
+// RequireFeaturesAsync registers a callback, which will be called when all dependent features are registered. The order of app init doesn't matter
+func RequireFeaturesAsync(ctx context.Context, callback interface{}) {
+	v := MustFromContext(ctx)
+	v.RequireFeaturesAsync(callback)
+}
+
 // New returns a new Xray instance based on given configuration.
 // The instance is not started at this point.
 // To ensure Xray instance works properly, the config must contain one Dispatcher, one InboundHandlerManager and one OutboundHandlerManager. Other features are optional.
@@ -290,6 +297,36 @@ func (s *Instance) RequireFeatures(callback interface{}) error {
 	return nil
 }
 
+// RequireFeaturesAsync registers a callback, which will be called when all dependent features are registered. The order of app init doesn't matter
+func (s *Instance) RequireFeaturesAsync(callback interface{}) {
+	callbackType := reflect.TypeOf(callback)
+	if callbackType.Kind() != reflect.Func {
+		panic("not a function")
+	}
+
+	var featureTypes []reflect.Type
+	for i := 0; i < callbackType.NumIn(); i++ {
+		featureTypes = append(featureTypes, reflect.PtrTo(callbackType.In(i)))
+	}
+
+	r := resolution{
+		deps:     featureTypes,
+		callback: callback,
+	}
+	go func() {
+		var finished = false
+		for i := 0; !finished; i++ {
+			if i > 100000 {
+				errors.LogError(s.ctx, "RequireFeaturesAsync failed after count ", i)
+				break;
+			}
+			finished, _ = r.resolve(s.features)
+			time.Sleep(time.Millisecond)
+		}
+		s.featureResolutions = append(s.featureResolutions, r)
+	}()
+}
+
 // AddFeature registers a feature into current Instance.
 func (s *Instance) AddFeature(feature features.Feature) error {
 	s.features = append(s.features, feature)

+ 1 - 1
proxy/dns/dns.go

@@ -27,7 +27,7 @@ func init() {
 	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
 		h := new(Handler)
 		if err := core.RequireFeatures(ctx, func(dnsClient dns.Client, policyManager policy.Manager) error {
-			core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) {
+			core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { // FakeDNSEngine is optional
 				h.fdns = fdns
 			})
 			return h.Init(config.(*Config), dnsClient, policyManager)

+ 2 - 1
transport/internet/tagged/tagged.go

@@ -4,8 +4,9 @@ import (
 	"context"
 
 	"github.com/xtls/xray-core/common/net"
+	"github.com/xtls/xray-core/features/routing"
 )
 
-type DialFunc func(ctx context.Context, dest net.Destination, tag string) (net.Conn, error)
+type DialFunc func(ctx context.Context, dispatcher routing.Dispatcher, dest net.Destination, tag string) (net.Conn, error)
 
 var Dialer DialFunc

+ 1 - 8
transport/internet/tagged/taggedimpl/impl.go

@@ -12,17 +12,10 @@ import (
 	"github.com/xtls/xray-core/transport/internet/tagged"
 )
 
-func DialTaggedOutbound(ctx context.Context, dest net.Destination, tag string) (net.Conn, error) {
-	var dispatcher routing.Dispatcher
+func DialTaggedOutbound(ctx context.Context, dispatcher routing.Dispatcher, dest net.Destination, tag string) (net.Conn, error) {
 	if core.FromContext(ctx) == nil {
 		return nil, errors.New("Instance context variable is not in context, dial denied. ")
 	}
-	if err := core.RequireFeatures(ctx, func(dispatcherInstance routing.Dispatcher) {
-		dispatcher = dispatcherInstance
-	}); err != nil {
-		return nil, errors.New("Required Feature dispatcher not resolved").Base(err)
-	}
-
 	content := new(session.Content)
 	content.SkipDNSResolve = true