Просмотр исходного кода

test: return NXDOMAIN when block ad, add some test cases for ip-rule.

Nick Peng 2 лет назад
Родитель
Сommit
2dbde718a7

+ 1 - 1
etc/smartdns/smartdns.conf

@@ -40,7 +40,6 @@
 #   -no-rule-soa: Skip address SOA(#) rules.
 #   -no-dualstack-selection: Disable dualstack ip selection.
 #   -force-aaaa-soa: force AAAA query return SOA.
-#   -set-mark: set mark on packets.
 # example: 
 #  IPV4: 
 #    bind :53
@@ -185,6 +184,7 @@ log-level info
 #   -exclude-default-group: exclude this server from default group.
 #   -proxy [proxy-name]: use proxy to connect to server.
 #   -bootstrap-dns: set as bootstrap dns server.
+#   -set-mark: set mark on packets.
 # server 8.8.8.8 -blacklist-ip -check-edns -group g1 -group g2
 # server tls://dns.google:853 
 # server https://dns.google/dns-query

+ 27 - 14
src/dns_server.c

@@ -516,14 +516,14 @@ static void _dns_server_set_dualstack_selection(struct dns_request *request)
 	request->dualstack_selection = dns_conf_dualstack_ip_selection;
 }
 
-static int _dns_server_is_return_soa(struct dns_request *request)
+static int _dns_server_is_return_soa_qtype(struct dns_request *request, dns_type_t qtype)
 {
 	struct dns_rule_flags *rule_flag = NULL;
 	unsigned int flags = 0;
 
 	if (_dns_server_has_bind_flag(request, BIND_FLAG_NO_RULE_SOA) == 0) {
 		/* when both has no rule SOA and force AAAA soa, force AAAA soa has high priority */
-		if (request->qtype == DNS_T_AAAA && _dns_server_has_bind_flag(request, BIND_FLAG_FORCE_AAAA_SOA) == 0) {
+		if (qtype == DNS_T_AAAA && _dns_server_has_bind_flag(request, BIND_FLAG_FORCE_AAAA_SOA) == 0) {
 			return 1;
 		}
 
@@ -542,7 +542,7 @@ static int _dns_server_is_return_soa(struct dns_request *request)
 			return 0;
 		}
 
-		switch (request->qtype) {
+		switch (qtype) {
 		case DNS_T_A:
 			if (flags & DOMAIN_FLAG_ADDR_IPV4_SOA) {
 				return 1;
@@ -568,7 +568,7 @@ static int _dns_server_is_return_soa(struct dns_request *request)
 		}
 	}
 
-	if (request->qtype == DNS_T_AAAA) {
+	if (qtype == DNS_T_AAAA) {
 		if (_dns_server_has_bind_flag(request, BIND_FLAG_FORCE_AAAA_SOA) == 0 || dns_conf_force_AAAA_SOA == 1) {
 			return 1;
 		}
@@ -577,6 +577,11 @@ static int _dns_server_is_return_soa(struct dns_request *request)
 	return 0;
 }
 
+static int _dns_server_is_return_soa(struct dns_request *request)
+{
+	return _dns_server_is_return_soa_qtype(request, request->qtype);
+}
+
 static void _dns_server_post_context_init(struct dns_server_post_context *context, struct dns_request *request)
 {
 	memset(context, 0, sizeof(*context));
@@ -2687,6 +2692,9 @@ static int _dns_server_ip_rule_check(struct dns_request *request, unsigned char
 	/* bogus-nxdomain */
 	rule = node->data;
 	if (rule->bogus) {
+		request->rcode = DNS_RC_NXDOMAIN;
+		request->has_soa = 1;
+		_dns_server_setup_soa(request);
 		goto match;
 	}
 
@@ -3931,10 +3939,14 @@ static int _dns_server_pre_process_rule_flags(struct dns_request *request)
 {
 	struct dns_rule_flags *rule_flag = NULL;
 	unsigned int flags = 0;
+	int rcode = DNS_RC_NOERROR;
 
 	/* get domain rule flag */
 	rule_flag = _dns_server_get_dns_rule(request, DOMAIN_RULE_FLAGS);
 	if (rule_flag == NULL) {
+		if (_dns_server_is_return_soa(request)) {
+			goto soa;
+		}
 		goto out;
 	}
 
@@ -3952,10 +3964,6 @@ static int _dns_server_pre_process_rule_flags(struct dns_request *request)
 		goto out;
 	}
 
-	if (_dns_server_is_return_soa(request)) {
-		goto soa;
-	}
-
 	/* return specific type of address */
 	switch (request->qtype) {
 	case DNS_T_A:
@@ -3966,6 +3974,9 @@ static int _dns_server_pre_process_rule_flags(struct dns_request *request)
 
 		if (_dns_server_is_return_soa(request)) {
 			/* return SOA for A request */
+			if (_dns_server_is_return_soa_qtype(request, DNS_T_AAAA)) {
+				rcode = DNS_RC_NXDOMAIN;
+			}
 			goto soa;
 		}
 		break;
@@ -3977,6 +3988,9 @@ static int _dns_server_pre_process_rule_flags(struct dns_request *request)
 
 		if (_dns_server_is_return_soa(request)) {
 			/* return SOA for A request */
+			if (_dns_server_is_return_soa_qtype(request, DNS_T_A)) {
+				rcode = DNS_RC_NXDOMAIN;
+			}
 			goto soa;
 		}
 
@@ -3990,12 +4004,16 @@ static int _dns_server_pre_process_rule_flags(struct dns_request *request)
 		break;
 	}
 
+	if (_dns_server_is_return_soa(request)) {
+		goto soa;
+	}
+
 out:
 	return -1;
 
 soa:
 	/* return SOA */
-	_dns_server_reply_SOA(DNS_RC_NOERROR, request);
+	_dns_server_reply_SOA(rcode, request);
 	return 0;
 }
 
@@ -4842,11 +4860,6 @@ static int _dns_server_process_special_query(struct dns_request *request)
 	case DNS_T_A:
 		break;
 	case DNS_T_AAAA:
-		/* force return SOA */
-		if (_dns_server_is_return_soa(request)) {
-			_dns_server_reply_SOA(DNS_RC_NOERROR, request);
-			goto clean_exit;
-		}
 
 		break;
 	default:

+ 2 - 2
test/cases/test-address.cc

@@ -100,7 +100,7 @@ cache-persist no)""");
 	ASSERT_TRUE(client.Query("c.com A", 60053));
 	std::cout << client.GetResult() << std::endl;
 	ASSERT_EQ(client.GetAnswerNum(), 0);
-	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetStatus(), "NXDOMAIN");
 	EXPECT_EQ(client.GetAuthority()[0].GetName(), "c.com");
 	EXPECT_EQ(client.GetAuthority()[0].GetTTL(), 30);
 	EXPECT_EQ(client.GetAuthority()[0].GetType(), "SOA");
@@ -110,7 +110,7 @@ cache-persist no)""");
 	ASSERT_TRUE(client.Query("c.com AAAA", 60053));
 	std::cout << client.GetResult() << std::endl;
 	ASSERT_EQ(client.GetAnswerNum(), 0);
-	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetStatus(), "NXDOMAIN");
 	EXPECT_EQ(client.GetAuthority()[0].GetName(), "c.com");
 	EXPECT_EQ(client.GetAuthority()[0].GetTTL(), 30);
 	EXPECT_EQ(client.GetAuthority()[0].GetType(), "SOA");

+ 73 - 0
test/cases/test-domain-rule.cc

@@ -0,0 +1,73 @@
+/*************************************************************************
+ *
+ * Copyright (C) 2018-2023 Ruilin Peng (Nick) <[email protected]>.
+ *
+ * smartdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * smartdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include "client.h"
+#include "dns.h"
+#include "include/utils.h"
+#include "server.h"
+#include "gtest/gtest.h"
+
+class DomainRule : public ::testing::Test
+{
+  protected:
+	virtual void SetUp() {}
+	virtual void TearDown() {}
+};
+
+TEST_F(DomainRule, bogus_nxdomain)
+{
+	smartdns::MockServer server_upstream;
+	smartdns::MockServer server_upstream2;
+	smartdns::Server server;
+
+	server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
+		if (request->qtype != DNS_T_A) {
+			return smartdns::SERVER_REQUEST_SOA;
+		}
+        if (request->domain == "a.com") {
+            smartdns::MockServer::AddIP(request, request->domain.c_str(), "10.11.12.13", 611);
+            return smartdns::SERVER_REQUEST_OK;
+        }
+
+		smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 611);
+		return smartdns::SERVER_REQUEST_OK;
+	});
+
+    /* this ip will be discard, but is reachable */
+    server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 10);
+
+	server.Start(R"""(bind [::]:60053
+server udp://127.0.0.1:61053 -blacklist-ip
+bogus-nxdomain 10.0.0.0/8
+log-num 0
+log-console yes
+log-level debug
+cache-persist no)""");
+	smartdns::Client client;
+	ASSERT_TRUE(client.Query("a.com", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAuthorityNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NXDOMAIN");
+
+    ASSERT_TRUE(client.Query("b.com", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
+}

+ 156 - 0
test/cases/test-ip-rule.cc

@@ -0,0 +1,156 @@
+/*************************************************************************
+ *
+ * Copyright (C) 2018-2023 Ruilin Peng (Nick) <[email protected]>.
+ *
+ * smartdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * smartdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include "client.h"
+#include "dns.h"
+#include "include/utils.h"
+#include "server.h"
+#include "gtest/gtest.h"
+
+class IPRule : public ::testing::Test
+{
+  protected:
+	virtual void SetUp() {}
+	virtual void TearDown() {}
+};
+
+TEST_F(IPRule, white_list)
+{
+	smartdns::MockServer server_upstream;
+	smartdns::MockServer server_upstream2;
+	smartdns::Server server;
+
+	server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
+		if (request->qtype != DNS_T_A) {
+			return smartdns::SERVER_REQUEST_SOA;
+		}
+
+		smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 611);
+		return smartdns::SERVER_REQUEST_OK;
+	});
+
+	server_upstream2.Start("udp://0.0.0.0:62053", [](struct smartdns::ServerRequestContext *request) {
+		if (request->qtype != DNS_T_A) {
+			return smartdns::SERVER_REQUEST_SOA;
+		}
+
+		smartdns::MockServer::AddIP(request, request->domain.c_str(), "4.5.6.7", 611);
+		return smartdns::SERVER_REQUEST_OK;
+	});
+
+    /* this ip will be discard, but is reachable */
+    server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 10);
+
+	server.Start(R"""(bind [::]:60053
+server udp://127.0.0.1:61053 -whitelist-ip
+server udp://127.0.0.1:62053 -whitelist-ip
+whitelist-ip 4.5.6.7/24
+log-num 0
+log-console yes
+log-level debug
+cache-persist no)""");
+	smartdns::Client client;
+	ASSERT_TRUE(client.Query("a.com", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "4.5.6.7");
+}
+
+TEST_F(IPRule, black_list)
+{
+	smartdns::MockServer server_upstream;
+	smartdns::MockServer server_upstream2;
+	smartdns::Server server;
+
+	server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
+		if (request->qtype != DNS_T_A) {
+			return smartdns::SERVER_REQUEST_SOA;
+		}
+
+		smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 611);
+		return smartdns::SERVER_REQUEST_OK;
+	});
+
+	server_upstream2.Start("udp://0.0.0.0:62053", [](struct smartdns::ServerRequestContext *request) {
+		if (request->qtype != DNS_T_A) {
+			return smartdns::SERVER_REQUEST_SOA;
+		}
+
+		smartdns::MockServer::AddIP(request, request->domain.c_str(), "4.5.6.7", 611);
+		return smartdns::SERVER_REQUEST_OK;
+	});
+
+    /* this ip will be discard, but is reachable */
+    server.MockPing(PING_TYPE_ICMP, "4.5.6.7", 60, 10);
+
+	server.Start(R"""(bind [::]:60053
+server udp://127.0.0.1:61053 -blacklist-ip
+server udp://127.0.0.1:62053 -blacklist-ip
+blacklist-ip 4.5.6.7/24
+log-num 0
+log-console yes
+log-level debug
+cache-persist no)""");
+	smartdns::Client client;
+	ASSERT_TRUE(client.Query("a.com", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
+}
+
+TEST_F(IPRule, ignore_ip)
+{
+	smartdns::MockServer server_upstream;
+	smartdns::MockServer server_upstream2;
+	smartdns::Server server;
+
+	server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
+		if (request->qtype != DNS_T_A) {
+			return smartdns::SERVER_REQUEST_SOA;
+		}
+
+		smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 611);
+		smartdns::MockServer::AddIP(request, request->domain.c_str(), "4.5.6.7", 611);
+		smartdns::MockServer::AddIP(request, request->domain.c_str(), "7.8.9.10", 611);
+		return smartdns::SERVER_REQUEST_OK;
+	});
+
+    /* this ip will be discard, but is reachable */
+    server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 10);
+    server.MockPing(PING_TYPE_ICMP, "4.5.6.7", 60, 90);
+    server.MockPing(PING_TYPE_ICMP, "7.8.9.10", 60, 40);
+
+	server.Start(R"""(bind [::]:60053
+server udp://127.0.0.1:61053 -blacklist-ip
+ignore-ip 1.2.3.0/24
+log-num 0
+log-console yes
+log-level debug
+cache-persist no)""");
+	smartdns::Client client;
+	ASSERT_TRUE(client.Query("a.com", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "7.8.9.10");
+}

+ 107 - 0
test/cases/test-qtype-soa.cc

@@ -76,3 +76,110 @@ cache-persist no)""");
 	EXPECT_EQ(client.GetStatus(), "NOERROR");
 	EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
 }
+
+TEST_F(QtypeSOA, force_AAAA_SOA)
+{
+	smartdns::MockServer server_upstream;
+	smartdns::Server server;
+	std::map<int, int> qid_map;
+
+	server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
+		if (request->qtype == DNS_T_A) {
+			smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 700);
+			return smartdns::SERVER_REQUEST_OK;
+		} else if (request->qtype == DNS_T_AAAA) {
+			smartdns::MockServer::AddIP(request, request->domain.c_str(), "64:ff9b::102:304", 700);
+			return smartdns::SERVER_REQUEST_OK;
+		}
+		return smartdns::SERVER_REQUEST_SOA;
+	});
+
+	server.Start(R"""(bind [::]:60053
+server 127.0.0.1:61053
+log-num 0
+log-console yes
+log-level debug
+speed-check-mode none
+force-AAAA-SOA yes
+cache-persist no)""");
+	smartdns::Client client;
+	ASSERT_TRUE(client.Query("a.com A", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
+	EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700);
+	EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
+
+	ASSERT_TRUE(client.Query("a.com AAAA", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAuthorityNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAuthority()[0].GetName(), "a.com");
+	EXPECT_EQ(client.GetAuthority()[0].GetTTL(), 30);
+	EXPECT_EQ(client.GetAuthority()[0].GetType(), "SOA");
+}
+
+TEST_F(QtypeSOA, bind_force_AAAA_SOA)
+{
+	smartdns::MockServer server_upstream;
+	smartdns::Server server;
+	std::map<int, int> qid_map;
+
+	server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
+		if (request->qtype == DNS_T_A) {
+			smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 700);
+			return smartdns::SERVER_REQUEST_OK;
+		} else if (request->qtype == DNS_T_AAAA) {
+			smartdns::MockServer::AddIP(request, request->domain.c_str(), "64:ff9b::102:304", 700);
+			return smartdns::SERVER_REQUEST_OK;
+		}
+		return smartdns::SERVER_REQUEST_SOA;
+	});
+
+	server.Start(R"""(
+bind [::]:60053
+bind [::]:60153 -force-aaaa-soa
+server 127.0.0.1:61053
+log-num 0
+log-console yes
+log-level debug
+speed-check-mode none
+cache-persist no)""");
+	smartdns::Client client;
+	ASSERT_TRUE(client.Query("a.com A", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
+	EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700);
+	EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
+
+	ASSERT_TRUE(client.Query("a.com AAAA", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
+	EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700);
+	EXPECT_EQ(client.GetAnswer()[0].GetType(), "AAAA");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "64:ff9b::102:304");
+
+	ASSERT_TRUE(client.Query("a.com A", 60153));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
+	EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700);
+	EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
+
+	ASSERT_TRUE(client.Query("a.com AAAA", 60153));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAuthorityNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAuthority()[0].GetName(), "a.com");
+	EXPECT_EQ(client.GetAuthority()[0].GetTTL(), 30);
+	EXPECT_EQ(client.GetAuthority()[0].GetType(), "SOA");
+}