Browse Source

domain-rules: support rule on root domain.

Nick Peng 2 years ago
parent
commit
4d396ff688
5 changed files with 92 additions and 58 deletions
  1. 10 4
      etc/smartdns/smartdns.conf
  2. 23 28
      src/dns_conf.c
  3. 0 2
      src/dns_conf.h
  4. 5 24
      src/dns_server.c
  5. 54 0
      test/cases/test-rule.cc

+ 10 - 4
etc/smartdns/smartdns.conf

@@ -258,7 +258,8 @@ log-level info
 #   proxy-server http://user:[email protected]:3128 -name proxy
 
 # specific nameserver to domain
-# nameserver /domain/[group|-]
+# nameserver [/domain/][group|-]
+# nameserer group, set the domain name to use the appropriate server group.
 # nameserver /www.example.com/office, Set the domain name to use the appropriate server group.
 # nameserver /www.example.com/-, ignore this domain
 
@@ -266,7 +267,10 @@ log-level info
 # expand-ptr-from-address yes
 
 # specific address to domain
-# address /domain/[ip1,ip2|-|-4|-6|#|#4|#6]
+# address [/domain/][ip1,ip2|-|-4|-6|#|#4|#6]
+# address #, block all A and AAAA request.
+# address #6, block all AAAA request.
+# address -6, allow all AAAA request.
 # address /www.example.com/1.2.3.4, return ip 1.2.3.4 to client
 # address /www.example.com/1.2.3.4,5.6.7.8, return multiple ip addresses
 # address /www.example.com/-, ignore address, query from upstream, suffix 4, for ipv4, 6 for ipv6, none for all
@@ -288,10 +292,11 @@ log-level info
 # ipset-timeout [yes]
 
 # specific ipset to domain
-# ipset /domain/[ipsetname|#4:v4setname|#6:v6setname|-|#4:-|#6:-]
+# ipset [/domain/][ipsetname|#4:v4setname|#6:v6setname|-|#4:-|#6:-]
 # ipset [ipsetname|#4:v4setname|#6:v6setname], set global ipset.
 # ipset /www.example.com/block, set ipset with ipset name of block. 
 # ipset /www.example.com/-, ignore this domain.
+# ipset ipsetname, set global ipset.
 
 # add to ipset when ping is unreachable
 # ipset-no-speed ipsetname
@@ -310,11 +315,12 @@ log-level info
 # nftset-debug yes
 
 # specific nftset to domain
-# nftset /domain/[#4:ip#table#set,#6:ipv6#table#setv6]
+# nftset [/domain/][#4:ip#table#set,#6:ipv6#table#setv6]
 # nftset [#4:ip#table#set,#6:ipv6#table#setv6] set global nftset.
 # nftset /www.example.com/ip#table#set, equivalent to 'nft add element ip table set { ... }'
 # nftset /www.example.com/-, ignore this domain
 # nftset /www.example.com/#6:-, ignore ipv6
+# nftset #6:ip#table#set, set global nftset.
 
 # set ddns domain
 # ddns-domain domain

+ 23 - 28
src/dns_conf.c

@@ -306,7 +306,8 @@ static int _get_domain(char *value, char *domain, int max_domain_size, char **pt
 	/* first field */
 	begin = strstr(value, "/");
 	if (begin == NULL) {
-		goto errout;
+		safe_strncpy(domain, ".", max_domain_size);
+		return 0;
 	}
 
 	/* second field */
@@ -318,6 +319,9 @@ static int _get_domain(char *value, char *domain, int max_domain_size, char **pt
 
 	/* remove prefix . */
 	while (*begin == '.') {
+		if (begin + 1 == end) {
+			break;
+		}
 		begin++;
 	}
 
@@ -1218,15 +1222,21 @@ static int _config_setup_domain_key(const char *domain, char *domain_key, int do
 {
 	int tmp_root_rule_only = 0;
 	int tmp_sub_rule_only = 0;
+	int domain_len = 0;
 
 	int len = strlen(domain);
-	if (len >= domain_key_max_len - 2) {
+	domain_len = len;
+	if (len >= domain_key_max_len - 3) {
 		tlog(TLOG_ERROR, "domain %s too long", domain);
 		return -1;
 	}
 
-	reverse_string(domain_key, domain, len, 1);
-	if (domain[0] == '*') {
+	while (len > 0 && domain[len - 1] == '.') {
+		len--;
+	}
+
+	reverse_string(domain_key + 1, domain, len, 1);
+	if (domain[0] == '*' && domain_len > 1) {
 		/* prefix wildcard */
 		len--;
 		if (domain[1] == '.') {
@@ -1236,20 +1246,22 @@ static int _config_setup_domain_key(const char *domain, char *domain_key, int do
 			tmp_sub_rule_only = 1;
 			tmp_root_rule_only = 1;
 		}
-	} else if (domain[0] == '-') {
+	} else if (domain[0] == '-' && domain_len > 1) {
 		/* root match only */
 		len--;
 		if (domain[1] == '.') {
 			tmp_root_rule_only = 1;
 		}
-	} else {
+	} else if (len > 0) {
 		/* suffix match */
-		domain_key[len] = '.';
+		domain_key[len + 1] = '.';
 		len++;
 	}
-	domain_key[len] = 0;
 
-	*domain_key_len = len;
+	domain_key[len + 1] = 0;
+	domain_key[0] = '.';
+
+	*domain_key_len = len + 1;
 	if (root_rule_only) {
 		*root_rule_only = tmp_root_rule_only;
 	}
@@ -1656,16 +1668,7 @@ static int _config_ipset(void *data, int argc, char *argv[])
 	}
 
 	if (_get_domain(value, domain, DNS_MAX_CONF_CNAME_LEN, &value) != 0) {
-		if (strstr(value, "/")) {
-			goto errout;
-		}
-
-		if (_config_ipset_setvalue(&_config_current_rule_group()->ipset_nftset.ipset, value) != 0) {
-			ret = -1;
-			goto errout;
-		}
-
-		return 0;
+		goto errout;
 	}
 
 	ret = _conf_domain_rule_ipset(domain, value);
@@ -1868,15 +1871,7 @@ static int _config_nftset(void *data, int argc, char *argv[])
 	}
 
 	if (_get_domain(value, domain, DNS_MAX_CONF_CNAME_LEN, &value) != 0) {
-		if (strstr(value, "/")) {
-			goto errout;
-		}
-		if (_config_nftset_setvalue(&_config_current_rule_group()->ipset_nftset.nftset, value) != 0) {
-			ret = -1;
-			goto errout;
-		}
-
-		return 0;
+		goto errout;
 	}
 
 	return _conf_domain_rule_nftset(domain, value);

+ 0 - 2
src/dns_conf.h

@@ -407,10 +407,8 @@ struct dns_conf_domain_rule {
 struct dns_conf_ipset_nftset {
 	int ipset_timeout_enable;
 	struct dns_ipset_names ipset_no_speed;
-	struct dns_ipset_names ipset;
 	int nftset_timeout_enable;
 	struct dns_nftset_names nftset_no_speed;
-	struct dns_nftset_names nftset;
 };
 
 struct dns_conf_group {

+ 5 - 24
src/dns_server.c

@@ -1896,10 +1896,6 @@ static int _dns_server_setup_ipset_nftset_packet(struct dns_server_post_context
 			ipset_rule = _dns_server_get_bind_ipset_nftset_rule(request, DOMAIN_RULE_IPSET);
 		}
 
-		if (ipset_rule == NULL && conf->ipset_nftset.ipset.inet_enable) {
-			ipset_rule = &conf->ipset_nftset.ipset.inet;
-		}
-
 		if (ipset_rule == NULL && check_no_speed_rule && conf->ipset_nftset.ipset_no_speed.inet_enable) {
 			ipset_rule_v4 = &conf->ipset_nftset.ipset_no_speed.inet;
 		}
@@ -1911,10 +1907,6 @@ static int _dns_server_setup_ipset_nftset_packet(struct dns_server_post_context
 			ipset_rule_v4 = _dns_server_get_bind_ipset_nftset_rule(request, DOMAIN_RULE_IPSET_IPV4);
 		}
 
-		if (ipset_rule_v4 == NULL && ipset_rule == NULL && conf->ipset_nftset.ipset.ipv4_enable) {
-			ipset_rule_v4 = &conf->ipset_nftset.ipset.ipv4;
-		}
-
 		if (ipset_rule_v4 == NULL && check_no_speed_rule && conf->ipset_nftset.ipset_no_speed.ipv4_enable) {
 			ipset_rule_v4 = &conf->ipset_nftset.ipset_no_speed.ipv4;
 		}
@@ -1926,10 +1918,6 @@ static int _dns_server_setup_ipset_nftset_packet(struct dns_server_post_context
 			ipset_rule_v6 = _dns_server_get_bind_ipset_nftset_rule(request, DOMAIN_RULE_IPSET_IPV6);
 		}
 
-		if (ipset_rule_v6 == NULL && ipset_rule == NULL && conf->ipset_nftset.ipset.ipv6_enable) {
-			ipset_rule_v6 = &conf->ipset_nftset.ipset.ipv6;
-		}
-
 		if (ipset_rule_v6 == NULL && check_no_speed_rule && conf->ipset_nftset.ipset_no_speed.ipv6_enable) {
 			ipset_rule_v6 = &conf->ipset_nftset.ipset_no_speed.ipv6;
 		}
@@ -1941,10 +1929,6 @@ static int _dns_server_setup_ipset_nftset_packet(struct dns_server_post_context
 			nftset_ip = _dns_server_get_bind_ipset_nftset_rule(request, DOMAIN_RULE_NFTSET_IP);
 		}
 
-		if (nftset_ip == NULL && conf->ipset_nftset.nftset.ip_enable) {
-			nftset_ip = &conf->ipset_nftset.nftset.ip;
-		}
-
 		if (nftset_ip == NULL && check_no_speed_rule && conf->ipset_nftset.nftset_no_speed.ip_enable) {
 			nftset_ip = &conf->ipset_nftset.nftset_no_speed.ip;
 		}
@@ -1957,10 +1941,6 @@ static int _dns_server_setup_ipset_nftset_packet(struct dns_server_post_context
 			nftset_ip6 = _dns_server_get_bind_ipset_nftset_rule(request, DOMAIN_RULE_NFTSET_IP6);
 		}
 
-		if (nftset_ip6 == NULL && conf->ipset_nftset.nftset.ip6_enable) {
-			nftset_ip6 = &conf->ipset_nftset.nftset.ip6;
-		}
-
 		if (nftset_ip6 == NULL && check_no_speed_rule && conf->ipset_nftset.nftset_no_speed.ip6_enable) {
 			nftset_ip6 = &conf->ipset_nftset.nftset_no_speed.ip6;
 		}
@@ -4664,13 +4644,14 @@ static void _dns_server_get_domain_rule_by_domain(struct dns_request *request, c
 
 	/* reverse domain string */
 	domain_len = strlen(domain);
-	if (domain_len >= (int)sizeof(domain_key) - 2) {
+	if (domain_len >= (int)sizeof(domain_key) - 3) {
 		return;
 	}
 
-	reverse_string(domain_key, domain, domain_len, 1);
-	domain_key[domain_len] = '.';
-	domain_len++;
+	reverse_string(domain_key + 1, domain, domain_len, 1);
+	domain_key[domain_len + 1] = '.';
+	domain_key[0] = '.';
+	domain_len += 2;
 	domain_key[domain_len] = 0;
 
 	/* find domain rule */

+ 54 - 0
test/cases/test-rule.cc

@@ -300,3 +300,57 @@ address /*.b.com/#6
 	EXPECT_EQ(client.GetAnswer()[0].GetType(), "AAAA");
 	EXPECT_EQ(client.GetAnswer()[0].GetData(), "64:ff9b::102:304");
 }
+
+TEST_F(Rule, root)
+{
+	smartdns::MockServer server_upstream;
+	smartdns::Server server;
+
+	server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
+		if (request->qtype == DNS_T_A) {
+			smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 700);
+			return smartdns::SERVER_REQUEST_OK;
+		} else if (request->qtype == DNS_T_AAAA) {
+			smartdns::MockServer::AddIP(request, request->domain.c_str(), "64:ff9b::102:304", 700);
+			return smartdns::SERVER_REQUEST_OK;
+		}
+		return smartdns::SERVER_REQUEST_SOA;
+	});
+
+	server.Start(R"""(bind [::]:60053
+server 127.0.0.1:61053
+speed-check-mode none
+address #6
+address /-.a.com/-6
+address 
+)""");
+	smartdns::Client client;
+
+	ASSERT_TRUE(client.Query("a.com A", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
+	EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700);
+	EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
+
+	ASSERT_TRUE(client.Query("a.com AAAA", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
+	EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700);
+	EXPECT_EQ(client.GetAnswer()[0].GetType(), "AAAA");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "64:ff9b::102:304");
+
+	ASSERT_TRUE(client.Query("a.a.com AAAA", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 0);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+
+	ASSERT_TRUE(client.Query("b.com AAAA", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 0);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+}