Browse Source

test: add test for mock server and bind option

Nick Peng 2 years ago
parent
commit
455dca9ae4

+ 2 - 1
src/fast_ping.c

@@ -1894,6 +1894,7 @@ int fast_ping_init(void)
 	ping.ident = (getpid() & 0XFFFF);
 	atomic_set(&ping.run, 1);
 
+	ping.epoll_fd = epollfd;
 	ret = pthread_create(&ping.tid, &attr, _fast_ping_work, NULL);
 	if (ret != 0) {
 		tlog(TLOG_ERROR, "create ping work thread failed, %s\n", strerror(ret));
@@ -1906,7 +1907,6 @@ int fast_ping_init(void)
 		goto errout;
 	}
 
-	ping.epoll_fd = epollfd;
 	ret = _fast_ping_init_wakeup_event();
 	if (ret != 0) {
 		tlog(TLOG_ERROR, "init wakeup event failed, %s\n", strerror(errno));
@@ -1933,6 +1933,7 @@ errout:
 
 	if (epollfd > 0) {
 		close(epollfd);
+		ping.epoll_fd = -1;
 	}
 
 	if (ping.event_fd) {

+ 37 - 6
test/cases/test-bind.cc

@@ -35,15 +35,15 @@ cache-persist no)""");
 
 TEST(Bind, udp_tcp)
 {
-    smartdns::MockServer server_upstream;
+	smartdns::MockServer server_upstream;
 	smartdns::MockServer server_upstream2;
 	smartdns::Server server;
 
 	server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
-        unsigned char addr[4] = {1, 2, 3, 4};
+		unsigned char addr[4] = {1, 2, 3, 4};
 		dns_add_A(request->response_packet, DNS_RRS_AN, request->domain.c_str(), 611, addr);
 		request->response_packet->head.rcode = DNS_RC_NOERROR;
-		return true;
+		return smartdns::SERVER_REQUEST_OK;
 	});
 
 	server.Start(R"""(
@@ -59,15 +59,46 @@ cache-persist no)""");
 	std::cout << client.GetResult() << std::endl;
 	ASSERT_EQ(client.GetAnswerNum(), 1);
 	EXPECT_EQ(client.GetStatus(), "NOERROR");
-    EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
+	EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
 	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
 
 	ASSERT_TRUE(client.Query("a.com", 60053));
 	std::cout << client.GetResult() << std::endl;
 	ASSERT_EQ(client.GetAnswerNum(), 1);
 	EXPECT_EQ(client.GetStatus(), "NOERROR");
-    EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 611);
+	EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 611);
 	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
-
 }
 
+TEST(Bind, self)
+{
+	smartdns::MockServer server_upstream;
+	smartdns::Server server;
+
+	server_upstream.Start("udp://0.0.0.0:62053", [](struct smartdns::ServerRequestContext *request) {
+		if (request->qtype == DNS_T_A) {
+			smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4");
+			return smartdns::SERVER_REQUEST_OK;
+		}
+		return smartdns::SERVER_REQUEST_SOA;
+	});
+
+	server.Start(R"""(
+bind [::]:60053 -group self
+server 127.0.0.1:61053 -group self
+bind [::]:61053 -group upstream  
+server 127.0.0.1:62053 -group upstream
+log-num 0
+log-console yes
+log-level info
+cache-persist no)""");
+	smartdns::Client client;
+
+	ASSERT_TRUE(client.Query("a.com", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_LT(client.GetQueryTime(), 100);
+	EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
+}

+ 3 - 3
test/cases/test-cname.cc

@@ -12,7 +12,7 @@ TEST(server, cname)
     server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
         std::string domain = request->domain;
         if (request->domain.length() == 0) {
-            return false;
+            return smartdns::SERVER_REQUEST_ERROR;
         }
 
         if (request->qtype == DNS_T_A) {
@@ -22,13 +22,13 @@ TEST(server, cname)
             unsigned char addr[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
             dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr);
         } else {
-            return false;
+            return smartdns::SERVER_REQUEST_ERROR;
         }
 
         EXPECT_EQ(domain, "e.com");
 
         request->response_packet->head.rcode = DNS_RC_NOERROR;
-        return true;
+        return smartdns::SERVER_REQUEST_OK;
     });
 
 	server.Start(R"""(bind [::]:60053

+ 4 - 4
test/cases/test-discard-block-ip.cc

@@ -14,7 +14,7 @@ TEST(DiscardBlockIP, first_ping)
         unsigned char addr[4] = {0, 0, 0, 0};
 		dns_add_A(request->response_packet, DNS_RRS_AN, request->domain.c_str(), 611, addr);
 		request->response_packet->head.rcode = DNS_RC_NOERROR;
-		return true;
+		return smartdns::SERVER_REQUEST_OK;
 	});
 
     server_upstream2.Start("udp://0.0.0.0:62053", [](struct smartdns::ServerRequestContext *request) {
@@ -22,7 +22,7 @@ TEST(DiscardBlockIP, first_ping)
         usleep(20000);
 		dns_add_A(request->response_packet, DNS_RRS_AN, request->domain.c_str(), 611, addr);
 		request->response_packet->head.rcode = DNS_RC_NOERROR;
-		return true;
+		return smartdns::SERVER_REQUEST_OK;
 	});
 
 	server.Start(R"""(bind [::]:60053
@@ -51,7 +51,7 @@ TEST(DiscardBlockIP, first_response)
         unsigned char addr[4] = {0, 0, 0, 0};
 		dns_add_A(request->response_packet, DNS_RRS_AN, request->domain.c_str(), 611, addr);
 		request->response_packet->head.rcode = DNS_RC_NOERROR;
-		return true;
+		return smartdns::SERVER_REQUEST_OK;
 	});
 
     server_upstream2.Start("udp://0.0.0.0:62053", [](struct smartdns::ServerRequestContext *request) {
@@ -59,7 +59,7 @@ TEST(DiscardBlockIP, first_response)
         usleep(20000);
 		dns_add_A(request->response_packet, DNS_RRS_AN, request->domain.c_str(), 611, addr);
 		request->response_packet->head.rcode = DNS_RC_NOERROR;
-		return true;
+		return smartdns::SERVER_REQUEST_OK;
 	});
 
 	server.Start(R"""(bind [::]:60053

+ 28 - 1
test/cases/test-mock-server.cc

@@ -9,9 +9,36 @@ TEST(MockServer, query_fail)
 	smartdns::Client client;
 	server.Start("udp://0.0.0.0:7053", [](struct smartdns::ServerRequestContext *request) {
 		request->response_data_len = 0;
-		return false;
+		return smartdns::SERVER_REQUEST_ERROR;
 	});
 
 	ASSERT_TRUE(client.Query("example.com", 7053));
+	std::cout << client.GetResult() << std::endl;
 	EXPECT_EQ(client.GetStatus(), "SERVFAIL");
 }
+
+TEST(MockServer, soa)
+{
+	smartdns::MockServer server;
+	smartdns::Client client;
+	server.Start("udp://0.0.0.0:7053", [](struct smartdns::ServerRequestContext *request) {
+		return smartdns::SERVER_REQUEST_SOA;
+	});
+
+	ASSERT_TRUE(client.Query("example.com", 7053));
+	std::cout << client.GetResult() << std::endl;
+	EXPECT_EQ(client.GetStatus(), "NXDOMAIN");
+}
+
+TEST(MockServer, noerror)
+{
+	smartdns::MockServer server;
+	smartdns::Client client;
+	server.Start("udp://0.0.0.0:7053", [](struct smartdns::ServerRequestContext *request) {
+		return smartdns::SERVER_REQUEST_OK;
+	});
+
+	ASSERT_TRUE(client.Query("example.com", 7053));
+	std::cout << client.GetResult() << std::endl;
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+}

+ 53 - 2
test/server.cc

@@ -142,13 +142,37 @@ void MockServer::Run()
 			head.aa = 0;
 			head.rd = 0;
 			head.ra = 1;
-			head.rcode = DNS_RC_SERVFAIL;
+			head.rcode = DNS_RC_NOERROR;
 			dns_packet_init(request.response_packet, sizeof(response_packet_buff), &head);
 
 			auto callback_ret = callback_(&request);
-			if (callback_ret == false || request.response_data_len == 0) {
+			if (callback_ret == SERVER_REQUEST_ERROR) {
+				dns_packet_init(request.response_packet, sizeof(response_packet_buff), &head);
+				request.response_packet->head.rcode = DNS_RC_SERVFAIL;
+				dns_add_domain(request.response_packet, request.domain.c_str(), request.qtype, request.qclass);
 				request.response_data_len =
 					dns_encode(request.response_data, request.response_data_max_len, request.response_packet);
+			} else if (request.response_data_len == 0) {
+				if (callback_ret == SERVER_REQUEST_OK) {
+					request.response_data_len =
+						dns_encode(request.response_data, request.response_data_max_len, request.response_packet);
+				} else if (callback_ret == SERVER_REQUEST_SOA) {
+					struct dns_soa soa;
+					memset(&soa, 0, sizeof(soa));
+					strncpy(soa.mname, "ns1.example.com", sizeof(soa.mname));
+					strncpy(soa.rname, "hostmaster.example.com", sizeof(soa.rname));
+					soa.serial = 1;
+					soa.refresh = 3600;
+					soa.retry = 600;
+					soa.expire = 86400;
+					soa.minimum = 3600;
+					dns_packet_init(request.response_packet, sizeof(response_packet_buff), &head);
+					dns_add_domain(request.response_packet, request.domain.c_str(), request.qtype, request.qclass);
+					request.response_packet->head.rcode = DNS_RC_NXDOMAIN;
+					dns_add_SOA(request.response_packet, DNS_RRS_AN, request.domain.c_str(), 1, &soa);
+					request.response_data_len =
+						dns_encode(request.response_data, request.response_data_max_len, request.response_packet);
+				}
 			}
 
 			sendto(fd_, request.response_data, request.response_data_len, MSG_NOSIGNAL, (struct sockaddr *)&from,
@@ -157,6 +181,33 @@ void MockServer::Run()
 	}
 }
 
+bool MockServer::AddIP(struct ServerRequestContext *request, const std::string &domain, const std::string &ip, int ttl)
+{
+	struct sockaddr_storage addr;
+	socklen_t addrlen = sizeof(addr);
+	memset(&addr, 0, sizeof(addr));
+
+	if (GetAddr(ip, "53", SOCK_DGRAM, IPPROTO_UDP, &addr, &addrlen)) {
+		if (addr.ss_family == AF_INET) {
+			struct sockaddr_in *addr4 = (struct sockaddr_in *)&addr;
+			dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), ttl,
+					  (unsigned char *)&addr4->sin_addr.s_addr);
+		} else if (addr.ss_family == AF_INET6) {
+			struct sockaddr_in6 *addr6 = (struct sockaddr_in6 *)&addr;
+			if (IN6_IS_ADDR_V4MAPPED(&addr6->sin6_addr)) {
+				dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), ttl,
+						  (unsigned char *)&addr6->sin6_addr.s6_addr[12]);
+				return true;
+			}
+			dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), ttl,
+						 (unsigned char *)&addr6->sin6_addr.s6_addr);
+		}
+		return true;
+	}
+
+	return false;
+}
+
 bool MockServer::GetAddr(const std::string &host, const std::string port, int type, int protocol,
 						 struct sockaddr_storage *addr, socklen_t *addrlen)
 

+ 10 - 2
test/server.h

@@ -65,7 +65,13 @@ struct ServerRequestContext {
 	int response_data_len;
 };
 
-using ServerRequest = std::function<bool(struct ServerRequestContext *request)>;
+typedef enum {
+	SERVER_REQUEST_OK = 0,
+	SERVER_REQUEST_ERROR,
+	SERVER_REQUEST_SOA,
+} ServerRequestResult;
+
+using ServerRequest = std::function<ServerRequestResult(struct ServerRequestContext *request)>;
 
 class MockServer
 {
@@ -77,10 +83,12 @@ class MockServer
 	void Stop();
 	bool IsRunning();
 
+	static bool AddIP(struct ServerRequestContext *request, const std::string &domain, const std::string &ip, int ttl = 60);
+
   private:
 	void Run();
 
-	bool GetAddr(const std::string &host, const std::string port, int type, int protocol, struct sockaddr_storage *addr,
+	static bool GetAddr(const std::string &host, const std::string port, int type, int protocol, struct sockaddr_storage *addr,
 				 socklen_t *addrlen);
 	int fd_;
 	std::thread thread_;