Browse Source

Override destination if replaced in hosts

世界 4 years ago
parent
commit
27224868ab
3 changed files with 44 additions and 6 deletions
  1. 24 6
      app/dispatcher/default.go
  2. 16 0
      app/dns/dns.go
  3. 4 0
      features/dns/client.go

+ 24 - 6
app/dispatcher/default.go

@@ -4,6 +4,7 @@ package dispatcher
 
 import (
 	"context"
+	"github.com/xtls/xray-core/features/dns"
 	"strings"
 	"sync"
 	"time"
@@ -15,7 +16,6 @@ import (
 	"github.com/xtls/xray-core/common/protocol"
 	"github.com/xtls/xray-core/common/session"
 	"github.com/xtls/xray-core/core"
-	"github.com/xtls/xray-core/features/dns"
 	"github.com/xtls/xray-core/features/outbound"
 	"github.com/xtls/xray-core/features/policy"
 	"github.com/xtls/xray-core/features/routing"
@@ -92,13 +92,14 @@ type DefaultDispatcher struct {
 	router routing.Router
 	policy policy.Manager
 	stats  stats.Manager
+	hosts  dns.HostsLookup
 }
 
 func init() {
 	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
 		d := new(DefaultDispatcher)
-		if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager) error {
-			return d.Init(config.(*Config), om, router, pm, sm)
+		if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error {
+			return d.Init(config.(*Config), om, router, pm, sm, dc)
 		}); err != nil {
 			return nil, err
 		}
@@ -107,11 +108,14 @@ func init() {
 }
 
 // Init initializes DefaultDispatcher.
-func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager) error {
+func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error {
 	d.ohm = om
 	d.router = router
 	d.policy = pm
 	d.stats = sm
+	if hosts, ok := dc.(dns.HostsLookup); ok {
+		d.hosts = hosts
+	}
 	return nil
 }
 
@@ -294,7 +298,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
 		result, err := sniffer(ctx, nil, true)
 		if err == nil {
 			content.Protocol = result.Protocol()
-			if shouldOverride(result, sniffingRequest.OverrideDestinationForProtocol) {
+			if shouldOverride(ctx, result, sniffingRequest, destination) {
 				domain := result.Domain()
 				newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
 				destination.Address = net.ParseAddress(domain)
@@ -316,7 +320,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
 			if err == nil {
 				content.Protocol = result.Protocol()
 			}
-			if err == nil && shouldOverride(result, sniffingRequest.OverrideDestinationForProtocol) {
+			if err == nil && shouldOverride(ctx, result, sniffingRequest, destination) {
 				domain := result.Domain()
 				newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
 				destination.Address = net.ParseAddress(domain)
@@ -379,6 +383,20 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool) (Sni
 }
 
 func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
+	ob := session.OutboundFromContext(ctx)
+	if d.hosts != nil && destination.Address.Family().IsDomain() {
+		proxied := d.hosts.LookupHosts(ob.Target.String())
+		if proxied != nil {
+			ro := ob.RouteTarget == destination
+			destination.Address = *proxied
+			if ro {
+				ob.RouteTarget = destination
+			} else {
+				ob.Target = destination
+			}
+		}
+	}
+
 	var handler outbound.Handler
 
 	if d.router != nil {

+ 16 - 0
app/dns/dns.go

@@ -223,6 +223,22 @@ func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, error) {
 	return nil, newError("returning nil for domain ", domain).Base(errors.Combine(errs...))
 }
 
+// LookupHosts implements dns.HostsLookup.
+func (s *DNS) LookupHosts(domain string) *net.Address {
+	domain = strings.TrimSuffix(domain, ".")
+	if domain == "" {
+		return nil
+	}
+	// Normalize the FQDN form query
+	addrs := s.hosts.Lookup(domain, *s.ipOption)
+	if len(addrs) > 0 {
+		newError("domain replaced: ", domain, " -> ", addrs[0].String()).AtInfo().WriteToLog()
+		return &addrs[0]
+	}
+
+	return nil
+}
+
 // GetIPOption implements ClientWithIPOption.
 func (s *DNS) GetIPOption() *dns.IPOption {
 	return s.ipOption

+ 4 - 0
features/dns/client.go

@@ -24,6 +24,10 @@ type Client interface {
 	LookupIP(domain string, option IPOption) ([]net.IP, error)
 }
 
+type HostsLookup interface {
+	LookupHosts(domain string) *net.Address
+}
+
 // ClientType returns the type of Client interface. Can be used for implementing common.HasType.
 //
 // xray:api:beta