Răsfoiți Sursa

mdns: add test for mdns-lookup

Nick Peng 2 ani în urmă
părinte
comite
037f10d3c0
5 a modificat fișierele cu 259 adăugiri și 19 ștergeri
  1. 16 3
      src/dns_client.c
  2. 7 0
      src/dns_client.h
  3. 81 15
      src/dns_server.c
  4. 7 1
      src/util.c
  5. 148 0
      test/cases/test-mdns.cc

+ 16 - 3
src/dns_client.c

@@ -70,9 +70,6 @@
 #define SOCKET_PRIORITY (6)
 #define SOCKET_IP_TOS (IPTOS_LOWDELAY | IPTOS_RELIABILITY)
 
-#define DNS_MDNS_IP "224.0.0.251"
-#define DNS_MDNS_PORT 5353
-
 /* ECS info */
 struct dns_client_ecs {
 	int enable;
@@ -4613,6 +4610,22 @@ static int _dns_client_add_mdns_server(void)
 		goto errout;
 	}
 
+#ifdef TEST
+	ret = _dns_client_server_add(DNS_MDNS_IP, "lo", DNS_MDNS_PORT, DNS_SERVER_MDNS, &server_flags);
+	if (ret != 0) {
+		tlog(TLOG_ERROR, "add mdns server failed.");
+		goto errout;
+	}
+
+	if (dns_client_add_to_group(DNS_SERVER_GROUP_MDNS, DNS_MDNS_IP, DNS_MDNS_PORT, DNS_SERVER_MDNS, &server_flags) !=
+		0) {
+		tlog(TLOG_ERROR, "add mdns server to group failed.");
+		goto errout;
+	}
+
+	return 0;
+#endif
+
 	if (getifaddrs(&ifaddr) == -1) {
 		goto errout;
 	}

+ 7 - 0
src/dns_client.h

@@ -29,6 +29,13 @@ extern "C" {
 #define DNS_SERVER_GROUP_DEFAULT "default"
 #define DNS_SERVER_GROUP_MDNS "mdns"
 #define DNS_SERVER_GROUP_LOCAL "local"
+#ifdef TEST
+#define DNS_MDNS_IP "127.0.0.1"
+#define DNS_MDNS_PORT 55353
+#else
+#define DNS_MDNS_IP "224.0.0.251"
+#define DNS_MDNS_PORT 5353
+#endif
 
 typedef enum {
 	DNS_SERVER_UDP,

+ 81 - 15
src/dns_server.c

@@ -433,6 +433,10 @@ static int _dns_server_get_conf_ttl(struct dns_request *request, int ttl)
 	int rr_ttl_min = dns_conf_rr_ttl_min;
 	int rr_ttl_max = dns_conf_rr_ttl_max;
 
+	if (request->is_mdns_lookup) {
+		rr_ttl_min = DNS_SERVER_ADDR_TTL;
+	}
+
 	struct dns_ttl_rule *ttl_rule = _dns_server_get_dns_rule(request, DOMAIN_RULE_TTL);
 	if (ttl_rule != NULL) {
 		if (ttl_rule->ttl > 0) {
@@ -1407,11 +1411,16 @@ static int _dns_server_get_cache_timeout(struct dns_request *request, struct dns
 {
 	int timeout = 0;
 	int prefetch_time = 0;
+	int is_serve_expired = dns_conf_serve_expired;
 
 	if (request->rcode != DNS_RC_NOERROR) {
 		return ttl + 1;
 	}
 
+	if (request->is_mdns_lookup == 1) {
+		return ttl + 1;
+	}
+
 	if (dns_conf_prefetch && _dns_cache_is_specify_packet(request->qtype) != 0) {
 		prefetch_time = 1;
 	}
@@ -1424,8 +1433,12 @@ static int _dns_server_get_cache_timeout(struct dns_request *request, struct dns
 		prefetch_time = 0;
 	}
 
+	if (request->no_serve_expired) {
+		is_serve_expired = 0;
+	}
+
 	if (prefetch_time == 1) {
-		if (dns_conf_serve_expired) {
+		if (is_serve_expired) {
 			timeout = dns_conf_serve_expired_prefetch_time;
 			if (timeout == 0) {
 				timeout = dns_conf_serve_expired_ttl / 2;
@@ -1451,7 +1464,7 @@ static int _dns_server_get_cache_timeout(struct dns_request *request, struct dns
 		}
 	} else {
 		timeout = ttl;
-		if (dns_conf_serve_expired) {
+		if (is_serve_expired) {
 			timeout += dns_conf_serve_expired_ttl;
 		}
 
@@ -1491,6 +1504,12 @@ static int _dns_server_request_update_cache(struct dns_request *request, int spe
 	cache_key.query_flag = request->server_flags;
 
 	if (request->prefetch) {
+		/* no prefetch for mdns request */
+		if (request->is_mdns_lookup) {
+			ret = 0;
+			goto errout;
+		}
+
 		if (dns_cache_replace(&cache_key, request->rcode, ttl, speed,
 							  _dns_server_get_cache_timeout(request, &cache_key, ttl),
 							  !(request->prefetch_flags & PREFETCH_FLAGS_EXPIRED), cache_data) != 0) {
@@ -1681,6 +1700,12 @@ static int _dns_cache_packet(struct dns_server_post_context *context)
 	cache_key.query_flag = request->server_flags;
 
 	if (request->prefetch) {
+		/* no prefetch for mdns request */
+		if (request->is_mdns_lookup) {
+			ret = 0;
+			goto errout;
+		}
+
 		if (dns_cache_replace(&cache_key, request->rcode, request->ip_ttl, -1,
 							  _dns_server_get_cache_timeout(request, &cache_key, request->ip_ttl),
 							  !(request->prefetch_flags & PREFETCH_FLAGS_EXPIRED), cache_packet) != 0) {
@@ -2219,6 +2244,31 @@ static int _dns_server_reply_all_pending_list(struct dns_request *request, struc
 	return ret;
 }
 
+static void _dns_server_need_append_mdns_local_cname(struct dns_request *request)
+{
+	if (request->is_mdns_lookup == 0) {
+		return;
+	}
+
+	if (request->has_cname != 0) {
+		return;
+	}
+
+	if (request->domain[0] == '\0') {
+		return;
+	}
+
+	if (strstr(request->domain, ".") != NULL) {
+		return;
+	}
+
+	request->has_cname = 1;
+	snprintf(request->cname, sizeof(request->cname), "%.*s.%s",
+			 (int)(sizeof(request->cname) - sizeof(DNS_SERVER_GROUP_LOCAL) - 1), request->domain,
+			 DNS_SERVER_GROUP_LOCAL);
+	return;
+}
+
 static void _dns_server_check_complete_dualstack(struct dns_request *request, struct dns_request *dualstack_request)
 {
 	if (dualstack_request == NULL || request == NULL) {
@@ -2300,6 +2350,10 @@ static int _dns_server_request_complete_with_all_IPs(struct dns_request *request
 		ttl = DNS_SERVER_FAIL_TTL;
 	}
 
+	if (request->ip_ttl == 0) {
+		request->ip_ttl = ttl;
+	}
+
 	if (request->prefetch == 1) {
 		return 0;
 	}
@@ -2320,6 +2374,8 @@ static int _dns_server_request_complete_with_all_IPs(struct dns_request *request
 		goto out;
 	}
 
+	_dns_server_need_append_mdns_local_cname(request);
+
 	if (request->has_soa) {
 		tlog(TLOG_INFO, "result: %s, qtype: %d, SOA", request->domain, request->qtype);
 	} else {
@@ -2559,6 +2615,8 @@ static void _dns_server_complete_with_multi_ipaddress(struct dns_request *reques
 		return;
 	}
 
+	_dns_server_need_append_mdns_local_cname(request);
+
 	_dns_server_post_context_init(&context, request);
 	context.do_cache = 1;
 	context.do_ipset = 1;
@@ -3367,6 +3425,7 @@ static int _dns_server_process_answer(struct dns_request *request, const char *d
 					 domain, request->qtype, request->soa.mname, request->soa.rname, request->soa.serial,
 					 request->soa.refresh, request->soa.retry, request->soa.expire, request->soa.minimum);
 
+				request->ip_ttl = _dns_server_get_conf_ttl(request, ttl);
 				int soa_num = atomic_inc_return(&request->soa_num);
 				if ((soa_num >= ((int)ceil((float)dns_server_alive_num() / 3) + 1) || soa_num > 4) &&
 					atomic_read(&request->ip_map_num) <= 0) {
@@ -3461,7 +3520,7 @@ static int _dns_server_passthrough_rule_check(struct dns_request *request, const
 					}
 				}
 
-				ttl = ttl_tmp;
+				ttl = _dns_server_get_conf_ttl(request, ttl_tmp);
 				_dns_server_request_release(request);
 			} break;
 			case DNS_T_AAAA: {
@@ -3502,7 +3561,7 @@ static int _dns_server_passthrough_rule_check(struct dns_request *request, const
 					}
 				}
 
-				ttl = ttl_tmp;
+				ttl = _dns_server_get_conf_ttl(request, ttl_tmp);
 				_dns_server_request_release(request);
 			} break;
 			case DNS_T_CNAME: {
@@ -3515,7 +3574,7 @@ static int _dns_server_passthrough_rule_check(struct dns_request *request, const
 					char tmpbuf[DNS_MAX_CNAME_LEN];
 					dns_get_CNAME(rrs, tmpname, DNS_MAX_CNAME_LEN, &ttl, tmpbuf, DNS_MAX_CNAME_LEN);
 					if (request->ip_ttl == 0) {
-						request->ip_ttl = ttl;
+						request->ip_ttl = _dns_server_get_conf_ttl(request, ttl);
 					}
 				}
 				break;
@@ -3646,6 +3705,7 @@ static int _dns_server_get_answer(struct dns_server_post_context *context)
 					 "%d, minimum: %d",
 					 request->domain, request->qtype, request->soa.mname, request->soa.rname, request->soa.serial,
 					 request->soa.refresh, request->soa.retry, request->soa.expire, request->soa.minimum);
+				request->ip_ttl = _dns_server_get_conf_ttl(request, ttl);
 			} break;
 			default:
 				break;
@@ -3708,7 +3768,7 @@ static void _dns_server_query_end(struct dns_request *request)
 	if (request->is_mdns_lookup == 1 && request->rcode == DNS_RC_SERVFAIL) {
 		request->rcode = DNS_RC_NOERROR;
 		request->force_soa = 1;
-		request->ip_ttl = _dns_server_get_local_ttl(request);
+		request->ip_ttl = _dns_server_get_conf_ttl(request, DNS_SERVER_ADDR_TTL);
 	}
 
 	pthread_mutex_lock(&request->ip_map_lock);
@@ -5560,6 +5620,16 @@ static int _dns_server_setup_query_option(struct dns_request *request, struct dn
 	return 0;
 }
 
+static void _dns_server_mdns_query_setup_server_group(struct dns_request *request, const char **group_name)
+{
+	if (request->is_mdns_lookup == 0 || group_name == NULL) {
+		return;
+	}
+
+	*group_name = DNS_SERVER_GROUP_MDNS;
+	return;
+}
+
 static int _dns_server_mdns_query_setup(struct dns_request *request, const char *group_name, char **request_domain,
 										char *domain_buffer, int domain_buffer_len)
 {
@@ -5672,6 +5742,10 @@ static int _dns_server_do_query(struct dns_request *request, int skip_notify_eve
 		safe_strncpy(request->dns_group_name, group_name, DNS_GROUP_NAME_LEN);
 	}
 
+	if (_dns_server_mdns_query_setup(request, group_name, &request_domain, domain_buffer, sizeof(domain_buffer)) != 0) {
+		goto errout;
+	}
+
 	if (_dns_server_process_cname_pre(request) != 0) {
 		goto errout;
 	}
@@ -5738,6 +5812,7 @@ static int _dns_server_do_query(struct dns_request *request, int skip_notify_eve
 
 	// setup options
 	_dns_server_setup_query_option(request, &options);
+	_dns_server_mdns_query_setup_server_group(request, &group_name);
 
 	pthread_mutex_lock(&server.request_list_lock);
 	if (list_empty(&server.request_list) && skip_notify_event == 1) {
@@ -5746,15 +5821,6 @@ static int _dns_server_do_query(struct dns_request *request, int skip_notify_eve
 	list_add_tail(&request->list, &server.request_list);
 	pthread_mutex_unlock(&server.request_list_lock);
 
-	if (_dns_server_mdns_query_setup(request, group_name, &request_domain, domain_buffer, sizeof(domain_buffer)) != 0) {
-		goto errout;
-	}
-
-	/* if request MDNS */
-	if (request->is_mdns_lookup) {
-		group_name = DNS_SERVER_GROUP_MDNS;
-	}
-
 	// Get reference for DNS query
 	request->request_wait++;
 	_dns_server_request_get(request);

+ 7 - 1
src/util.c

@@ -449,6 +449,11 @@ int check_is_ipv6(const char *ip)
 			continue;
 		}
 
+		/* scope id, end of ipv6 address*/
+		if (c == '%') {
+			break;
+		}
+
 		if (c == ':') {
 			colon_num++;
 			dig_num = 0;
@@ -1949,7 +1954,8 @@ static int _dns_debug_display(struct dns_packet *packet)
 				char name[DNS_MAX_CNAME_LEN] = {0};
 				char target[DNS_MAX_CNAME_LEN];
 
-				ret = dns_get_SRV(rrs, name, DNS_MAX_CNAME_LEN, &ttl, &priority, &weight, &port, target, DNS_MAX_CNAME_LEN);
+				ret = dns_get_SRV(rrs, name, DNS_MAX_CNAME_LEN, &ttl, &priority, &weight, &port, target,
+								  DNS_MAX_CNAME_LEN);
 				if (ret < 0) {
 					tlog(TLOG_DEBUG, "decode SRV failed, %s", name);
 					return -1;

+ 148 - 0
test/cases/test-mdns.cc

@@ -0,0 +1,148 @@
+/*************************************************************************
+ *
+ * Copyright (C) 2018-2023 Ruilin Peng (Nick) <[email protected]>.
+ *
+ * smartdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * smartdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include "client.h"
+#include "dns.h"
+#include "dns_client.h"
+#include "include/utils.h"
+#include "server.h"
+#include "gtest/gtest.h"
+#include <fstream>
+
+class mDNS : public ::testing::Test
+{
+  protected:
+	virtual void SetUp() {}
+	virtual void TearDown() {}
+};
+
+TEST(mDNS, query)
+{
+	smartdns::MockServer server_upstream1;
+	smartdns::MockServer server_upstream2;
+	smartdns::Server server;
+
+	std::string listen_url = "udp://";
+	listen_url += DNS_MDNS_IP;
+	listen_url += ":" + std::to_string(DNS_MDNS_PORT);
+
+	server_upstream1.Start(listen_url.c_str(), [](struct smartdns::ServerRequestContext *request) {
+		std::string domain = request->domain;
+		if (request->domain.length() == 0) {
+			return smartdns::SERVER_REQUEST_ERROR;
+		}
+
+		if (request->qtype == DNS_T_A) {
+			unsigned char addr[][4] = {{1, 2, 3, 4}};
+			dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[0]);
+		} else if (request->qtype == DNS_T_AAAA) {
+			unsigned char addr[][16] = {{1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}};
+			dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[0]);
+		} else {
+			return smartdns::SERVER_REQUEST_ERROR;
+		}
+
+		return smartdns::SERVER_REQUEST_OK;
+	});
+
+	server_upstream2.Start("udp://0.0.0.0:61053",
+						   [](struct smartdns::ServerRequestContext *request) { return smartdns::SERVER_REQUEST_SOA; });
+
+	server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 100);
+	server.MockPing(PING_TYPE_ICMP, "102:304:500::1", 60, 100);
+
+	server.Start(R"""(bind [::]:60053
+server 127.0.0.1:61053
+dualstack-ip-selection no
+mdns-lookup yes
+)""");
+	smartdns::Client client;
+
+	ASSERT_TRUE(client.Query("b.com A", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 0);
+	EXPECT_EQ(client.GetStatus(), "NXDOMAIN");
+
+	ASSERT_TRUE(client.Query("host A", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 2);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "host");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "host.local.");
+	EXPECT_EQ(client.GetAnswer()[1].GetName(), "host.local");
+	EXPECT_EQ(client.GetAnswer()[1].GetData(), "1.2.3.4");
+
+	ASSERT_TRUE(client.Query("host AAAA", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 2);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "host");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "host.local.");
+	EXPECT_EQ(client.GetAnswer()[1].GetName(), "host.local");
+	EXPECT_EQ(client.GetAnswer()[1].GetData(), "102:304:500::1");
+
+	ASSERT_TRUE(client.Query("host.local A", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "host.local");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
+}
+
+TEST(mDNS, ptr)
+{
+	smartdns::MockServer server_upstream1;
+	smartdns::MockServer server_upstream2;
+	smartdns::Server server;
+
+	std::string listen_url = "udp://";
+	listen_url += DNS_MDNS_IP;
+	listen_url += ":" + std::to_string(DNS_MDNS_PORT);
+
+	server_upstream1.Start(listen_url.c_str(), [](struct smartdns::ServerRequestContext *request) {
+		std::string domain = request->domain;
+		if (request->domain.length() == 0) {
+			return smartdns::SERVER_REQUEST_ERROR;
+		}
+
+		if (request->qtype != DNS_T_PTR) {
+			return smartdns::SERVER_REQUEST_SOA;
+		}
+
+		dns_add_PTR(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, "host.local");
+
+		return smartdns::SERVER_REQUEST_OK;
+	});
+
+	server_upstream2.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
+		return smartdns::SERVER_REQUEST_ERROR;
+	});
+
+	server.Start(R"""(bind [::]:60053
+server 127.0.0.1:61053
+dualstack-ip-selection no
+mdns-lookup yes
+)""");
+	smartdns::Client client;
+	ASSERT_TRUE(client.Query("-x 192.168.1.1", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "1.1.168.192.in-addr.arpa");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "host.local.");
+}