Browse Source

dns_conf: fix wildcard match issue

Nick Peng 2 years ago
parent
commit
37a87e864e
2 changed files with 119 additions and 60 deletions
  1. 63 59
      src/dns_conf.c
  2. 56 1
      test/cases/test-rule.cc

+ 63 - 59
src/dns_conf.c

@@ -864,20 +864,63 @@ static int _config_domain_rule_add_callback(const char *domain, void *priv)
 	return _config_domain_rule_add(domain, args->type, args->rule);
 }
 
+static int _config_setup_domain_key(const char *domain, char *domain_key, int domain_key_max_len, int *domain_key_len,
+									int *root_rule_only, int *sub_rule_only)
+{
+	int tmp_root_rule_only = 0;
+	int tmp_sub_rule_only = 0;
+
+	int len = strlen(domain);
+	if (len >= domain_key_max_len - 2) {
+		tlog(TLOG_ERROR, "domain %s too long", domain);
+		return -1;
+	}
+
+	reverse_string(domain_key, domain, len, 1);
+	if (domain[0] == '*') {
+		/* prefix wildcard */
+		len--;
+		if (domain[1] == '.') {
+			tmp_sub_rule_only = 1;
+		} else if ((domain[1] == '-') && (domain[2] == '.')) {
+			len--;
+			tmp_sub_rule_only = 1;
+			tmp_root_rule_only = 1;
+		}
+	} else if (domain[0] == '-') {
+		/* root match only */
+		len--;
+		if (domain[1] == '.') {
+			tmp_root_rule_only = 1;
+		}
+	} else {
+		/* suffix match */
+		domain_key[len] = '.';
+		len++;
+	}
+	domain_key[len] = 0;
+
+	*domain_key_len = len;
+	if (root_rule_only) {
+		*root_rule_only = tmp_root_rule_only;
+	}
+
+	if (sub_rule_only) {
+		*sub_rule_only = tmp_sub_rule_only;
+	}
+
+	return 0;
+}
+
 static struct dns_domain_rule *_config_domain_rule_get(const char *domain)
 {
 	char domain_key[DNS_MAX_CONF_CNAME_LEN];
 	int len = 0;
 
-	len = strlen(domain);
-	if (len >= (int)sizeof(domain_key) - 2) {
+	if (_config_setup_domain_key(domain, domain_key, sizeof(domain_key), &len, NULL, NULL) != 0) {
 		return NULL;
 	}
 
-	reverse_string(domain_key, domain, len, 1);
-	domain_key[len] = '.';
-	len++;
-	domain_key[len] = 0;
 	return art_search(&dns_conf_domain_rule, (unsigned char *)domain_key, len);
 }
 
@@ -892,18 +935,6 @@ static int _config_domain_rule_add(const char *domain, enum domain_rule type, vo
 	int sub_rule_only = 0;
 	int root_rule_only = 0;
 
-	/* Reverse string, for suffix match */
-	len = strlen(domain);
-	if (len >= (int)sizeof(domain_key) - 2) {
-		tlog(TLOG_ERROR, "domain name %s too long", domain);
-		goto errout;
-	}
-
-	if (len <= 0) {
-		tlog(TLOG_ERROR, "domain name %s too short", domain);
-		goto errout;
-	}
-
 	if (strncmp(domain, "domain-set:", sizeof("domain-set:") - 1) == 0) {
 		struct dns_set_rule_add_callback_args args;
 		args.type = type;
@@ -912,29 +943,10 @@ static int _config_domain_rule_add(const char *domain, enum domain_rule type, vo
 											&args);
 	}
 
-	reverse_string(domain_key, domain, len, 1);
-	if (domain[0] == '*') {
-		/* prefix wildcard */
-		len--;
-		if (domain[1] == '.') {
-			sub_rule_only = 1;
-		} else if ((domain[1] == '-') && (domain[2] == '.')) {
-			len--;
-			sub_rule_only = 1;
-			root_rule_only = 1;
-		}
-	} else if (domain[0] == '-') {
-		/* root match only */
-		len--;
-		if (domain[1] == '.') {
-			root_rule_only = 1;
-		}
-	} else {
-		/* suffix match */
-		domain_key[len] = '.';
-		len++;
+	/* Reverse string, for suffix match */
+	if (_config_setup_domain_key(domain, domain_key, sizeof(domain_key), &len, &root_rule_only, &sub_rule_only) != 0) {
+		goto errout;
 	}
-	domain_key[len] = 0;
 
 	if (type >= DOMAIN_RULE_MAX) {
 		goto errout;
@@ -991,22 +1003,15 @@ static int _config_domain_rule_delete(const char *domain)
 	char domain_key[DNS_MAX_CONF_CNAME_LEN];
 	int len = 0;
 
-	/* Reverse string, for suffix match */
-	len = strlen(domain);
-	if (len >= (int)sizeof(domain_key)) {
-		tlog(TLOG_ERROR, "domain name %s too long", domain);
-		goto errout;
-	}
-
 	if (strncmp(domain, "domain-set:", sizeof("domain-set:") - 1) == 0) {
 		return _config_domain_rule_set_each(domain + sizeof("domain-set:") - 1, _config_domain_rule_delete_callback,
 											NULL);
 	}
+	/* Reverse string, for suffix match */
 
-	reverse_string(domain_key, domain, len, 1);
-	domain_key[len] = '.';
-	len++;
-	domain_key[len] = 0;
+	if (_config_setup_domain_key(domain, domain_key, sizeof(domain_key), &len, NULL, NULL) != 0) {
+		goto errout;
+	}
 
 	/* delete existing rules */
 	void *rule = art_delete(&dns_conf_domain_rule, (unsigned char *)domain_key, len);
@@ -1036,6 +1041,8 @@ static int _config_domain_rule_flag_set(const char *domain, unsigned int flag, u
 
 	char domain_key[DNS_MAX_CONF_CNAME_LEN];
 	int len = 0;
+	int sub_rule_only = 0;
+	int root_rule_only = 0;
 
 	if (strncmp(domain, "domain-set:", sizeof("domain-set:") - 1) == 0) {
 		struct dns_set_rule_flags_callback_args args;
@@ -1045,15 +1052,9 @@ static int _config_domain_rule_flag_set(const char *domain, unsigned int flag, u
 											&args);
 	}
 
-	len = strlen(domain);
-	if (len >= (int)sizeof(domain_key)) {
-		tlog(TLOG_ERROR, "domain %s too long", domain);
-		return -1;
+	if (_config_setup_domain_key(domain, domain_key, sizeof(domain_key), &len, &root_rule_only, &sub_rule_only) != 0) {
+		goto errout;
 	}
-	reverse_string(domain_key, domain, len, 1);
-	domain_key[len] = '.';
-	len++;
-	domain_key[len] = 0;
 
 	/* Get existing or create domain rule */
 	domain_rule = art_search(&dns_conf_domain_rule, (unsigned char *)domain_key, len);
@@ -1073,6 +1074,9 @@ static int _config_domain_rule_flag_set(const char *domain, unsigned int flag, u
 		domain_rule->rules[DOMAIN_RULE_FLAGS] = (struct dns_rule *)rule_flags;
 	}
 
+	domain_rule->sub_rule_only = sub_rule_only;
+	domain_rule->root_rule_only = root_rule_only;
+
 	rule_flags = (struct dns_rule_flags *)domain_rule->rules[DOMAIN_RULE_FLAGS];
 	if (is_clear == false) {
 		rule_flags->flags |= flag;

+ 56 - 1
test/cases/test-rule.cc

@@ -259,4 +259,59 @@ cache-persist no)""");
 	EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700);
 	EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
 	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
-}
+}
+
+TEST_F(Rule, AAAA_SOA)
+{
+	smartdns::MockServer server_upstream;
+	smartdns::Server server;
+
+	server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
+		if (request->qtype == DNS_T_A) {
+			smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 700);
+			return smartdns::SERVER_REQUEST_OK;
+		} else if (request->qtype == DNS_T_AAAA) {
+			smartdns::MockServer::AddIP(request, request->domain.c_str(), "64:ff9b::102:304", 700);
+			return smartdns::SERVER_REQUEST_OK;
+		}
+		return smartdns::SERVER_REQUEST_SOA;
+	});
+
+	server.Start(R"""(bind [::]:60053
+server 127.0.0.1:61053
+log-num 0
+log-console yes
+log-level debug
+speed-check-mode none
+address /-.a.com/#6
+address /*.b.com/#6
+cache-persist no)""");
+	smartdns::Client client;
+	ASSERT_TRUE(client.Query("a.com AAAA", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 0);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+
+	ASSERT_TRUE(client.Query("a.a.com AAAA", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.a.com");
+	EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700);
+	EXPECT_EQ(client.GetAnswer()[0].GetType(), "AAAA");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "64:ff9b::102:304");
+
+	ASSERT_TRUE(client.Query("a.b.com AAAA", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 0);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+
+	ASSERT_TRUE(client.Query("b.com AAAA", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com");
+	EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700);
+	EXPECT_EQ(client.GetAnswer()[0].GetType(), "AAAA");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "64:ff9b::102:304");
+}