Przeglądaj źródła

domain-rule: Fix domain rule precedence to allow multiple address rules for same domain with different matching types

Nick Peng 3 miesięcy temu
rodzic
commit
2e60a0cce7

+ 0 - 2
src/dns_conf/dns_conf_group.c

@@ -209,8 +209,6 @@ static int _config_domain_rule_iter_copy(void *data, const unsigned char *key, u
 			new_domain_rule->rules[i] = old_domain_rule->rules[i];
 		}
 	}
-	new_domain_rule->sub_rule_only = old_domain_rule->sub_rule_only;
-	new_domain_rule->root_rule_only = old_domain_rule->root_rule_only;
 
 	old_domain_rule = art_insert(dest_tree, key, key_len, new_domain_rule);
 	if (old_domain_rule) {

+ 15 - 6
src/dns_conf/domain_rule.c

@@ -287,8 +287,18 @@ static int _config_setup_domain_key(const char *domain, char *domain_key, int do
 		}
 	}
 
+	/* add dot to the front when sub rule only */
 	domain_key[0] = '.';
-	domain_key[len + 1] = '\0';
+	if (tmp_sub_rule_only == 1 && tmp_root_rule_only == 0) {
+		domain_key[len + 1] = '\0';
+	} else if (tmp_root_rule_only == 1 && tmp_sub_rule_only == 0) {
+		if (domain_key[len] == '.') {
+			len--;
+		}
+		domain_key[len + 1] = '\0';
+	} else {
+		domain_key[len + 1] = '\0';
+	}
 
 	*domain_key_len = len + 1;
 	if (root_rule_only) {
@@ -422,10 +432,9 @@ int _config_domain_rule_flag_set(const char *domain, unsigned int flag, unsigned
 		domain_rule->rules[DOMAIN_RULE_FLAGS] = (struct dns_rule *)rule_flags;
 	}
 
-	domain_rule->sub_rule_only = sub_rule_only;
-	domain_rule->root_rule_only = root_rule_only;
-
 	rule_flags = (struct dns_rule_flags *)domain_rule->rules[DOMAIN_RULE_FLAGS];
+	rule_flags->head.sub_only = sub_rule_only;
+	rule_flags->head.root_only = root_rule_only;
 	if (is_clear == false) {
 		rule_flags->flags |= flag;
 	} else {
@@ -552,8 +561,8 @@ int _config_domain_rule_add(const char *domain, enum domain_rule type, void *rul
 	}
 
 	domain_rule->rules[type] = rule;
-	domain_rule->sub_rule_only = sub_rule_only;
-	domain_rule->root_rule_only = root_rule_only;
+	((struct dns_rule *)rule)->sub_only = sub_rule_only;
+	((struct dns_rule *)rule)->root_only = root_rule_only;
 	_dns_rule_get(rule);
 
 	/* update domain rule - only for new allocations */

+ 1 - 0
src/dns_server/dns_server.h

@@ -90,6 +90,7 @@ typedef enum DNS_CHILD_POST_RESULT {
 struct rule_walk_args {
 	void *args;
 	int rule_index;
+	uint32_t full_key_len;
 	unsigned char *key[DOMAIN_RULE_MAX];
 	uint32_t key_len[DOMAIN_RULE_MAX];
 };

+ 24 - 10
src/dns_server/rules.c

@@ -112,16 +112,12 @@ static int _dns_server_get_rules(unsigned char *key, uint32_t key_len, int is_su
 		return 0;
 	}
 
-	if (domain_rule->sub_rule_only != domain_rule->root_rule_only) {
-		/* only subkey rule */
-		if (domain_rule->sub_rule_only == 1 && is_subkey == 0) {
-			return 0;
-		}
-
-		/* only root key rule */
-		if (domain_rule->root_rule_only == 1 && is_subkey == 1) {
-			return 0;
-		}
+	/* sub rule flag check */
+	int is_effective_sub = 1;
+	if (key_len == walk_args->full_key_len) {
+		is_effective_sub = 0;
+	} else if (key_len == walk_args->full_key_len - 1 && walk_args->full_key_len > 0) {
+		is_effective_sub = 0;
 	}
 
 	if (walk_args->rule_index >= 0) {
@@ -139,9 +135,26 @@ static int _dns_server_get_rules(unsigned char *key, uint32_t key_len, int is_su
 		}
 
 		if (i == DOMAIN_RULE_FLAGS) {
+			struct dns_rule_flags *rule_flags = (struct dns_rule_flags *)domain_rule->rules[i];
+			if (rule_flags->head.sub_only == 1 && is_effective_sub == 0) {
+				continue;
+			}
+
+			if (rule_flags->head.root_only == 1 && is_effective_sub == 1) {
+				continue;
+			}
+
 			request_domain_rule->flags |= ((struct dns_rule_flags *)domain_rule->rules[i])->flags;
 		}
 
+		if (domain_rule->rules[i]->sub_only == 1 && is_effective_sub == 0) {
+			continue;
+		}
+
+		if (domain_rule->rules[i]->root_only == 1 && is_effective_sub == 1) {
+			continue;
+		}
+
 		request_domain_rule->rules[i] = domain_rule->rules[i];
 		request_domain_rule->is_sub_rule[i] = is_subkey;
 		walk_args->key[i] = key;
@@ -183,6 +196,7 @@ void _dns_server_get_domain_rule_by_domain_ext(struct dns_conf_group *conf,
 	domain_key[0] = '.';
 	domain_len += 2;
 	domain_key[domain_len] = 0;
+	walk_args.full_key_len = domain_len;
 
 	/* find domain rule */
 	art_substring_walk(&conf->domain_rule.tree, (unsigned char *)domain_key, domain_len, _dns_server_get_rules,

+ 4 - 5
src/include/smartdns/dns_conf.h

@@ -189,6 +189,8 @@ enum response_mode_type {
 struct dns_rule {
 	atomic_t refcnt;
 	enum domain_rule rule;
+	unsigned char sub_only : 1;
+	unsigned char root_only : 1;
 };
 
 struct dns_rule_flags {
@@ -268,10 +270,8 @@ extern struct dns_nftset_names dns_conf_nftset_no_speed;
 extern struct dns_nftset_names dns_conf_nftset;
 
 struct dns_domain_rule {
-	unsigned char sub_rule_only : 1;
-	unsigned char root_rule_only : 1;
-	unsigned char capacity : 6; /* Current allocated capacity (max 63) */
-	struct dns_rule *rules[];   /* Flexible array member */
+	unsigned char capacity;   /* Current allocated capacity (max 255) */
+	struct dns_rule *rules[]; /* Flexible array member */
 };
 
 struct dns_nameserver_rule {
@@ -676,7 +676,6 @@ struct dns_conf_plugin {
 	int args_len;
 };
 
-
 struct dns_conf_plugin_table {
 	DECLARE_HASHTABLE(plugins, 4);
 	art_tree plugins_conf;

+ 43 - 1
test/cases/test-rule.cc

@@ -17,9 +17,9 @@
  */
 
 #include "client.h"
-#include "smartdns/dns.h"
 #include "include/utils.h"
 #include "server.h"
+#include "smartdns/dns.h"
 #include "smartdns/util.h"
 #include "gtest/gtest.h"
 #include <fstream>
@@ -354,3 +354,45 @@ address
 	ASSERT_EQ(client.GetAnswerNum(), 0);
 	EXPECT_EQ(client.GetStatus(), "NOERROR");
 }
+
+TEST_F(Rule, root_and_sub)
+{
+	smartdns::Server server;
+	smartdns::MockServer server_upstream;
+
+	server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
+		if (request->qtype == DNS_T_A) {
+			smartdns::MockServer::AddIP(request, request->domain.c_str(), "9.9.9.9", 700);
+			return smartdns::SERVER_REQUEST_OK;
+		}
+		return smartdns::SERVER_REQUEST_SOA;
+	});
+
+	server.Start(R"""(bind [::]:60053
+server 127.0.0.1:61053
+log-level debug
+speed-check-mode none
+address /q.a.com/1.2.3.4
+address /-.a.com/4.5.6.7
+address /*.a.com/7.8.9.10
+)""");
+	smartdns::Client client;
+
+	// 1. q.a.com should be 1.2.3.4
+	ASSERT_TRUE(client.Query("q.a.com A", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
+
+	// 2. a.com should be 4.5.6.7 (matching -.a.com)
+	ASSERT_TRUE(client.Query("a.com A", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "4.5.6.7");
+
+	// 3. other.a.com should be 7.8.9.10 (matching *.a.com)
+	ASSERT_TRUE(client.Query("other.a.com A", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "7.8.9.10");
+}