Browse Source

cache: modify cache ver check method, add ipset, nftset after restart.

Nick Peng 2 năm trước cách đây
mục cha
commit
582cdfb879

+ 45 - 9
src/dns_cache.c

@@ -147,7 +147,7 @@ void dns_cache_data_free(struct dns_cache_data *data)
 	free(data);
 }
 
-struct dns_cache_data *dns_cache_new_data(void)
+struct dns_cache_data *dns_cache_new_data_addr(void)
 {
 	struct dns_cache_addr *cache_addr = malloc(sizeof(struct dns_cache_addr));
 	memset(cache_addr, 0, sizeof(struct dns_cache_addr));
@@ -157,6 +157,7 @@ struct dns_cache_data *dns_cache_new_data(void)
 
 	cache_addr->head.cache_type = CACHE_TYPE_NONE;
 	cache_addr->head.size = sizeof(struct dns_cache_addr) - sizeof(struct dns_cache_data_head);
+	cache_addr->head.magic = MAGIC_CACHE_DATA;
 
 	return (struct dns_cache_data *)cache_addr;
 }
@@ -241,6 +242,7 @@ struct dns_cache_data *dns_cache_new_data_packet(void *packet, size_t packet_len
 
 	cache_packet->head.cache_type = CACHE_TYPE_PACKET;
 	cache_packet->head.size = packet_len;
+	cache_packet->head.magic = MAGIC_CACHE_DATA;
 
 	return (struct dns_cache_data *)cache_packet;
 }
@@ -274,6 +276,7 @@ static int _dns_cache_replace(struct dns_cache_key *cache_key, int ttl, int spee
 	dns_cache->info.ttl = ttl;
 	dns_cache->info.speed = speed;
 	dns_cache->info.no_inactive = no_inactive;
+	dns_cache->info.is_visited = 1;
 	old_cache_data = dns_cache->cache_data;
 	dns_cache->cache_data = cache_data;
 	list_del_init(&dns_cache->list);
@@ -294,12 +297,14 @@ static int _dns_cache_replace(struct dns_cache_key *cache_key, int ttl, int spee
 	return 0;
 }
 
-int dns_cache_replace(struct dns_cache_key *cache_key, int ttl, int speed, int no_inactive, struct dns_cache_data *cache_data)
+int dns_cache_replace(struct dns_cache_key *cache_key, int ttl, int speed, int no_inactive,
+					  struct dns_cache_data *cache_data)
 {
 	return _dns_cache_replace(cache_key, ttl, speed, no_inactive, 0, cache_data);
 }
 
-int dns_cache_replace_inactive(struct dns_cache_key *cache_key, int ttl, int speed, int no_inactive, struct dns_cache_data *cache_data)
+int dns_cache_replace_inactive(struct dns_cache_key *cache_key, int ttl, int speed, int no_inactive,
+							   struct dns_cache_data *cache_data)
 {
 	return _dns_cache_replace(cache_key, ttl, speed, no_inactive, 1, cache_data);
 }
@@ -391,7 +396,8 @@ errout:
 	return -1;
 }
 
-int dns_cache_insert(struct dns_cache_key *cache_key, int ttl, int speed, int no_inactive, struct dns_cache_data *cache_data)
+int dns_cache_insert(struct dns_cache_key *cache_key, int ttl, int speed, int no_inactive,
+					 struct dns_cache_data *cache_data)
 {
 	struct dns_cache_info info;
 
@@ -418,6 +424,7 @@ int dns_cache_insert(struct dns_cache_key *cache_key, int ttl, int speed, int no
 	info.hitnum_update_add = DNS_CACHE_HITNUM_STEP;
 	info.speed = speed;
 	info.no_inactive = no_inactive;
+	info.is_visited = 1;
 	time(&info.insert_time);
 	time(&info.replace_time);
 
@@ -541,6 +548,11 @@ struct dns_cache_data *dns_cache_get_data(struct dns_cache *dns_cache)
 	return dns_cache->cache_data;
 }
 
+int dns_cache_is_visited(struct dns_cache *dns_cache)
+{
+	return dns_cache->info.is_visited;
+}
+
 void dns_cache_delete(struct dns_cache *dns_cache)
 {
 	pthread_mutex_lock(&dns_cache_head.lock);
@@ -574,6 +586,7 @@ void dns_cache_update(struct dns_cache *dns_cache)
 		if (dns_cache->info.hitnum_update_add < DNS_CACHE_HITNUM_STEP_MAX) {
 			dns_cache->info.hitnum_update_add++;
 		}
+		dns_cache->info.is_visited = 1;
 	}
 	pthread_mutex_unlock(&dns_cache_head.lock);
 }
@@ -707,15 +720,18 @@ static int _dns_cache_read_record(int fd, uint32_t cache_number)
 			goto errout;
 		}
 
-		if (cache_record.magic != MAGIC_CACHE_DATA) {
+		if (cache_record.magic != MAGIC_RECORD) {
 			tlog(TLOG_ERROR, "magic is invalid.");
 			goto errout;
 		}
 
 		if (cache_record.type == CACHE_RECORD_TYPE_ACTIVE) {
 			head = &dns_cache_head.cache_list;
-		} else {
+		} else if (cache_record.type == CACHE_RECORD_TYPE_INACTIVE) {
 			head = &dns_cache_head.inactive_list;
+		} else {
+			tlog(TLOG_ERROR, "read cache record type is invalid.");
+			goto errout;
 		}
 
 		ret = read(fd, &data_head, sizeof(data_head));
@@ -724,6 +740,11 @@ static int _dns_cache_read_record(int fd, uint32_t cache_number)
 			goto errout;
 		}
 
+		if (data_head.magic != MAGIC_CACHE_DATA) {
+			tlog(TLOG_ERROR, "data magic is invalid.");
+			goto errout;
+		}
+
 		if (data_head.size > 1024 * 8) {
 			tlog(TLOG_ERROR, "data may invalid, skip load cache.");
 			goto errout;
@@ -742,6 +763,15 @@ static int _dns_cache_read_record(int fd, uint32_t cache_number)
 			goto errout;
 		}
 
+		/* set cache unvisited, so that when refreshing ipset/nftset, reload ipset list by restarting smartdns */
+		cache_record.info.is_visited = 0;
+		cache_record.info.domain[DNS_MAX_CNAME_LEN - 1] = '\0';
+		cache_record.info.dns_group_name[DNS_GROUP_NAME_LEN - 1] = '\0';
+		if (cache_record.type >= CACHE_RECORD_TYPE_END) {
+			tlog(TLOG_ERROR, "read cache record type is invalid.");
+			goto errout;
+		}
+
 		if (_dns_cache_insert(&cache_record.info, cache_data, head) != 0) {
 			tlog(TLOG_ERROR, "insert cache data failed.");
 			cache_data = NULL;
@@ -786,7 +816,7 @@ int dns_cache_load(const char *file)
 		goto errout;
 	}
 
-	if (strncmp(cache_file.version, __TIMESTAMP__, DNS_CACHE_VERSION_LEN - 1) != 0) {
+	if (strncmp(cache_file.version, dns_cache_file_version(), DNS_CACHE_VERSION_LEN) != 0) {
 		tlog(TLOG_WARN, "cache version is different, skip load cache.");
 		goto errout;
 	}
@@ -815,7 +845,7 @@ static int _dns_cache_write_record(int fd, uint32_t *cache_number, enum CACHE_RE
 	pthread_mutex_lock(&dns_cache_head.lock);
 	list_for_each_entry_safe_reverse(dns_cache, tmp, head, list)
 	{
-		cache_record.magic = MAGIC_CACHE_DATA;
+		cache_record.magic = MAGIC_RECORD;
 		cache_record.type = type;
 		memcpy(&cache_record.info, &dns_cache->info, sizeof(struct dns_cache_info));
 		ssize_t ret = write(fd, &cache_record, sizeof(cache_record));
@@ -871,7 +901,7 @@ int dns_cache_save(const char *file)
 	struct dns_cache_file cache_file;
 	memset(&cache_file, 0, sizeof(cache_file));
 	cache_file.magic = MAGIC_NUMBER;
-	safe_strncpy(cache_file.version, __TIMESTAMP__, DNS_CACHE_VERSION_LEN);
+	safe_strncpy(cache_file.version, dns_cache_file_version(), DNS_CACHE_VERSION_LEN);
 	cache_file.cache_number = 0;
 
 	if (lseek(fd, sizeof(cache_file), SEEK_SET) < 0) {
@@ -926,3 +956,9 @@ void dns_cache_destroy(void)
 
 	pthread_mutex_destroy(&dns_cache_head.lock);
 }
+
+const char *dns_cache_file_version(void)
+{
+	const char *version = "cache ver 1.0";
+	return version;
+}

+ 14 - 4
src/dns_cache.h

@@ -36,7 +36,8 @@ extern "C" {
 #define DNS_CACHE_VERSION_LEN 32
 #define DNS_CACHE_GROUP_NAME_LEN 32
 #define MAGIC_NUMBER 0x6548634163536e44
-#define MAGIC_CACHE_DATA 0x44615461
+#define MAGIC_CACHE_DATA 0x61546144
+#define MAGIC_RECORD 0x64526352
 
 enum CACHE_TYPE {
 	CACHE_TYPE_NONE,
@@ -47,12 +48,14 @@ enum CACHE_TYPE {
 enum CACHE_RECORD_TYPE {
 	CACHE_RECORD_TYPE_ACTIVE,
 	CACHE_RECORD_TYPE_INACTIVE,
+	CACHE_RECORD_TYPE_END,
 };
 
 struct dns_cache_data_head {
 	enum CACHE_TYPE cache_type;
 	int is_soa;
 	ssize_t size;
+	uint32_t magic;
 };
 
 struct dns_cache_data {
@@ -89,6 +92,7 @@ struct dns_cache_info {
 	int speed;
 	int no_inactive;
 	int hitnum_update_add;
+	int is_visited;
 	time_t insert_time;
 	time_t replace_time;
 };
@@ -136,9 +140,11 @@ struct dns_cache_data *dns_cache_new_data_packet(void *packet, size_t packet_len
 
 int dns_cache_init(int size, int enable_inactive, int inactive_list_expired);
 
-int dns_cache_replace(struct dns_cache_key *key, int ttl, int speed, int no_inactive, struct dns_cache_data *cache_data);
+int dns_cache_replace(struct dns_cache_key *key, int ttl, int speed, int no_inactive,
+					  struct dns_cache_data *cache_data);
 
-int dns_cache_replace_inactive(struct dns_cache_key *key, int ttl, int speed, int no_inactive, struct dns_cache_data *cache_data);
+int dns_cache_replace_inactive(struct dns_cache_key *key, int ttl, int speed, int no_inactive,
+							   struct dns_cache_data *cache_data);
 
 int dns_cache_insert(struct dns_cache_key *key, int ttl, int speed, int no_inactive, struct dns_cache_data *cache_data);
 
@@ -152,6 +158,8 @@ void dns_cache_release(struct dns_cache *dns_cache);
 
 int dns_cache_hitnum_dec_get(struct dns_cache *dns_cache);
 
+int dns_cache_is_visited(struct dns_cache *dns_cache);
+
 void dns_cache_update(struct dns_cache *dns_cache);
 
 typedef void dns_cache_callback(struct dns_cache *dns_cache);
@@ -165,7 +173,7 @@ int dns_cache_get_cname_ttl(struct dns_cache *dns_cache);
 
 int dns_cache_is_soa(struct dns_cache *dns_cache);
 
-struct dns_cache_data *dns_cache_new_data(void);
+struct dns_cache_data *dns_cache_new_data_addr(void);
 
 struct dns_cache_data *dns_cache_get_data(struct dns_cache *dns_cache);
 
@@ -180,6 +188,8 @@ int dns_cache_load(const char *file);
 
 int dns_cache_save(const char *file);
 
+const char *dns_cache_file_version(void);
+
 #ifdef __cplusplus
 }
 #endif

+ 21 - 0
src/dns_client.c

@@ -123,6 +123,7 @@ struct dns_server_info {
 	time_t last_recv;
 	unsigned long send_tick;
 	int prohibit;
+	int is_already_prohibit;
 
 	/* server addr info */
 	unsigned short ai_family;
@@ -200,6 +201,7 @@ struct dns_client {
 	struct list_head dns_request_list;
 	atomic_t run_period;
 	atomic_t dns_server_num;
+	atomic_t dns_server_prohibit_num;
 
 	/* ECS */
 	struct dns_client_ecs ecs_ipv4;
@@ -1413,6 +1415,11 @@ int dns_server_num(void)
 	return atomic_read(&client.dns_server_num);
 }
 
+int dns_server_alive_num(void)
+{
+	return atomic_read(&client.dns_server_num) - atomic_read(&client.dns_server_prohibit_num);
+}
+
 static void _dns_client_query_get(struct dns_query_struct *query)
 {
 	if (atomic_inc_return(&query->refcnt) <= 0) {
@@ -3338,6 +3345,7 @@ static int _dns_client_send_packet(struct dns_query_struct *query, void *packet,
 	query->send_tick = get_tick_count();
 
 	/* send query to all dns servers */
+	atomic_inc(&query->dns_request_sent);
 	for (i = 0; i < 2; i++) {
 		total_server = 0;
 		pthread_mutex_lock(&client.server_list_lock);
@@ -3345,12 +3353,19 @@ static int _dns_client_send_packet(struct dns_query_struct *query, void *packet,
 		{
 			server_info = group_member->server;
 			if (server_info->prohibit) {
+				if (server_info->is_already_prohibit == 0) {
+					server_info->is_already_prohibit = 1;
+					atomic_inc(&client.dns_server_prohibit_num);
+				}
+	
 				time_t now = 0;
 				time(&now);
 				if ((now - 60 < server_info->last_send) && (now - 5 > server_info->last_recv)) {
 					continue;
 				}
 				server_info->prohibit = 0;
+				server_info->is_already_prohibit = 0;
+				atomic_dec(&client.dns_server_prohibit_num);
 				if (now - 60 > server_info->last_send) {
 					_dns_client_close_socket(server_info);
 				}
@@ -3428,6 +3443,11 @@ static int _dns_client_send_packet(struct dns_query_struct *query, void *packet,
 		}
 	}
 
+	int num  = atomic_dec_return(&query->dns_request_sent);
+	if (num == 0) {
+		_dns_client_query_remove(query);
+	}
+
 	if (send_count <= 0) {
 		tlog(TLOG_WARN, "Send query to upstream server failed, total server number %d", total_server);
 		return -1;
@@ -4194,6 +4214,7 @@ int dns_client_init(void)
 	memset(&client, 0, sizeof(client));
 	pthread_attr_init(&attr);
 	atomic_set(&client.dns_server_num, 0);
+	atomic_set(&client.dns_server_prohibit_num, 0);
 	atomic_set(&client.run_period, 0);
 
 	epollfd = epoll_create1(EPOLL_CLOEXEC);

+ 2 - 0
src/dns_client.h

@@ -141,6 +141,8 @@ int dns_client_remove_from_group(const char *group_name, char *server_ip, int po
 
 int dns_client_remove_group(const char *group_name);
 
+int dns_server_alive_num(void);
+
 int dns_server_num(void);
 
 #ifdef __cplusplus

+ 22 - 11
src/dns_server.c

@@ -1576,6 +1576,7 @@ static int _dns_server_setup_ipset_nftset_packet(struct dns_server_post_context
 	struct dns_request *request = context->request;
 	char name[DNS_MAX_CNAME_LEN] = {0};
 	int rr_count = 0;
+	int timeout_value = 0;
 	int i = 0;
 	int j = 0;
 	struct dns_rrs *rrs = NULL;
@@ -1642,6 +1643,11 @@ static int _dns_server_setup_ipset_nftset_packet(struct dns_server_post_context
 		return 0;
 	}
 
+	timeout_value = request->ip_ttl * 3;
+	if (timeout_value == 0) {
+		timeout_value = _dns_server_get_conf_ttl(request, 0) * 3;
+	}
+
 	for (j = 1; j < DNS_RRS_END; j++) {
 		rrs = dns_get_rrs_start(context->packet, j, &rr_count);
 		for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(context->packet, rrs)) {
@@ -1659,7 +1665,7 @@ static int _dns_server_setup_ipset_nftset_packet(struct dns_server_post_context
 					/* add IPV4 to ipset */
 					tlog(TLOG_DEBUG, "IPSET-MATCH: domain: %s, ipset: %s, IP: %d.%d.%d.%d", request->domain,
 						 rule->ipsetname, addr[0], addr[1], addr[2], addr[3]);
-					ipset_add(rule->ipsetname, addr, DNS_RR_A_LEN, request->ip_ttl * 2);
+					ipset_add(rule->ipsetname, addr, DNS_RR_A_LEN, timeout_value);
 				}
 
 				if (nftset_ip != NULL) {
@@ -1668,7 +1674,7 @@ static int _dns_server_setup_ipset_nftset_packet(struct dns_server_post_context
 						 nftset_ip->familyname, nftset_ip->nfttablename, nftset_ip->nftsetname, addr[0], addr[1],
 						 addr[2], addr[3]);
 					nftset_add(nftset_ip->familyname, nftset_ip->nfttablename, nftset_ip->nftsetname, addr,
-							   DNS_RR_A_LEN, request->ip_ttl * 2);
+							   DNS_RR_A_LEN, timeout_value);
 				}
 			} break;
 			case DNS_T_AAAA: {
@@ -1687,7 +1693,7 @@ static int _dns_server_setup_ipset_nftset_packet(struct dns_server_post_context
 						 request->domain, rule->ipsetname, addr[0], addr[1], addr[2], addr[3], addr[4], addr[5],
 						 addr[6], addr[7], addr[8], addr[9], addr[10], addr[11], addr[12], addr[13], addr[14],
 						 addr[15]);
-					ipset_add(rule->ipsetname, addr, DNS_RR_AAAA_LEN, request->ip_ttl * 2);
+					ipset_add(rule->ipsetname, addr, DNS_RR_AAAA_LEN, timeout_value);
 				}
 
 				if (nftset_ip6 != NULL) {
@@ -1699,7 +1705,7 @@ static int _dns_server_setup_ipset_nftset_packet(struct dns_server_post_context
 						 addr[0], addr[1], addr[2], addr[3], addr[4], addr[5], addr[6], addr[7], addr[8], addr[9],
 						 addr[10], addr[11], addr[12], addr[13], addr[14], addr[15]);
 					nftset_add(nftset_ip6->familyname, nftset_ip6->nfttablename, nftset_ip6->nftsetname, addr,
-							   DNS_RR_AAAA_LEN, request->ip_ttl * 2);
+							   DNS_RR_AAAA_LEN, timeout_value);
 				}
 			} break;
 			default:
@@ -2803,7 +2809,7 @@ static int _dns_server_process_answer_A(struct dns_rrs *rrs, struct dns_request
 	/* Ad blocking result */
 	if (addr[0] == 0 || addr[0] == 127) {
 		/* If half of the servers return the same result, then ignore this address */
-		if (atomic_inc_return(&request->adblock) <= (dns_server_num() / 2 + dns_server_num() % 2)) {
+		if (atomic_inc_return(&request->adblock) <= (dns_server_alive_num() / 2 + dns_server_alive_num() % 2)) {
 			request->rcode = DNS_RC_NOERROR;
 			_dns_server_request_release(request);
 			return -1;
@@ -2880,7 +2886,7 @@ static int _dns_server_process_answer_AAAA(struct dns_rrs *rrs, struct dns_reque
 	/* Ad blocking result */
 	if (_dns_server_is_adblock_ipv6(addr) == 0) {
 		/* If half of the servers return the same result, then ignore this address */
-		if (atomic_inc_return(&request->adblock) <= (dns_server_num() / 2 + dns_server_num() % 2)) {
+		if (atomic_inc_return(&request->adblock) <= (dns_server_alive_num() / 2 + dns_server_alive_num() % 2)) {
 			request->rcode = DNS_RC_NOERROR;
 			_dns_server_request_release(request);
 			return -1;
@@ -2989,7 +2995,8 @@ static int _dns_server_process_answer(struct dns_request *request, const char *d
 					 request->soa.refresh, request->soa.retry, request->soa.expire, request->soa.minimum);
 
 				int soa_num = atomic_inc_return(&request->soa_num);
-				if ((soa_num >= (dns_server_num() / 3) + 1 || soa_num > 4) && atomic_read(&request->ip_map_num) <= 0) {
+				if ((soa_num >= (dns_server_alive_num() / 3) + 1 || soa_num > 4) &&
+					atomic_read(&request->ip_map_num) <= 0) {
 					request->ip_ttl = ttl;
 					_dns_server_request_complete(request);
 				}
@@ -3072,7 +3079,7 @@ static int _dns_server_passthrough_rule_check(struct dns_request *request, const
 				/* Ad blocking result */
 				if (addr[0] == 0 || addr[0] == 127) {
 					/* If half of the servers return the same result, then ignore this address */
-					if (atomic_read(&request->adblock) <= (dns_server_num() / 2 + dns_server_num() % 2)) {
+					if (atomic_read(&request->adblock) <= (dns_server_alive_num() / 2 + dns_server_alive_num() % 2)) {
 						_dns_server_request_release(request);
 						return 0;
 					}
@@ -3116,7 +3123,7 @@ static int _dns_server_passthrough_rule_check(struct dns_request *request, const
 				/* Ad blocking result */
 				if (_dns_server_is_adblock_ipv6(addr) == 0) {
 					/* If half of the servers return the same result, then ignore this address */
-					if (atomic_read(&request->adblock) <= (dns_server_num() / 2 + dns_server_num() % 2)) {
+					if (atomic_read(&request->adblock) <= (dns_server_alive_num() / 2 + dns_server_alive_num() % 2)) {
 						_dns_server_request_release(request);
 						return 0;
 					}
@@ -3384,7 +3391,7 @@ static void _dns_server_passthrough_may_complete(struct dns_request *request)
 		addr = request->ip_addr;
 		if (addr[0] == 0 || addr[0] == 127) {
 			/* If half of the servers return the same result, then ignore this address */
-			if (atomic_read(&request->adblock) <= (dns_server_num() / 2 + dns_server_num() % 2)) {
+			if (atomic_read(&request->adblock) <= (dns_server_alive_num() / 2 + dns_server_alive_num() % 2)) {
 				return;
 			}
 		}
@@ -3394,7 +3401,7 @@ static void _dns_server_passthrough_may_complete(struct dns_request *request)
 		addr = request->ip_addr;
 		if (_dns_server_is_adblock_ipv6(addr) == 0) {
 			/* If half of the servers return the same result, then ignore this address */
-			if (atomic_read(&request->adblock) <= (dns_server_num() / 2 + dns_server_num() % 2)) {
+			if (atomic_read(&request->adblock) <= (dns_server_alive_num() / 2 + dns_server_alive_num() % 2)) {
 				return;
 			}
 		}
@@ -4551,6 +4558,10 @@ static int _dns_server_process_cache_packet(struct dns_request *request, struct
 		return -1;
 	}
 
+	if (dns_cache_is_visited(dns_cache) == 0) {
+		do_ipset = 1;
+	}
+
 	if (dns_cache->info.qtype != request->qtype) {
 		return -1;
 	}

+ 0 - 2
test/cases/test-address.cc

@@ -35,7 +35,6 @@ TEST_F(Address, soa)
 {
 	smartdns::MockServer server_upstream;
 	smartdns::Server server;
-	std::map<int, int> qid_map;
 
 	server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
 		if (request->qtype == DNS_T_A) {
@@ -122,7 +121,6 @@ TEST_F(Address, ip)
 {
 	smartdns::MockServer server_upstream;
 	smartdns::Server server;
-	std::map<int, int> qid_map;
 
 	server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
 		if (request->qtype == DNS_T_A) {

+ 77 - 0
test/cases/test-cache.cc

@@ -21,7 +21,10 @@
 #include "include/utils.h"
 #include "server.h"
 #include "gtest/gtest.h"
+#include <fcntl.h>
 #include <fstream>
+#include <sys/stat.h>
+#include <sys/types.h>
 
 /* clang-format off */
 #include "dns_cache.h"
@@ -286,3 +289,77 @@ dualstack-ip-selection no
 	EXPECT_EQ(head.magic, MAGIC_NUMBER);
 	EXPECT_EQ(head.cache_number, 1);
 }
+
+TEST_F(Cache, corrupt_file)
+{
+	smartdns::MockServer server_upstream;
+	auto cache_file = "/tmp/smartdns_cache." + smartdns::GenerateRandomString(10);
+	std::string conf = R"""(
+bind [::]:60053@lo
+server 127.0.0.1:62053
+log-num 0
+log-console yes
+log-level debug
+dualstack-ip-selection no
+cache-persist yes
+)""";
+
+	conf += "cache-file " + cache_file;
+	Defer
+	{
+		unlink(cache_file.c_str());
+	};
+
+	server_upstream.Start("udp://0.0.0.0:62053", [](struct smartdns::ServerRequestContext *request) {
+		if (request->qtype == DNS_T_A) {
+			smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4");
+			return smartdns::SERVER_REQUEST_OK;
+		}
+		return smartdns::SERVER_REQUEST_SOA;
+	});
+	{
+		smartdns::Server server;
+		server.Start(conf);
+		smartdns::Client client;
+
+		ASSERT_TRUE(client.Query("a.com", 60053));
+		std::cout << client.GetResult() << std::endl;
+		ASSERT_EQ(client.GetAnswerNum(), 1);
+		EXPECT_EQ(client.GetStatus(), "NOERROR");
+		EXPECT_LT(client.GetQueryTime(), 100);
+		EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
+		EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
+		server.Stop();
+		usleep(200 * 1000);
+	}
+
+	ASSERT_EQ(access(cache_file.c_str(), F_OK), 0);
+
+	int fd = open(cache_file.c_str(), O_RDWR);
+	ASSERT_NE(fd, -1);
+	srandom(time(NULL));
+	off_t file_size = lseek(fd, 0, SEEK_END);
+	off_t offset = random() % (file_size - 300);
+	std::cout << "try make corrupt at " << offset << ", file size: " << file_size << std::endl;
+	lseek(fd, offset, SEEK_SET);
+	for (int i = 0; i < 300; i++) {
+		unsigned char c = random() % 256;
+		write(fd, &c, 1);
+	}
+	close(fd);
+	{
+		smartdns::Server server;
+		server.Start(conf);
+		smartdns::Client client;
+
+		ASSERT_TRUE(client.Query("a.com", 60053));
+		std::cout << client.GetResult() << std::endl;
+		ASSERT_EQ(client.GetAnswerNum(), 1);
+		EXPECT_EQ(client.GetStatus(), "NOERROR");
+		EXPECT_LT(client.GetQueryTime(), 100);
+		EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
+		EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
+		server.Stop();
+		usleep(200 * 1000);
+	}
+}

+ 0 - 1
test/cases/test-dns64.cc

@@ -35,7 +35,6 @@ TEST_F(DNS64, no_dualstack)
 {
 	smartdns::MockServer server_upstream;
 	smartdns::Server server;
-	std::map<int, int> qid_map;
 
 	server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
 		if (request->qtype == DNS_T_A) {

+ 0 - 1
test/cases/test-domain-set.cc

@@ -35,7 +35,6 @@ TEST_F(DomainSet, set_add)
 {
 	smartdns::MockServer server_upstream;
 	smartdns::Server server;
-	std::map<int, int> qid_map;
 	smartdns::TempFile file_set;
 	std::vector<std::string> domain_list;
 	int count = 16;

+ 0 - 1
test/cases/test-dualstack.cc

@@ -35,7 +35,6 @@ TEST_F(DualStack, ipv4_prefer)
 {
 	smartdns::MockServer server_upstream;
 	smartdns::Server server;
-	std::map<int, int> qid_map;
 
 	server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
 		if (request->qtype == DNS_T_A) {

+ 0 - 2
test/cases/test-qtype-soa.cc

@@ -155,7 +155,6 @@ TEST_F(QtypeSOA, force_AAAA_SOA)
 {
 	smartdns::MockServer server_upstream;
 	smartdns::Server server;
-	std::map<int, int> qid_map;
 
 	server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
 		if (request->qtype == DNS_T_A) {
@@ -199,7 +198,6 @@ TEST_F(QtypeSOA, bind_force_AAAA_SOA)
 {
 	smartdns::MockServer server_upstream;
 	smartdns::Server server;
-	std::map<int, int> qid_map;
 
 	server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
 		if (request->qtype == DNS_T_A) {

+ 69 - 4
test/cases/test-speed-check.cc

@@ -35,7 +35,6 @@ TEST_F(SpeedCheck, response_mode)
 {
 	smartdns::MockServer server_upstream;
 	smartdns::Server server;
-	std::map<int, int> qid_map;
 
 	server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
 		if (request->qtype == DNS_T_A) {
@@ -78,7 +77,6 @@ TEST_F(SpeedCheck, none)
 {
 	smartdns::MockServer server_upstream;
 	smartdns::Server server;
-	std::map<int, int> qid_map;
 
 	server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
 		if (request->qtype == DNS_T_A) {
@@ -120,7 +118,6 @@ TEST_F(SpeedCheck, domain_rules_none)
 {
 	smartdns::MockServer server_upstream;
 	smartdns::Server server;
-	std::map<int, int> qid_map;
 
 	server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
 		if (request->qtype == DNS_T_A) {
@@ -162,7 +159,6 @@ TEST_F(SpeedCheck, only_ping)
 {
 	smartdns::MockServer server_upstream;
 	smartdns::Server server;
-	std::map<int, int> qid_map;
 
 	server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
 		if (request->qtype == DNS_T_A) {
@@ -190,6 +186,75 @@ cache-persist no)""");
 	EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
 }
 
+TEST_F(SpeedCheck, no_ping_fallback_tcp)
+{
+	smartdns::MockServer server_upstream;
+	smartdns::Server server;
+
+	server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
+		if (request->qtype == DNS_T_A) {
+			smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4");
+			smartdns::MockServer::AddIP(request, request->domain.c_str(), "5.6.7.8");
+			return smartdns::SERVER_REQUEST_OK;
+		}
+		return smartdns::SERVER_REQUEST_SOA;
+	});
+
+	server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 1000);
+	server.MockPing(PING_TYPE_TCP, "5.6.7.8:80", 60, 100);
+	server.Start(R"""(bind [::]:60053
+server 127.0.0.1:61053
+log-num 0
+log-console yes
+speed-check-mode ping,tcp:80
+log-level debug
+cache-persist no)""");
+	smartdns::Client client;
+	ASSERT_TRUE(client.Query("a.com", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_LT(client.GetQueryTime(), 500);
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
+	EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "5.6.7.8");
+}
+
+
+TEST_F(SpeedCheck, tcp_faster_than_ping)
+{
+	smartdns::MockServer server_upstream;
+	smartdns::Server server;
+
+	server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
+		if (request->qtype == DNS_T_A) {
+			smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4");
+			smartdns::MockServer::AddIP(request, request->domain.c_str(), "5.6.7.8");
+			return smartdns::SERVER_REQUEST_OK;
+		}
+		return smartdns::SERVER_REQUEST_SOA;
+	});
+
+	server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 300);
+	server.MockPing(PING_TYPE_TCP, "5.6.7.8:80", 60, 10);
+	server.Start(R"""(bind [::]:60053
+server 127.0.0.1:61053
+log-num 0
+log-console yes
+speed-check-mode ping,tcp:80
+log-level debug
+cache-persist no)""");
+	smartdns::Client client;
+	ASSERT_TRUE(client.Query("a.com", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_LT(client.GetQueryTime(), 500);
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
+	EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "5.6.7.8");
+}
+
 TEST_F(SpeedCheck, fastest_ip)
 {
 	smartdns::MockServer server_upstream;