2
0
Эх сурвалжийг харах

feature: supports setting the maximum number of concurrent queries

Nick Peng 2 жил өмнө
parent
commit
60b5e2b643

+ 4 - 0
etc/smartdns/smartdns.conf

@@ -151,6 +151,10 @@ force-qtype-SOA 65
 # example:
 # max-reply-ip-num 1
 
+# Maximum number of queries per second|0|number of queries, 0 means no limit.
+# example:
+# max-query-limit 65535
+
 # response mode
 # Experimental feature
 # response-mode [first-ping|fastest-ip|fastest-response]

+ 3 - 1
src/dns_conf.c

@@ -79,6 +79,7 @@ char dns_conf_bind_ca_key_pass[DNS_MAX_PATH];
 char dns_conf_need_cert = 0;
 
 int dns_conf_max_reply_ip_num = DNS_MAX_REPLY_IP_NUM;
+int dns_conf_max_query_limit = DNS_MAX_QUERY_LIMIT;
 
 static struct config_enum_list dns_conf_response_mode_enum[] = {
 	{"first-ping", DNS_RESPONSE_MODE_FIRST_PING_IP},
@@ -1901,7 +1902,7 @@ struct dns_srv_records *dns_server_get_srv_record(const char *domain)
 }
 
 static int _confg_srv_record_add(const char *domain, const char *host, unsigned short priority, unsigned short weight,
-							   unsigned short port)
+								 unsigned short port)
 {
 	struct dns_srv_records *srv_records = NULL;
 	struct dns_srv_record *srv_record = NULL;
@@ -4346,6 +4347,7 @@ static struct config_item _config_item[] = {
 	CONF_INT("rr-ttl-reply-max", &dns_conf_rr_ttl_reply_max, 0, CONF_INT_MAX),
 	CONF_INT("local-ttl", &dns_conf_local_ttl, 0, CONF_INT_MAX),
 	CONF_INT("max-reply-ip-num", &dns_conf_max_reply_ip_num, 1, CONF_INT_MAX),
+	CONF_INT("max-query-limit", &dns_conf_max_query_limit, 0, CONF_INT_MAX),
 	CONF_ENUM("response-mode", &dns_conf_response_mode, &dns_conf_response_mode_enum),
 	CONF_YESNO("force-AAAA-SOA", &dns_conf_force_AAAA_SOA),
 	CONF_YESNO("force-no-CNAME", &dns_conf_force_no_cname),

+ 2 - 0
src/dns_conf.h

@@ -59,6 +59,7 @@ extern "C" {
 #define DNS_MAX_CONF_CNAME_LEN 256
 #define MAX_QTYPE_NUM 65535
 #define DNS_MAX_REPLY_IP_NUM 8
+#define DNS_MAX_QUERY_LIMIT 65535
 #define DNS_DEFAULT_CHECKPOINT_TIME (3600 * 24)
 
 #define SMARTDNS_CONF_FILE "/etc/smartdns/smartdns.conf"
@@ -567,6 +568,7 @@ extern int dns_conf_dualstack_ip_allow_force_AAAA;
 extern int dns_conf_dualstack_ip_selection_threshold;
 
 extern int dns_conf_max_reply_ip_num;
+extern int dns_conf_max_query_limit;
 extern enum response_mode_type dns_conf_response_mode;
 
 extern int dns_conf_rr_ttl;

+ 19 - 1
src/dns_server.c

@@ -343,6 +343,7 @@ struct dns_server {
 	/* dns request list */
 	pthread_mutex_t request_list_lock;
 	struct list_head request_list;
+	atomic_t request_num;
 
 	DECLARE_HASHTABLE(request_pending, 4);
 	pthread_mutex_t request_pending_lock;
@@ -2595,6 +2596,7 @@ static void _dns_server_delete_request(struct dns_request *request)
 	pthread_mutex_destroy(&request->ip_map_lock);
 	memset(request, 0, sizeof(*request));
 	free(request);
+	atomic_dec(&server.request_num);
 }
 
 static void _dns_server_complete_with_multi_ipaddress(struct dns_request *request)
@@ -2787,6 +2789,7 @@ static struct dns_request *_dns_server_new_request(void)
 	INIT_LIST_HEAD(&request->check_list);
 	hash_init(request->ip_map);
 	_dns_server_request_get(request);
+	atomic_add(1, &server.request_num);
 
 	return request;
 errout:
@@ -5983,7 +5986,21 @@ static int _dns_server_recv(struct dns_server_conn_head *conn, unsigned char *in
 		goto errout;
 	}
 
-	tlog(TLOG_DEBUG, "query %s from %s, qtype: %d, id: %d\n", request->domain, name, request->qtype, request->id);
+	tlog(TLOG_DEBUG, "query %s from %s, qtype: %d, id: %d, query-num: %d", request->domain, name, request->qtype,
+		 request->id, atomic_read(&server.request_num));
+
+	if (atomic_read(&server.request_num) > dns_conf_max_query_limit && dns_conf_max_query_limit > 0) {
+		static time_t last_log_time = 0;
+		time_t now = time(NULL);
+		if (now - last_log_time > 120) {
+			last_log_time = now;
+			tlog(TLOG_WARN, "maximum number of dns queries reached, max: %d", dns_conf_max_query_limit);
+		}
+		request->send_tick = get_tick_count();
+		request->rcode = DNS_RC_REFUSED;
+		ret = 0;
+		goto errout;
+	}
 
 	ret = _dns_server_do_query(request, 1);
 	if (ret != 0) {
@@ -7860,6 +7877,7 @@ int dns_server_init(void)
 	pthread_attr_init(&attr);
 	INIT_LIST_HEAD(&server.conn_list);
 	time(&server.cache_save_time);
+	atomic_set(&server.request_num, 0);
 
 	epollfd = epoll_create1(EPOLL_CLOEXEC);
 	if (epollfd < 0) {

+ 0 - 1
test/cases/test-same-pending-query.cc

@@ -73,7 +73,6 @@ log-level error
 )""");
 
 	std::vector<std::thread> threads;
-	uint64_t tick = get_tick_count();
 	for (int i = 0; i < 5; i++) {
 		auto t = std::thread([&]() {
 			for (int j = 0; j < 10; j++) {

+ 61 - 0
test/cases/test-server.cc

@@ -157,4 +157,65 @@ server 127.0.0.1:61053
 	EXPECT_EQ(client.GetStatus(), "SERVFAIL");
 	EXPECT_GE(client.GetQueryTime(), 1500);
 	EXPECT_GE(count, 4);
+}
+
+TEST_F(Server, max_queries)
+{
+	smartdns::MockServer server_upstream;
+	smartdns::MockServer server_upstream1;
+	smartdns::Server server;
+	int count = 0;
+
+	server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
+		smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 611);
+		sleep(1);
+		return smartdns::SERVER_REQUEST_OK;
+	});
+
+	server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 128, 10);
+
+	server.Start(R"""(bind [::]:60053
+bind-tcp [::]:60053
+server 127.0.0.1:61053
+dualstack-ip-selection no
+max-query-limit 2
+)""");
+
+	std::vector<std::thread> threads;
+	int success_num = 0;
+	int refused_num = 0;
+	for (int i = 0; i < 5; i++) {
+		auto t = std::thread([&]() {
+			smartdns::Client client;
+			ASSERT_TRUE(client.Query("a.com", 60053));
+			if (client.GetStatus() == "NOERROR") {
+				success_num++;
+				EXPECT_EQ(client.GetStatus(), "NOERROR");
+				ASSERT_EQ(client.GetAnswerNum(), 1);
+				EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
+				EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
+			} else if (client.GetStatus() == "REFUSED") {
+				refused_num++;
+			} else {
+				FAIL();
+			}
+		});
+		threads.push_back(std::move(t));
+	}
+
+	for (auto &t : threads) {
+		t.join();
+	}
+
+	EXPECT_EQ(success_num, 2);
+	EXPECT_EQ(refused_num, 3);
+
+	for (int i = 0; i < 5; i++) {
+		smartdns::Client client;
+		ASSERT_TRUE(client.Query("a.com", 60053));
+		EXPECT_EQ(client.GetStatus(), "NOERROR");
+		ASSERT_EQ(client.GetAnswerNum(), 1);
+		EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
+		EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
+	}
 }