浏览代码

Refine DNS strategies

JimhHan 4 年之前
父节点
当前提交
726a722019

+ 99 - 41
app/dns/config.pb.go

@@ -128,6 +128,55 @@ func (QueryStrategy) EnumDescriptor() ([]byte, []int) {
 	return file_app_dns_config_proto_rawDescGZIP(), []int{1}
 }
 
+type CacheStrategy int32
+
+const (
+	CacheStrategy_Cache_ALL     CacheStrategy = 0
+	CacheStrategy_Cache_NOERROR CacheStrategy = 1
+	CacheStrategy_Cache_DISABLE CacheStrategy = 2
+)
+
+// Enum value maps for CacheStrategy.
+var (
+	CacheStrategy_name = map[int32]string{
+		0: "Cache_ALL",
+		1: "Cache_NOERROR",
+		2: "Cache_DISABLE",
+	}
+	CacheStrategy_value = map[string]int32{
+		"Cache_ALL":     0,
+		"Cache_NOERROR": 1,
+		"Cache_DISABLE": 2,
+	}
+)
+
+func (x CacheStrategy) Enum() *CacheStrategy {
+	p := new(CacheStrategy)
+	*p = x
+	return p
+}
+
+func (x CacheStrategy) String() string {
+	return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
+}
+
+func (CacheStrategy) Descriptor() protoreflect.EnumDescriptor {
+	return file_app_dns_config_proto_enumTypes[2].Descriptor()
+}
+
+func (CacheStrategy) Type() protoreflect.EnumType {
+	return &file_app_dns_config_proto_enumTypes[2]
+}
+
+func (x CacheStrategy) Number() protoreflect.EnumNumber {
+	return protoreflect.EnumNumber(x)
+}
+
+// Deprecated: Use CacheStrategy.Descriptor instead.
+func (CacheStrategy) EnumDescriptor() ([]byte, []int) {
+	return file_app_dns_config_proto_rawDescGZIP(), []int{2}
+}
+
 type NameServer struct {
 	state         protoimpl.MessageState
 	sizeCache     protoimpl.SizeCache
@@ -232,7 +281,7 @@ type Config struct {
 	// Tag is the inbound tag of DNS client.
 	Tag string `protobuf:"bytes,6,opt,name=tag,proto3" json:"tag,omitempty"`
 	// DisableCache Disable DNS cache
-	DisableCache  bool          `protobuf:"varint,8,opt,name=disableCache,proto3" json:"disableCache,omitempty"`
+	CacheStrategy CacheStrategy `protobuf:"varint,8,opt,name=cache_strategy,json=cacheStrategy,proto3,enum=xray.app.dns.CacheStrategy" json:"cache_strategy,omitempty"`
 	QueryStrategy QueryStrategy `protobuf:"varint,9,opt,name=query_strategy,json=queryStrategy,proto3,enum=xray.app.dns.QueryStrategy" json:"query_strategy,omitempty"`
 }
 
@@ -312,11 +361,11 @@ func (x *Config) GetTag() string {
 	return ""
 }
 
-func (x *Config) GetDisableCache() bool {
+func (x *Config) GetCacheStrategy() CacheStrategy {
 	if x != nil {
-		return x.DisableCache
+		return x.CacheStrategy
 	}
-	return false
+	return CacheStrategy_Cache_ALL
 }
 
 func (x *Config) GetQueryStrategy() QueryStrategy {
@@ -549,7 +598,7 @@ var file_app_dns_config_proto_rawDesc = []byte{
 	0x69, 0x67, 0x69, 0x6e, 0x61, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x72, 0x75,
 	0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x72, 0x75, 0x6c, 0x65, 0x12, 0x12,
 	0x0a, 0x04, 0x73, 0x69, 0x7a, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x73, 0x69,
-	0x7a, 0x65, 0x22, 0x8d, 0x05, 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3f, 0x0a,
+	0x7a, 0x65, 0x22, 0xad, 0x05, 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3f, 0x0a,
 	0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03,
 	0x28, 0x0b, 0x32, 0x19, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e,
 	0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x42, 0x02, 0x18,
@@ -568,9 +617,11 @@ var file_app_dns_config_proto_rawDesc = []byte{
 	0x70, 0x70, 0x2e, 0x64, 0x6e, 0x73, 0x2e, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x48, 0x6f,
 	0x73, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x52, 0x0b, 0x73, 0x74, 0x61, 0x74, 0x69,
 	0x63, 0x48, 0x6f, 0x73, 0x74, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x74, 0x61, 0x67, 0x18, 0x06, 0x20,
-	0x01, 0x28, 0x09, 0x52, 0x03, 0x74, 0x61, 0x67, 0x12, 0x22, 0x0a, 0x0c, 0x64, 0x69, 0x73, 0x61,
-	0x62, 0x6c, 0x65, 0x43, 0x61, 0x63, 0x68, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0c,
-	0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x43, 0x61, 0x63, 0x68, 0x65, 0x12, 0x42, 0x0a, 0x0e,
+	0x01, 0x28, 0x09, 0x52, 0x03, 0x74, 0x61, 0x67, 0x12, 0x42, 0x0a, 0x0e, 0x63, 0x61, 0x63, 0x68,
+	0x65, 0x5f, 0x73, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67, 0x79, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0e,
+	0x32, 0x1b, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x64, 0x6e, 0x73, 0x2e,
+	0x43, 0x61, 0x63, 0x68, 0x65, 0x53, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67, 0x79, 0x52, 0x0d, 0x63,
+	0x61, 0x63, 0x68, 0x65, 0x53, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67, 0x79, 0x12, 0x42, 0x0a, 0x0e,
 	0x71, 0x75, 0x65, 0x72, 0x79, 0x5f, 0x73, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67, 0x79, 0x18, 0x09,
 	0x20, 0x01, 0x28, 0x0e, 0x32, 0x1b, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x61, 0x70, 0x70, 0x2e,
 	0x64, 0x6e, 0x73, 0x2e, 0x51, 0x75, 0x65, 0x72, 0x79, 0x53, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67,
@@ -598,11 +649,16 @@ var file_app_dns_config_proto_rawDesc = []byte{
 	0x72, 0x79, 0x53, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67, 0x79, 0x12, 0x0a, 0x0a, 0x06, 0x55, 0x53,
 	0x45, 0x5f, 0x49, 0x50, 0x10, 0x00, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x53, 0x45, 0x5f, 0x49, 0x50,
 	0x34, 0x10, 0x01, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x53, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x10, 0x02,
-	0x42, 0x46, 0x0a, 0x10, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x61, 0x70, 0x70,
-	0x2e, 0x64, 0x6e, 0x73, 0x50, 0x01, 0x5a, 0x21, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63,
-	0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, 0x78, 0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72,
-	0x65, 0x2f, 0x61, 0x70, 0x70, 0x2f, 0x64, 0x6e, 0x73, 0xaa, 0x02, 0x0c, 0x58, 0x72, 0x61, 0x79,
-	0x2e, 0x41, 0x70, 0x70, 0x2e, 0x44, 0x6e, 0x73, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
+	0x2a, 0x44, 0x0a, 0x0d, 0x43, 0x61, 0x63, 0x68, 0x65, 0x53, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67,
+	0x79, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x61, 0x63, 0x68, 0x65, 0x5f, 0x41, 0x4c, 0x4c, 0x10, 0x00,
+	0x12, 0x11, 0x0a, 0x0d, 0x43, 0x61, 0x63, 0x68, 0x65, 0x5f, 0x4e, 0x4f, 0x45, 0x52, 0x52, 0x4f,
+	0x52, 0x10, 0x01, 0x12, 0x11, 0x0a, 0x0d, 0x43, 0x61, 0x63, 0x68, 0x65, 0x5f, 0x44, 0x49, 0x53,
+	0x41, 0x42, 0x4c, 0x45, 0x10, 0x02, 0x42, 0x46, 0x0a, 0x10, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72,
+	0x61, 0x79, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x64, 0x6e, 0x73, 0x50, 0x01, 0x5a, 0x21, 0x67, 0x69,
+	0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, 0x78, 0x72,
+	0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x61, 0x70, 0x70, 0x2f, 0x64, 0x6e, 0x73, 0xaa,
+	0x02, 0x0c, 0x58, 0x72, 0x61, 0x79, 0x2e, 0x41, 0x70, 0x70, 0x2e, 0x44, 0x6e, 0x73, 0x62, 0x06,
+	0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
 }
 
 var (
@@ -617,39 +673,41 @@ func file_app_dns_config_proto_rawDescGZIP() []byte {
 	return file_app_dns_config_proto_rawDescData
 }
 
-var file_app_dns_config_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
+var file_app_dns_config_proto_enumTypes = make([]protoimpl.EnumInfo, 3)
 var file_app_dns_config_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
 var file_app_dns_config_proto_goTypes = []interface{}{
 	(DomainMatchingType)(0),           // 0: xray.app.dns.DomainMatchingType
 	(QueryStrategy)(0),                // 1: xray.app.dns.QueryStrategy
-	(*NameServer)(nil),                // 2: xray.app.dns.NameServer
-	(*Config)(nil),                    // 3: xray.app.dns.Config
-	(*NameServer_PriorityDomain)(nil), // 4: xray.app.dns.NameServer.PriorityDomain
-	(*NameServer_OriginalRule)(nil),   // 5: xray.app.dns.NameServer.OriginalRule
-	nil,                               // 6: xray.app.dns.Config.HostsEntry
-	(*Config_HostMapping)(nil),        // 7: xray.app.dns.Config.HostMapping
-	(*net.Endpoint)(nil),              // 8: xray.common.net.Endpoint
-	(*router.GeoIP)(nil),              // 9: xray.app.router.GeoIP
-	(*net.IPOrDomain)(nil),            // 10: xray.common.net.IPOrDomain
+	(CacheStrategy)(0),                // 2: xray.app.dns.CacheStrategy
+	(*NameServer)(nil),                // 3: xray.app.dns.NameServer
+	(*Config)(nil),                    // 4: xray.app.dns.Config
+	(*NameServer_PriorityDomain)(nil), // 5: xray.app.dns.NameServer.PriorityDomain
+	(*NameServer_OriginalRule)(nil),   // 6: xray.app.dns.NameServer.OriginalRule
+	nil,                               // 7: xray.app.dns.Config.HostsEntry
+	(*Config_HostMapping)(nil),        // 8: xray.app.dns.Config.HostMapping
+	(*net.Endpoint)(nil),              // 9: xray.common.net.Endpoint
+	(*router.GeoIP)(nil),              // 10: xray.app.router.GeoIP
+	(*net.IPOrDomain)(nil),            // 11: xray.common.net.IPOrDomain
 }
 var file_app_dns_config_proto_depIdxs = []int32{
-	8,  // 0: xray.app.dns.NameServer.address:type_name -> xray.common.net.Endpoint
-	4,  // 1: xray.app.dns.NameServer.prioritized_domain:type_name -> xray.app.dns.NameServer.PriorityDomain
-	9,  // 2: xray.app.dns.NameServer.geoip:type_name -> xray.app.router.GeoIP
-	5,  // 3: xray.app.dns.NameServer.original_rules:type_name -> xray.app.dns.NameServer.OriginalRule
-	8,  // 4: xray.app.dns.Config.NameServers:type_name -> xray.common.net.Endpoint
-	2,  // 5: xray.app.dns.Config.name_server:type_name -> xray.app.dns.NameServer
-	6,  // 6: xray.app.dns.Config.Hosts:type_name -> xray.app.dns.Config.HostsEntry
-	7,  // 7: xray.app.dns.Config.static_hosts:type_name -> xray.app.dns.Config.HostMapping
-	1,  // 8: xray.app.dns.Config.query_strategy:type_name -> xray.app.dns.QueryStrategy
-	0,  // 9: xray.app.dns.NameServer.PriorityDomain.type:type_name -> xray.app.dns.DomainMatchingType
-	10, // 10: xray.app.dns.Config.HostsEntry.value:type_name -> xray.common.net.IPOrDomain
-	0,  // 11: xray.app.dns.Config.HostMapping.type:type_name -> xray.app.dns.DomainMatchingType
-	12, // [12:12] is the sub-list for method output_type
-	12, // [12:12] is the sub-list for method input_type
-	12, // [12:12] is the sub-list for extension type_name
-	12, // [12:12] is the sub-list for extension extendee
-	0,  // [0:12] is the sub-list for field type_name
+	9,  // 0: xray.app.dns.NameServer.address:type_name -> xray.common.net.Endpoint
+	5,  // 1: xray.app.dns.NameServer.prioritized_domain:type_name -> xray.app.dns.NameServer.PriorityDomain
+	10, // 2: xray.app.dns.NameServer.geoip:type_name -> xray.app.router.GeoIP
+	6,  // 3: xray.app.dns.NameServer.original_rules:type_name -> xray.app.dns.NameServer.OriginalRule
+	9,  // 4: xray.app.dns.Config.NameServers:type_name -> xray.common.net.Endpoint
+	3,  // 5: xray.app.dns.Config.name_server:type_name -> xray.app.dns.NameServer
+	7,  // 6: xray.app.dns.Config.Hosts:type_name -> xray.app.dns.Config.HostsEntry
+	8,  // 7: xray.app.dns.Config.static_hosts:type_name -> xray.app.dns.Config.HostMapping
+	2,  // 8: xray.app.dns.Config.cache_strategy:type_name -> xray.app.dns.CacheStrategy
+	1,  // 9: xray.app.dns.Config.query_strategy:type_name -> xray.app.dns.QueryStrategy
+	0,  // 10: xray.app.dns.NameServer.PriorityDomain.type:type_name -> xray.app.dns.DomainMatchingType
+	11, // 11: xray.app.dns.Config.HostsEntry.value:type_name -> xray.common.net.IPOrDomain
+	0,  // 12: xray.app.dns.Config.HostMapping.type:type_name -> xray.app.dns.DomainMatchingType
+	13, // [13:13] is the sub-list for method output_type
+	13, // [13:13] is the sub-list for method input_type
+	13, // [13:13] is the sub-list for extension type_name
+	13, // [13:13] is the sub-list for extension extendee
+	0,  // [0:13] is the sub-list for field type_name
 }
 
 func init() { file_app_dns_config_proto_init() }
@@ -724,7 +782,7 @@ func file_app_dns_config_proto_init() {
 		File: protoimpl.DescBuilder{
 			GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
 			RawDescriptor: file_app_dns_config_proto_rawDesc,
-			NumEnums:      2,
+			NumEnums:      3,
 			NumMessages:   6,
 			NumExtensions: 0,
 			NumServices:   0,

+ 7 - 1
app/dns/config.proto

@@ -42,6 +42,12 @@ enum QueryStrategy {
   USE_IP6 = 2;
 }
 
+enum CacheStrategy {
+  Cache_ALL = 0;
+  Cache_NOERROR = 1;
+  Cache_DISABLE = 2;
+}
+
 message Config {
   // Nameservers used by this DNS. Only traditional UDP servers are support at
   // the moment. A special value 'localhost' as a domain address can be set to
@@ -79,7 +85,7 @@ message Config {
   reserved 7;
 
   // DisableCache Disable DNS cache
-  bool disableCache = 8;
+  CacheStrategy cache_strategy = 8;
 
   QueryStrategy query_strategy = 9;
 }

+ 37 - 38
app/dns/dns.go

@@ -23,7 +23,7 @@ import (
 type DNS struct {
 	sync.Mutex
 	tag           string
-	disableCache  bool
+	cs            CacheStrategy
 	ipOption      *dns.IPOption
 	hosts         *StaticHosts
 	clients       []*Client
@@ -138,7 +138,7 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
 		ctx:           ctx,
 		domainMatcher: domainMatcher,
 		matcherInfos:  matcherInfos,
-		disableCache:  config.DisableCache,
+		cs:            config.CacheStrategy,
 	}, nil
 }
 
@@ -165,11 +165,12 @@ func (s *DNS) IsOwnLink(ctx context.Context) bool {
 
 // LookupIP implements dns.Client.
 func (s *DNS) LookupIP(domain string) ([]net.IP, error) {
-	return s.lookupIPInternal(domain, dns.IPOption{
-		IPv4Enable: true,
-		IPv6Enable: true,
-		FakeEnable: false,
-	})
+	return s.lookupIPInternal(domain, *s.ipOption)
+}
+
+// LookupOptions implements dns.Client.
+func (s *DNS) LookupOptions(domain string, opt dns.IPOption) ([]net.IP, error) {
+	return s.lookupIPInternal(domain, opt)
 }
 
 // LookupIPv4 implements dns.IPv4Lookup.
@@ -194,6 +195,9 @@ func (s *DNS) lookupIPInternal(domain string, option dns.IPOption) ([]net.IP, er
 	if domain == "" {
 		return nil, newError("empty domain name")
 	}
+	if isQuery(option) {
+		return nil, newError("empty option: I'm pretty sure it shouldn't happened.")
+	}
 
 	// Normalize the FQDN form query
 	if strings.HasSuffix(domain, ".") {
@@ -210,19 +214,18 @@ func (s *DNS) lookupIPInternal(domain string, option dns.IPOption) ([]net.IP, er
 		newError("domain replaced: ", domain, " -> ", addrs[0].Domain()).WriteToLog()
 		domain = addrs[0].Domain()
 	default: // Successfully found ip records in static host
-		newError("returning ", len(addrs), " IPs for domain ", domain).WriteToLog()
-		return toNetIP(addrs)
+		if isIPQuery(option) {
+			// maybe our client prefer to query fake dns -_-
+			newError("returning ", len(addrs), " IPs for domain ", domain).WriteToLog()
+			return toNetIP(addrs)
+		}
 	}
 
 	// Name servers lookup
 	errs := []error{}
 	ctx := session.ContextWithInbound(s.ctx, &session.Inbound{Tag: s.tag})
-	for _, client := range s.sortClients(domain) {
-		if !option.FakeEnable && strings.EqualFold(client.Name(), "FakeDNS") {
-			newError("skip DNS resolution for domain ", domain, " at server ", client.Name()).AtDebug().WriteToLog()
-			continue
-		}
-		ips, err := client.QueryIP(ctx, domain, option, s.disableCache)
+	for _, client := range s.sortClients(domain, option) {
+		ips, err := client.QueryIP(ctx, domain, option, s.cs)
 		if len(ips) > 0 {
 			return ips, nil
 		}
@@ -238,33 +241,31 @@ func (s *DNS) lookupIPInternal(domain string, option dns.IPOption) ([]net.IP, er
 	return nil, newError("returning nil for domain ", domain).Base(errors.Combine(errs...))
 }
 
-// GetIPOption implements ClientWithIPOption.
-func (s *DNS) GetIPOption() *dns.IPOption {
-	return s.ipOption
-}
-
-// SetQueryOption implements ClientWithIPOption.
-func (s *DNS) SetQueryOption(isIPv4Enable, isIPv6Enable bool) {
-	s.ipOption.IPv4Enable = isIPv4Enable
-	s.ipOption.IPv6Enable = isIPv6Enable
-}
-
-// SetFakeDNSOption implements ClientWithIPOption.
-func (s *DNS) SetFakeDNSOption(isFakeEnable bool) {
-	s.ipOption.FakeEnable = isFakeEnable
-}
-
-func (s *DNS) sortClients(domain string) []*Client {
+func (s *DNS) sortClients(domain string, option dns.IPOption) []*Client {
 	clients := make([]*Client, 0, len(s.clients))
 	clientUsed := make([]bool, len(s.clients))
 	clientNames := make([]string, 0, len(s.clients))
 	domainRules := []string{}
 
+	defer func() {
+		if len(domainRules) > 0 {
+			newError("domain ", domain, " matches following rules: ", domainRules).AtDebug().WriteToLog()
+		}
+		if len(clientNames) > 0 {
+			newError("domain ", domain, " will use DNS in order: ", clientNames).AtDebug().WriteToLog()
+		}
+	}()
+
 	// Priority domain matching
 	for _, match := range s.domainMatcher.Match(domain) {
 		info := s.matcherInfos[match]
 		client := s.clients[info.clientIdx]
 		domainRule := client.domains[info.domainRuleIdx]
+		if !canQueryOnClient(option, client) {
+			newError("skipping the client " + client.Name()).AtDebug().WriteToLog()
+			continue
+		}
+
 		domainRules = append(domainRules, fmt.Sprintf("%s(DNS idx:%d)", domainRule, info.clientIdx))
 		if clientUsed[info.clientIdx] {
 			continue
@@ -279,17 +280,15 @@ func (s *DNS) sortClients(domain string) []*Client {
 		if clientUsed[idx] {
 			continue
 		}
+		if !canQueryOnClient(option, client) {
+			newError("skipping the client " + client.Name()).AtDebug().WriteToLog()
+			continue
+		}
 		clientUsed[idx] = true
 		clients = append(clients, client)
 		clientNames = append(clientNames, client.Name())
 	}
 
-	if len(domainRules) > 0 {
-		newError("domain ", domain, " matches following rules: ", domainRules).AtDebug().WriteToLog()
-	}
-	if len(clientNames) > 0 {
-		newError("domain ", domain, " will use DNS in order: ", clientNames).AtDebug().WriteToLog()
-	}
 	return clients
 }
 

+ 3 - 3
app/dns/nameserver.go

@@ -20,7 +20,7 @@ type Server interface {
 	// Name of the Client.
 	Name() string
 	// QueryIP sends IP queries to its configured server.
-	QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns.IPOption, disableCache bool) ([]net.IP, error)
+	QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns.IPOption, cs CacheStrategy) ([]net.IP, error)
 }
 
 // Client is the interface for DNS client.
@@ -177,9 +177,9 @@ func (c *Client) Name() string {
 }
 
 // QueryIP send DNS query to the name server with the client's IP.
-func (c *Client) QueryIP(ctx context.Context, domain string, option dns.IPOption, disableCache bool) ([]net.IP, error) {
+func (c *Client) QueryIP(ctx context.Context, domain string, option dns.IPOption, cs CacheStrategy) ([]net.IP, error) {
 	ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
-	ips, err := c.server.QueryIP(ctx, domain, c.clientIP, option, disableCache)
+	ips, err := c.server.QueryIP(ctx, domain, c.clientIP, option, cs)
 	cancel()
 
 	if err != nil {

+ 7 - 5
app/dns/nameserver_doh.go

@@ -369,17 +369,19 @@ func (s *DoHNameServer) findIPsForDomain(domain string, option dns_feature.IPOpt
 }
 
 // QueryIP implements Server.
-func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) { // nolint: dupl
+func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, cs CacheStrategy) ([]net.IP, error) { // nolint: dupl
 	fqdn := Fqdn(domain)
 
-	if disableCache {
+	if cs == CacheStrategy_Cache_DISABLE {
 		newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog()
 	} else {
 		ips, err := s.findIPsForDomain(fqdn, option)
 		if err != errRecordNotFound {
-			newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
-			log.Record(&log.DNSLog{s.name, domain, ips, log.DNSCacheHit, 0, err})
-			return ips, err
+			if cs == CacheStrategy_Cache_NOERROR && err == nil {
+				newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
+				log.Record(&log.DNSLog{s.name, domain, ips, log.DNSCacheHit, 0, err})
+				return ips, err
+			}
 		}
 	}
 

+ 3 - 3
app/dns/nameserver_doh_test.go

@@ -23,7 +23,7 @@ func TestDOHNameServer(t *testing.T) {
 	ips, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
 		IPv4Enable: true,
 		IPv6Enable: true,
-	}, false)
+	}, CacheStrategy_Cache_ALL)
 	cancel()
 	common.Must(err)
 	if len(ips) == 0 {
@@ -40,7 +40,7 @@ func TestDOHNameServerWithCache(t *testing.T) {
 	ips, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
 		IPv4Enable: true,
 		IPv6Enable: true,
-	}, false)
+	}, CacheStrategy_Cache_ALL)
 	cancel()
 	common.Must(err)
 	if len(ips) == 0 {
@@ -51,7 +51,7 @@ func TestDOHNameServerWithCache(t *testing.T) {
 	ips2, err := s.QueryIP(ctx2, "google.com", net.IP(nil), dns_feature.IPOption{
 		IPv4Enable: true,
 		IPv6Enable: true,
-	}, true)
+	}, CacheStrategy_Cache_ALL)
 	cancel()
 	common.Must(err)
 	if r := cmp.Diff(ips2, ips); r != "" {

+ 4 - 2
app/dns/nameserver_fakedns.go

@@ -16,11 +16,13 @@ func NewFakeDNSServer() *FakeDNSServer {
 	return &FakeDNSServer{}
 }
 
+const FakeDNSName = "FakeDNS"
+
 func (FakeDNSServer) Name() string {
-	return "FakeDNS"
+	return FakeDNSName
 }
 
-func (f *FakeDNSServer) QueryIP(ctx context.Context, domain string, _ net.IP, _ dns.IPOption, _ bool) ([]net.IP, error) {
+func (f *FakeDNSServer) QueryIP(ctx context.Context, domain string, _ net.IP, _ dns.IPOption, _ CacheStrategy) ([]net.IP, error) {
 	if f.fakeDNSEngine == nil {
 		if err := core.RequireFeatures(ctx, func(fd dns.FakeDNSEngine) {
 			f.fakeDNSEngine = fd

+ 1 - 1
app/dns/nameserver_local.go

@@ -14,7 +14,7 @@ type LocalNameServer struct {
 }
 
 // QueryIP implements Server.
-func (s *LocalNameServer) QueryIP(_ context.Context, domain string, _ net.IP, option dns.IPOption, _ bool) ([]net.IP, error) {
+func (s *LocalNameServer) QueryIP(_ context.Context, domain string, _ net.IP, option dns.IPOption, _ CacheStrategy) ([]net.IP, error) {
 	var ips []net.IP
 	var err error
 

+ 1 - 1
app/dns/nameserver_local_test.go

@@ -17,7 +17,7 @@ func TestLocalNameServer(t *testing.T) {
 	ips, err := s.QueryIP(ctx, "google.com", net.IP{}, dns.IPOption{
 		IPv4Enable: true,
 		IPv6Enable: true,
-	}, false)
+	}, CacheStrategy_Cache_ALL)
 	cancel()
 	common.Must(err)
 	if len(ips) == 0 {

+ 7 - 5
app/dns/nameserver_quic.go

@@ -268,17 +268,19 @@ func (s *QUICNameServer) findIPsForDomain(domain string, option dns_feature.IPOp
 }
 
 // QueryIP is called from dns.Server->queryIPTimeout
-func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) {
+func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, cs CacheStrategy) ([]net.IP, error) {
 	fqdn := Fqdn(domain)
 
-	if disableCache {
+	if cs == CacheStrategy_Cache_DISABLE {
 		newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog()
 	} else {
 		ips, err := s.findIPsForDomain(fqdn, option)
 		if err != errRecordNotFound {
-			newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
-			log.Record(&log.DNSLog{s.name, domain, ips, log.DNSCacheHit, 0, err})
-			return ips, err
+			if cs == CacheStrategy_Cache_NOERROR && err == nil {
+				newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
+				log.Record(&log.DNSLog{s.name, domain, ips, log.DNSCacheHit, 0, err})
+				return ips, err
+			}
 		}
 	}
 

+ 3 - 3
app/dns/nameserver_quic_test.go

@@ -23,7 +23,7 @@ func TestQUICNameServer(t *testing.T) {
 	ips, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
 		IPv4Enable: true,
 		IPv6Enable: true,
-	}, false)
+	}, CacheStrategy_Cache_ALL)
 	cancel()
 	common.Must(err)
 	if len(ips) == 0 {
@@ -40,7 +40,7 @@ func TestQUICNameServerWithCache(t *testing.T) {
 	ips, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
 		IPv4Enable: true,
 		IPv6Enable: true,
-	}, false)
+	}, CacheStrategy_Cache_ALL)
 	cancel()
 	common.Must(err)
 	if len(ips) == 0 {
@@ -51,7 +51,7 @@ func TestQUICNameServerWithCache(t *testing.T) {
 	ips2, err := s.QueryIP(ctx2, "google.com", net.IP(nil), dns_feature.IPOption{
 		IPv4Enable: true,
 		IPv6Enable: true,
-	}, true)
+	}, CacheStrategy_Cache_ALL)
 	cancel()
 	common.Must(err)
 	if r := cmp.Diff(ips2, ips); r != "" {

+ 7 - 5
app/dns/nameserver_udp.go

@@ -245,17 +245,19 @@ func (s *ClassicNameServer) findIPsForDomain(domain string, option dns_feature.I
 }
 
 // QueryIP implements Server.
-func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) {
+func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, cs CacheStrategy) ([]net.IP, error) {
 	fqdn := Fqdn(domain)
 
-	if disableCache {
+	if cs == CacheStrategy_Cache_DISABLE {
 		newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog()
 	} else {
 		ips, err := s.findIPsForDomain(fqdn, option)
 		if err != errRecordNotFound {
-			newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
-			log.Record(&log.DNSLog{s.name, domain, ips, log.DNSCacheHit, 0, err})
-			return ips, err
+			if cs == CacheStrategy_Cache_NOERROR && err == nil {
+				newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
+				log.Record(&log.DNSLog{s.name, domain, ips, log.DNSCacheHit, 0, err})
+				return ips, err
+			}
 		}
 	}
 

+ 24 - 0
app/dns/options.go

@@ -0,0 +1,24 @@
+package dns
+
+import "github.com/xtls/xray-core/features/dns"
+
+type Option interface {
+	queryIPv4() bool
+	queryIPv6() bool
+	queryIP() bool
+	queryFake() bool
+	canDoQuery(c *Client) bool
+}
+
+func isIPQuery(o dns.IPOption) bool {
+	return o.IPv4Enable || o.IPv6Enable
+}
+
+func canQueryOnClient(o dns.IPOption, c *Client) bool {
+	isIPClient := !(c.Name() == FakeDNSName)
+	return isIPClient && isIPQuery(o)
+}
+
+func isQuery(o dns.IPOption) bool {
+	return !(o.IPv4Enable || o.IPv6Enable || o.FakeEnable)
+}

+ 11 - 14
features/dns/client.go

@@ -22,6 +22,9 @@ type Client interface {
 
 	// LookupIP returns IP address for the given domain. IPs may contain IPv4 and/or IPv6 addresses.
 	LookupIP(domain string) ([]net.IP, error)
+
+	// LookupOptions query IP address for domain with IPOption.
+	LookupOptions(domain string, opt IPOption) ([]net.IP, error)
 }
 
 // IPv4Lookup is an optional feature for querying IPv4 addresses only.
@@ -38,20 +41,6 @@ type IPv6Lookup interface {
 	LookupIPv6(domain string) ([]net.IP, error)
 }
 
-// ClientWithIPOption is an optional feature for querying DNS information.
-//
-// xray:api:beta
-type ClientWithIPOption interface {
-	// GetIPOption returns IPOption for the DNS client.
-	GetIPOption() *IPOption
-
-	// SetQueryOption sets IPv4Enable and IPv6Enable for the DNS client.
-	SetQueryOption(isIPv4Enable, isIPv6Enable bool)
-
-	// SetFakeDNSOption sets FakeEnable option for DNS client.
-	SetFakeDNSOption(isFakeEnable bool)
-}
-
 // ClientType returns the type of Client interface. Can be used for implementing common.HasType.
 //
 // xray:api:beta
@@ -78,3 +67,11 @@ func RCodeFromError(err error) uint16 {
 	}
 	return 0
 }
+
+var (
+	LookupIPv4 = IPOption{IPv4Enable: true}
+	LookupIPv6 = IPOption{IPv6Enable: true}
+	LookupIP   = IPOption{IPv4Enable: true, IPv6Enable: true}
+	LookupFake = IPOption{FakeEnable: true}
+	LookupAll  = IPOption{true, true, true}
+)

+ 5 - 0
features/dns/localdns/client.go

@@ -38,6 +38,11 @@ func (*Client) LookupIP(host string) ([]net.IP, error) {
 	return parsedIPs, nil
 }
 
+// LookupOptions implements Client.
+func (c *Client) LookupOptions(host string, _ dns.IPOption) ([]net.IP, error) {
+	return c.LookupIP(host)
+}
+
 // LookupIPv4 implements IPv4Lookup.
 func (c *Client) LookupIPv4(host string) ([]net.IP, error) {
 	ips, err := c.LookupIP(host)

+ 2 - 30
features/routing/dns/context.go

@@ -26,40 +26,12 @@ func (ctx *ResolvableContext) GetTargetIPs() []net.IP {
 	}
 
 	if domain := ctx.GetTargetDomain(); len(domain) != 0 {
-		var lookupFunc func(string) ([]net.IP, error) = ctx.dnsClient.LookupIP
-		ipOption := &dns.IPOption{
-			IPv4Enable: true,
-			IPv6Enable: true,
-		}
-
-		if c, ok := ctx.dnsClient.(dns.ClientWithIPOption); ok {
-			ipOption = c.GetIPOption()
-			c.SetFakeDNSOption(false) // Skip FakeDNS.
-		} else {
-			newError("ctx.dnsClient doesn't implement ClientWithIPOption").AtDebug().WriteToLog()
-		}
-
-		switch {
-		case ipOption.IPv4Enable && !ipOption.IPv6Enable:
-			if lookupIPv4, ok := ctx.dnsClient.(dns.IPv4Lookup); ok {
-				lookupFunc = lookupIPv4.LookupIPv4
-			} else {
-				newError("ctx.dnsClient doesn't implement IPv4Lookup. Use LookupIP instead.").AtDebug().WriteToLog()
-			}
-		case !ipOption.IPv4Enable && ipOption.IPv6Enable:
-			if lookupIPv6, ok := ctx.dnsClient.(dns.IPv6Lookup); ok {
-				lookupFunc = lookupIPv6.LookupIPv6
-			} else {
-				newError("ctx.dnsClient doesn't implement IPv6Lookup. Use LookupIP instead.").AtDebug().WriteToLog()
-			}
-		}
-
-		ips, err := lookupFunc(domain)
+		ips, err := ctx.dnsClient.LookupIP(domain)
 		if err == nil {
 			ctx.resolvedIPs = ips
 			return ips
 		}
-		newError("resolve ip for ", domain).Base(err).WriteToLog()
+		newError("failed to resolve ip for ", domain).Base(err).WriteToLog()
 	}
 
 	return nil

+ 16 - 2
infra/conf/dns.go

@@ -124,6 +124,7 @@ type DNSConfig struct {
 	ClientIP      *Address            `json:"clientIp"`
 	Tag           string              `json:"tag"`
 	QueryStrategy string              `json:"queryStrategy"`
+	CacheStrategy string              `json:"cacheStrategy"`
 	DisableCache  bool                `json:"disableCache"`
 }
 
@@ -142,8 +143,12 @@ func getHostMapping(addr *Address) *dns.Config_HostMapping {
 // Build implements Buildable
 func (c *DNSConfig) Build() (*dns.Config, error) {
 	config := &dns.Config{
-		Tag:          c.Tag,
-		DisableCache: c.DisableCache,
+		Tag:           c.Tag,
+		CacheStrategy: dns.CacheStrategy_Cache_ALL,
+	}
+
+	if c.DisableCache {
+		config.CacheStrategy = dns.CacheStrategy_Cache_DISABLE
 	}
 
 	if c.ClientIP != nil {
@@ -163,6 +168,15 @@ func (c *DNSConfig) Build() (*dns.Config, error) {
 		config.QueryStrategy = dns.QueryStrategy_USE_IP6
 	}
 
+	switch strings.ToLower(c.CacheStrategy) {
+	case "noerror":
+		config.CacheStrategy = dns.CacheStrategy_Cache_NOERROR
+	case "all":
+		config.CacheStrategy = dns.CacheStrategy_Cache_ALL
+	case "disable", "none":
+		config.CacheStrategy = dns.CacheStrategy_Cache_DISABLE
+	}
+
 	for _, server := range c.Servers {
 		ns, err := server.Build()
 		if err != nil {

+ 1 - 1
infra/conf/dns_test.go

@@ -139,7 +139,7 @@ func TestDNSConfigParsing(t *testing.T) {
 				},
 				ClientIp:      []byte{10, 0, 0, 1},
 				QueryStrategy: dns.QueryStrategy_USE_IP4,
-				DisableCache:  true,
+				CacheStrategy: dns.CacheStrategy_Cache_DISABLE,
 			},
 		},
 	})

+ 6 - 23
proxy/dns/dns.go

@@ -37,8 +37,6 @@ type ownLinkVerifier interface {
 
 type Handler struct {
 	client          dns.Client
-	ipv4Lookup      dns.IPv4Lookup
-	ipv6Lookup      dns.IPv6Lookup
 	ownLinkVerifier ownLinkVerifier
 	server          net.Destination
 }
@@ -46,18 +44,6 @@ type Handler struct {
 func (h *Handler) Init(config *Config, dnsClient dns.Client) error {
 	h.client = dnsClient
 
-	if ipv4lookup, ok := dnsClient.(dns.IPv4Lookup); ok {
-		h.ipv4Lookup = ipv4lookup
-	} else {
-		return newError("dns.Client doesn't implement IPv4Lookup")
-	}
-
-	if ipv6lookup, ok := dnsClient.(dns.IPv6Lookup); ok {
-		h.ipv6Lookup = ipv6lookup
-	} else {
-		return newError("dns.Client doesn't implement IPv6Lookup")
-	}
-
 	if v, ok := dnsClient.(ownLinkVerifier); ok {
 		h.ownLinkVerifier = v
 	}
@@ -213,21 +199,18 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string,
 	var err error
 
 	var ttl uint32 = 600
-
-	// Do NOT skip FakeDNS
-	if c, ok := h.client.(dns.ClientWithIPOption); ok {
-		c.SetFakeDNSOption(true)
-	} else {
-		newError("dns.Client doesn't implement ClientWithIPOption")
-	}
+	var opt = dns.LookupIP
 
 	switch qType {
 	case dnsmessage.TypeA:
-		ips, err = h.ipv4Lookup.LookupIPv4(domain)
+		opt = dns.LookupIPv4
 	case dnsmessage.TypeAAAA:
-		ips, err = h.ipv6Lookup.LookupIPv6(domain)
+		opt = dns.LookupIPv6
 	}
 
+	opt.FakeEnable = true
+
+	ips, err = h.client.LookupOptions(domain, opt)
 	rcode := dns.RCodeFromError(err)
 	if rcode == 0 && len(ips) == 0 && err != dns.ErrEmptyResponse {
 		newError("ip query").Base(err).WriteToLog()

+ 5 - 14
proxy/freedom/freedom.go

@@ -59,24 +59,15 @@ func (h *Handler) policy() policy.Session {
 }
 
 func (h *Handler) resolveIP(ctx context.Context, domain string, localAddr net.Address) net.Address {
-	if c, ok := h.dns.(dns.ClientWithIPOption); ok {
-		c.SetFakeDNSOption(false) // Skip FakeDNS
-	} else {
-		newError("DNS client doesn't implement ClientWithIPOption")
-	}
-
-	var lookupFunc func(string) ([]net.IP, error) = h.dns.LookupIP
+	var opt = dns.LookupIP
 	if h.config.DomainStrategy == Config_USE_IP4 || (localAddr != nil && localAddr.Family().IsIPv4()) {
-		if lookupIPv4, ok := h.dns.(dns.IPv4Lookup); ok {
-			lookupFunc = lookupIPv4.LookupIPv4
-		}
+		opt = dns.LookupIPv4
 	} else if h.config.DomainStrategy == Config_USE_IP6 || (localAddr != nil && localAddr.Family().IsIPv6()) {
-		if lookupIPv6, ok := h.dns.(dns.IPv6Lookup); ok {
-			lookupFunc = lookupIPv6.LookupIPv6
-		}
+		opt = dns.LookupIPv6
 	}
+	opt.FakeEnable = true
 
-	ips, err := lookupFunc(domain)
+	ips, err := h.dns.LookupOptions(domain, opt)
 	if err != nil {
 		newError("failed to get IP address for domain ", domain).Base(err).WriteToLog(session.ExportIDToError(ctx))
 	}

+ 6 - 16
transport/internet/system_dialer.go

@@ -63,30 +63,20 @@ func (d *DefaultSystemDialer) lookupIP(domain string, strategy DomainStrategy, l
 		return nil, nil
 	}
 
-	if c, ok := d.dns.(dns.ClientWithIPOption); ok {
-		c.SetFakeDNSOption(false) // Skip FakeDNS
-	} else {
-		newError("DNS client doesn't implement ClientWithIPOption")
-	}
-
-	var lookupFunc func(string) ([]net.IP, error) = d.dns.LookupIP
+	var opt = dns.LookupIP
 	switch {
 	case strategy == DomainStrategy_USE_IP4 || (localAddr != nil && localAddr.Family().IsIPv4()):
-		if lookupIPv4, ok := d.dns.(dns.IPv4Lookup); ok {
-			lookupFunc = lookupIPv4.LookupIPv4
-		}
+		opt = dns.LookupIPv4
 	case strategy == DomainStrategy_USE_IP6 || (localAddr != nil && localAddr.Family().IsIPv6()):
-		if lookupIPv6, ok := d.dns.(dns.IPv6Lookup); ok {
-			lookupFunc = lookupIPv6.LookupIPv6
-		}
+		opt = dns.LookupIPv6
 	case strategy == DomainStrategy_AS_IS:
 		return nil, nil
 	}
 
-	return lookupFunc(domain)
+	return d.dns.LookupOptions(domain, opt)
 }
 
-func (d *DefaultSystemDialer) doLookupIP(ctx context.Context, dst net.Destination, sockopt *SocketConfig) bool {
+func (d *DefaultSystemDialer) canLookupIP(ctx context.Context, dst net.Destination, sockopt *SocketConfig) bool {
 	if sockopt == nil || dst.Address.Family().IsIP() || d.dns == nil {
 		return false
 	}
@@ -121,7 +111,7 @@ func (d *DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest ne
 		}
 	}
 
-	if d.doLookupIP(ctx, dest, sockopt) {
+	if d.canLookupIP(ctx, dest, sockopt) {
 		ips, err := d.lookupIP(dest.Address.String(), sockopt.DomainStrategy, src)
 		if err == nil && len(ips) > 0 {
 			dest.Address = net.IPAddress(ips[dice.Roll(len(ips))])