Browse Source

Refine DNS Options

JimhHan 4 years ago
parent
commit
217844cc37

+ 14 - 11
app/dns/dns.go

@@ -167,33 +167,36 @@ func (s *DNS) IsOwnLink(ctx context.Context) bool {
 
 // LookupIP implements dns.Client.
 func (s *DNS) LookupIP(domain string) ([]net.IP, error) {
-	return s.lookupIPInternal(domain, *s.ipOption)
+	return s.lookupIPInternal(domain, s.ipOption.Copy())
 }
 
 // LookupOptions implements dns.Client.
-func (s *DNS) LookupOptions(domain string, opt dns.IPOption) ([]net.IP, error) {
+func (s *DNS) LookupOptions(domain string, opts ...dns.Option) ([]net.IP, error) {
+	opt := s.ipOption.Copy()
+	for _, o := range opts {
+		if o != nil {
+			o(opt)
+		}
+	}
+
 	return s.lookupIPInternal(domain, opt)
 }
 
 // LookupIPv4 implements dns.IPv4Lookup.
 func (s *DNS) LookupIPv4(domain string) ([]net.IP, error) {
-	return s.lookupIPInternal(domain, dns.IPOption{
+	return s.lookupIPInternal(domain, &dns.IPOption{
 		IPv4Enable: true,
-		IPv6Enable: false,
-		FakeEnable: false,
 	})
 }
 
 // LookupIPv6 implements dns.IPv6Lookup.
 func (s *DNS) LookupIPv6(domain string) ([]net.IP, error) {
-	return s.lookupIPInternal(domain, dns.IPOption{
-		IPv4Enable: false,
+	return s.lookupIPInternal(domain, &dns.IPOption{
 		IPv6Enable: true,
-		FakeEnable: false,
 	})
 }
 
-func (s *DNS) lookupIPInternal(domain string, option dns.IPOption) ([]net.IP, error) {
+func (s *DNS) lookupIPInternal(domain string, option *dns.IPOption) ([]net.IP, error) {
 	if domain == "" {
 		return nil, newError("empty domain name")
 	}
@@ -228,7 +231,7 @@ func (s *DNS) lookupIPInternal(domain string, option dns.IPOption) ([]net.IP, er
 	errs := []error{}
 	ctx := session.ContextWithInbound(s.ctx, &session.Inbound{Tag: s.tag})
 	for _, client := range s.sortClients(domain, option) {
-		ips, err := client.QueryIP(ctx, domain, option, s.cs)
+		ips, err := client.QueryIP(ctx, domain, *option, s.cs)
 		if len(ips) > 0 {
 			return ips, nil
 		}
@@ -244,7 +247,7 @@ func (s *DNS) lookupIPInternal(domain string, option dns.IPOption) ([]net.IP, er
 	return nil, newError("returning nil for domain ", domain).Base(errors.Combine(errs...))
 }
 
-func (s *DNS) sortClients(domain string, option dns.IPOption) []*Client {
+func (s *DNS) sortClients(domain string, option *dns.IPOption) []*Client {
 	clients := make([]*Client, 0, len(s.clients))
 	clientUsed := make([]bool, len(s.clients))
 	clientNames := make([]string, 0, len(s.clients))

+ 3 - 3
app/dns/hosts.go

@@ -74,7 +74,7 @@ func NewStaticHosts(hosts []*Config_HostMapping, legacy map[string]*net.IPOrDoma
 	return sh, nil
 }
 
-func filterIP(ips []net.Address, option dns.IPOption) []net.Address {
+func filterIP(ips []net.Address, option *dns.IPOption) []net.Address {
 	filtered := make([]net.Address, 0, len(ips))
 	for _, ip := range ips {
 		if (ip.Family().IsIPv4() && option.IPv4Enable) || (ip.Family().IsIPv6() && option.IPv6Enable) {
@@ -95,7 +95,7 @@ func (h *StaticHosts) lookupInternal(domain string) []net.Address {
 	return ips
 }
 
-func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) []net.Address {
+func (h *StaticHosts) lookup(domain string, option *dns.IPOption, maxDepth int) []net.Address {
 	switch addrs := h.lookupInternal(domain); {
 	case len(addrs) == 0: // Not recorded in static hosts, return nil
 		return nil
@@ -113,6 +113,6 @@ func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) [
 }
 
 // Lookup returns IP addresses or proxied domain for the given domain, if exists in this StaticHosts.
-func (h *StaticHosts) Lookup(domain string, option dns.IPOption) []net.Address {
+func (h *StaticHosts) Lookup(domain string, option *dns.IPOption) []net.Address {
 	return h.lookup(domain, option, 5)
 }

+ 3 - 3
app/dns/hosts_test.go

@@ -40,7 +40,7 @@ func TestStaticHosts(t *testing.T) {
 	common.Must(err)
 
 	{
-		ips := hosts.Lookup("example.com", dns.IPOption{
+		ips := hosts.Lookup("example.com", &dns.IPOption{
 			IPv4Enable: true,
 			IPv6Enable: true,
 		})
@@ -53,7 +53,7 @@ func TestStaticHosts(t *testing.T) {
 	}
 
 	{
-		ips := hosts.Lookup("www.example.cn", dns.IPOption{
+		ips := hosts.Lookup("www.example.cn", &dns.IPOption{
 			IPv4Enable: true,
 			IPv6Enable: true,
 		})
@@ -66,7 +66,7 @@ func TestStaticHosts(t *testing.T) {
 	}
 
 	{
-		ips := hosts.Lookup("baidu.com", dns.IPOption{
+		ips := hosts.Lookup("baidu.com", &dns.IPOption{
 			IPv4Enable: false,
 			IPv6Enable: true,
 		})

+ 3 - 11
app/dns/options.go

@@ -2,23 +2,15 @@ package dns
 
 import "github.com/xtls/xray-core/features/dns"
 
-type Option interface {
-	queryIPv4() bool
-	queryIPv6() bool
-	queryIP() bool
-	queryFake() bool
-	canDoQuery(c *Client) bool
-}
-
-func isIPQuery(o dns.IPOption) bool {
+func isIPQuery(o *dns.IPOption) bool {
 	return o.IPv4Enable || o.IPv6Enable
 }
 
-func canQueryOnClient(o dns.IPOption, c *Client) bool {
+func canQueryOnClient(o *dns.IPOption, c *Client) bool {
 	isIPClient := !(c.Name() == FakeDNSName)
 	return isIPClient && isIPQuery(o)
 }
 
-func isQuery(o dns.IPOption) bool {
+func isQuery(o *dns.IPOption) bool {
 	return !(o.IPv4Enable || o.IPv6Enable || o.FakeEnable)
 }

+ 37 - 7
features/dns/client.go

@@ -14,6 +14,12 @@ type IPOption struct {
 	FakeEnable bool
 }
 
+func (p *IPOption) Copy() *IPOption {
+	return &IPOption{p.IPv4Enable, p.IPv6Enable, p.FakeEnable}
+}
+
+type Option func(dopt *IPOption) *IPOption
+
 // Client is a Xray feature for querying DNS information.
 //
 // xray:api:stable
@@ -23,8 +29,8 @@ type Client interface {
 	// LookupIP returns IP address for the given domain. IPs may contain IPv4 and/or IPv6 addresses.
 	LookupIP(domain string) ([]net.IP, error)
 
-	// LookupOptions query IP address for domain with IPOption.
-	LookupOptions(domain string, opt IPOption) ([]net.IP, error)
+	// LookupOptions query IP address for domain with *IPOption.
+	LookupOptions(domain string, opt ...Option) ([]net.IP, error)
 }
 
 // IPv4Lookup is an optional feature for querying IPv4 addresses only.
@@ -69,9 +75,33 @@ func RCodeFromError(err error) uint16 {
 }
 
 var (
-	LookupIPv4 = IPOption{IPv4Enable: true}
-	LookupIPv6 = IPOption{IPv6Enable: true}
-	LookupIP   = IPOption{IPv4Enable: true, IPv6Enable: true}
-	LookupFake = IPOption{FakeEnable: true}
-	LookupAll  = IPOption{true, true, true}
+	LookupIPv4Only = func(d *IPOption) *IPOption {
+		d.IPv4Enable = true
+		d.IPv6Enable = false
+		return d
+	}
+	LookupIPv6Only = func(d *IPOption) *IPOption {
+		d.IPv4Enable = false
+		d.IPv6Enable = true
+		return d
+	}
+	LookupIP = func(d *IPOption) *IPOption {
+		d.IPv4Enable = true
+		d.IPv6Enable = true
+		return d
+	}
+	LookupFake = func(d *IPOption) *IPOption {
+		d.FakeEnable = true
+		return d
+	}
+	LookupNoFake = func(d *IPOption) *IPOption {
+		d.FakeEnable = false
+		return d
+	}
+
+	LookupAll = func(d *IPOption) *IPOption {
+		LookupIP(d)
+		LookupFake(d)
+		return d
+	}
 )

+ 1 - 1
features/dns/localdns/client.go

@@ -39,7 +39,7 @@ func (*Client) LookupIP(host string) ([]net.IP, error) {
 }
 
 // LookupOptions implements Client.
-func (c *Client) LookupOptions(host string, _ dns.IPOption) ([]net.IP, error) {
+func (c *Client) LookupOptions(host string, _ ...dns.Option) ([]net.IP, error) {
 	return c.LookupIP(host)
 }
 

+ 4 - 6
proxy/dns/dns.go

@@ -199,18 +199,16 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string,
 	var err error
 
 	var ttl uint32 = 600
-	var opt = dns.LookupIP
+	var opt dns.Option
 
 	switch qType {
 	case dnsmessage.TypeA:
-		opt = dns.LookupIPv4
+		opt = dns.LookupIPv4Only
 	case dnsmessage.TypeAAAA:
-		opt = dns.LookupIPv6
+		opt = dns.LookupIPv6Only
 	}
 
-	opt.FakeEnable = true
-
-	ips, err = h.client.LookupOptions(domain, opt)
+	ips, err = h.client.LookupOptions(domain, opt, dns.LookupFake)
 	rcode := dns.RCodeFromError(err)
 	if rcode == 0 && len(ips) == 0 && err != dns.ErrEmptyResponse {
 		newError("ip query").Base(err).WriteToLog()

+ 4 - 5
proxy/freedom/freedom.go

@@ -59,15 +59,14 @@ func (h *Handler) policy() policy.Session {
 }
 
 func (h *Handler) resolveIP(ctx context.Context, domain string, localAddr net.Address) net.Address {
-	var opt = dns.LookupIP
+	var opt dns.Option
 	if h.config.DomainStrategy == Config_USE_IP4 || (localAddr != nil && localAddr.Family().IsIPv4()) {
-		opt = dns.LookupIPv4
+		opt = dns.LookupIPv4Only
 	} else if h.config.DomainStrategy == Config_USE_IP6 || (localAddr != nil && localAddr.Family().IsIPv6()) {
-		opt = dns.LookupIPv6
+		opt = dns.LookupIPv6Only
 	}
-	opt.FakeEnable = true
 
-	ips, err := h.dns.LookupOptions(domain, opt)
+	ips, err := h.dns.LookupOptions(domain, opt, dns.LookupNoFake)
 	if err != nil {
 		newError("failed to get IP address for domain ", domain).Base(err).WriteToLog(session.ExportIDToError(ctx))
 	}

+ 9 - 4
testing/mocks/dns.go

@@ -65,18 +65,23 @@ func (mr *DNSClientMockRecorder) LookupIP(arg0 interface{}) *gomock.Call {
 }
 
 // LookupOptions mocks base method.
-func (m *DNSClient) LookupOptions(arg0 string, arg1 dns.IPOption) ([]net.IP, error) {
+func (m *DNSClient) LookupOptions(arg0 string, arg1 ...dns.Option) ([]net.IP, error) {
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "LookupOptions", arg0, arg1)
+	varargs := []interface{}{arg0}
+	for _, a := range arg1 {
+		varargs = append(varargs, a)
+	}
+	ret := m.ctrl.Call(m, "LookupOptions", varargs...)
 	ret0, _ := ret[0].([]net.IP)
 	ret1, _ := ret[1].(error)
 	return ret0, ret1
 }
 
 // LookupOptions indicates an expected call of LookupOptions.
-func (mr *DNSClientMockRecorder) LookupOptions(arg0, arg1 interface{}) *gomock.Call {
+func (mr *DNSClientMockRecorder) LookupOptions(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LookupOptions", reflect.TypeOf((*DNSClient)(nil).LookupOptions), arg0, arg1)
+	varargs := append([]interface{}{arg0}, arg1...)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LookupOptions", reflect.TypeOf((*DNSClient)(nil).LookupOptions), varargs...)
 }
 
 // Start mocks base method.

+ 4 - 4
transport/internet/system_dialer.go

@@ -63,17 +63,17 @@ func (d *DefaultSystemDialer) lookupIP(domain string, strategy DomainStrategy, l
 		return nil, nil
 	}
 
-	var opt = dns.LookupIP
+	var opt dns.Option
 	switch {
 	case strategy == DomainStrategy_USE_IP4 || (localAddr != nil && localAddr.Family().IsIPv4()):
-		opt = dns.LookupIPv4
+		opt = dns.LookupIPv4Only
 	case strategy == DomainStrategy_USE_IP6 || (localAddr != nil && localAddr.Family().IsIPv6()):
-		opt = dns.LookupIPv6
+		opt = dns.LookupIPv6Only
 	case strategy == DomainStrategy_AS_IS:
 		return nil, nil
 	}
 
-	return d.dns.LookupOptions(domain, opt)
+	return d.dns.LookupOptions(domain, opt, dns.LookupNoFake)
 }
 
 func (d *DefaultSystemDialer) canLookupIP(ctx context.Context, dst net.Destination, sockopt *SocketConfig) bool {