Browse Source

dns_cache: make cache case insensitive

Nick Peng 2 months ago
parent
commit
9e2474caf1

+ 1 - 0
.gitignore

@@ -10,3 +10,4 @@ test.bin
 package/target
 package/*.gz
 package/*.ipk
+target

+ 125 - 6
src/dns.c

@@ -2952,6 +2952,122 @@ int dns_encode(unsigned char *data, int size, struct dns_packet *packet)
 	return context.ptr - context.data;
 }
 
+static int _dns_update_domain(struct dns_context *context, const char *domain)
+{
+	int len = 0;
+	int ptr_jump = 0;
+	int output_len = 0;
+	unsigned char *ptr = context->ptr;
+	unsigned char *packet = context->data;
+	int packet_size = context->maxsize;
+	int domain_len = strlen(domain);
+	int processed_len = 0;
+
+	while (1) {
+		if (ptr >= packet + packet_size || ptr < packet || ptr_jump > 32 || processed_len > domain_len + 1) {
+			return -1;
+		}
+
+		len = *ptr;
+		if (len == 0) {
+			ptr++;
+			processed_len++;
+			break;
+		}
+
+		/* compressed domain */
+		if (len >= 0xC0) {
+			if ((ptr + 2) > (packet + packet_size)) {
+				return -1;
+			}
+
+			/* read offset */
+			len = _dns_read_short(&ptr) & 0x3FFF;
+			ptr = packet + len;
+			if (ptr > packet + packet_size) {
+				return -1;
+			}
+
+			ptr_jump++;
+			continue;
+		}
+
+		ptr_jump = 0;
+
+		if (output_len > 0) {
+			output_len += 1;
+		}
+
+		if (ptr > packet + packet_size) {
+			return -1;
+		}
+
+		ptr++;
+		/* update domain */
+		memcpy(ptr, domain + processed_len, len);
+		ptr += len;
+		processed_len += len + 1;
+		output_len += len;
+	}
+
+	if (output_len != domain_len) {
+		tlog(TLOG_DEBUG, "update domain failed, length mismatch. output_len: %d, domain_len: %d", output_len,
+			 domain_len);
+		return -1;
+	}
+
+	return 0;
+}
+
+static int _dns_update_rr_domain(struct dns_context *context, unsigned char *rr_start, const char *domain,
+								 struct dns_update_param *param)
+{
+	const char *query_domain = param->query_domain;
+	unsigned char *old_context_ptr = context->ptr;
+
+	if (param->query_domain == NULL) {
+		return 0;
+	}
+
+	if (strncasecmp(domain, query_domain, DNS_MAX_CNAME_LEN) != 0) {
+		return 0;
+	}
+
+	context->ptr = rr_start;
+	if (_dns_update_domain(context, query_domain) != 0) {
+		tlog(TLOG_DEBUG, "update domain failed, %s", domain);
+		context->ptr = old_context_ptr;
+		return -1;
+	}
+
+	context->ptr = old_context_ptr;
+
+	return 0;
+}
+
+static int _dns_update_qd(struct dns_context *context, dns_rr_type type, struct dns_update_param *param)
+{
+	char domain[DNS_MAX_CNAME_LEN];
+	int qtype = 0;
+	int qclass = 0;
+	int len = 0;
+	unsigned char *rr_start = context->ptr;
+
+	len = _dns_decode_qr_head(context, domain, DNS_MAX_CNAME_LEN, &qtype, &qclass);
+	if (len < 0) {
+		tlog(TLOG_DEBUG, "update qd failed.");
+		return -1;
+	}
+
+	int ret = _dns_update_rr_domain(context, rr_start, domain, param);
+	if (ret < 0) {
+		tlog(TLOG_DEBUG, "domain not match, %s", domain);
+		return -1;
+	}
+
+	return 0;
+}
+
 static int _dns_update_an(struct dns_context *context, dns_rr_type type, struct dns_update_param *param)
 {
 	int ret = 0;
@@ -2961,6 +3077,7 @@ static int _dns_update_an(struct dns_context *context, dns_rr_type type, struct
 	int rr_len = 0;
 	char domain[DNS_MAX_CNAME_LEN];
 	unsigned char *start = NULL;
+	unsigned char *rr_start = context->ptr;
 
 	/* decode rr head */
 	ret = _dns_decode_rr_head(context, domain, DNS_MAX_CNAME_LEN, &qtype, &qclass, &ttl, &rr_len);
@@ -2969,6 +3086,12 @@ static int _dns_update_an(struct dns_context *context, dns_rr_type type, struct
 		return -1;
 	}
 
+	ret = _dns_update_rr_domain(context, rr_start, domain, param);
+	if (ret < 0) {
+		tlog(TLOG_DEBUG, "domain not match, %s", domain);
+		return -1;
+	}
+
 	start = context->ptr;
 	switch (qtype) {
 	case DNS_T_OPT:
@@ -3001,12 +3124,8 @@ static int _dns_update_body(struct dns_context *context, struct dns_update_param
 	count = head->qdcount;
 	head->qdcount = 0;
 	for (i = 0; i < count; i++) {
-		char domain[DNS_MAX_CNAME_LEN];
-		int qtype = 0;
-		int qclass = 0;
-		int len = 0;
-		len = _dns_decode_qr_head(context, domain, DNS_MAX_CNAME_LEN, &qtype, &qclass);
-		if (len < 0) {
+		ret = _dns_update_qd(context, DNS_RRS_QD, param);
+		if (ret < 0) {
 			tlog(TLOG_DEBUG, "update qd failed.");
 			return -1;
 		}

+ 5 - 5
src/dns_cache.c

@@ -223,7 +223,7 @@ static struct dns_cache *_dns_cache_lookup(struct dns_cache_key *cache_key)
 	struct dns_cache *dns_cache_ret = NULL;
 	time_t now = 0;
 
-	key = hash_string(cache_key->domain);
+	key = hash_string_case(cache_key->domain);
 	key = jhash(&cache_key->qtype, sizeof(cache_key->qtype), key);
 	key = hash_string_initval(cache_key->dns_group_name, key);
 	key = jhash(&cache_key->query_flag, sizeof(cache_key->query_flag), key);
@@ -237,7 +237,7 @@ static struct dns_cache *_dns_cache_lookup(struct dns_cache_key *cache_key)
 			continue;
 		}
 
-		if (strncmp(cache_key->domain, dns_cache->info.domain, DNS_MAX_CNAME_LEN) != 0) {
+		if (strncasecmp(cache_key->domain, dns_cache->info.domain, DNS_MAX_CNAME_LEN) != 0) {
 			continue;
 		}
 
@@ -345,7 +345,7 @@ static void _dns_cache_remove_by_domain(struct dns_cache_key *cache_key)
 	uint32_t key = 0;
 	struct dns_cache *dns_cache = NULL;
 
-	key = hash_string(cache_key->domain);
+	key = hash_string_case(cache_key->domain);
 	key = jhash(&cache_key->qtype, sizeof(cache_key->qtype), key);
 	key = hash_string_initval(cache_key->dns_group_name, key);
 	key = jhash(&cache_key->query_flag, sizeof(cache_key->query_flag), key);
@@ -361,7 +361,7 @@ static void _dns_cache_remove_by_domain(struct dns_cache_key *cache_key)
 			continue;
 		}
 
-		if (strncmp(cache_key->domain, dns_cache->info.domain, DNS_MAX_CNAME_LEN) != 0) {
+		if (strncasecmp(cache_key->domain, dns_cache->info.domain, DNS_MAX_CNAME_LEN) != 0) {
 			continue;
 		}
 
@@ -401,7 +401,7 @@ static int _dns_cache_insert(struct dns_cache_info *info, struct dns_cache_data
 	}
 
 	memset(dns_cache, 0, sizeof(*dns_cache));
-	key = hash_string(info->domain);
+	key = hash_string_case(info->domain);
 	key = jhash(&info->qtype, sizeof(info->qtype), key);
 	key = hash_string_initval(info->dns_group_name, key);
 	key = jhash(&info->query_flag, sizeof(info->query_flag), key);

+ 7 - 1
src/dns_server/cache.c

@@ -413,7 +413,13 @@ static int _dns_server_process_cache_packet(struct dns_request *request, struct
 
 	struct dns_server_post_context context;
 	_dns_server_post_context_init(&context, request);
-	context.inpacket = cache_packet->data;
+
+	if (request->original_domain != NULL && cache_packet->head.size < DNS_IN_PACKSIZE) {
+		context.inpacket = context.inpacket_buff;
+		memcpy(context.inpacket, cache_packet->data, cache_packet->head.size);
+	} else {
+		context.inpacket = cache_packet->data;
+	}
 	context.inpacket_len = cache_packet->head.size;
 	request->ping_time = dns_cache->info.speed;
 

+ 5 - 3
src/dns_server/context.c

@@ -728,7 +728,7 @@ static int _dns_result_child_post(struct dns_server_post_context *context)
 	return 0;
 }
 
-static int _dns_request_update_id_ttl(struct dns_server_post_context *context)
+static int _dns_request_update_id_ttl_domain(struct dns_server_post_context *context)
 {
 	int ttl = context->reply_ttl;
 	struct dns_request *request = context->request;
@@ -758,6 +758,7 @@ static int _dns_request_update_id_ttl(struct dns_server_post_context *context)
 	param.id = request->id;
 	param.cname_ttl = ttl;
 	param.ip_ttl = ttl;
+	param.query_domain = request->original_domain;
 	if (dns_packet_update(context->inpacket, context->inpacket_len, &param) != 0) {
 		tlog(TLOG_DEBUG, "update packet info failed.");
 	}
@@ -820,7 +821,7 @@ int _dns_request_post(struct dns_server_post_context *context)
 		return 0;
 	}
 
-	ret = _dns_request_update_id_ttl(context);
+	ret = _dns_request_update_id_ttl_domain(context);
 	if (ret != 0) {
 		tlog(TLOG_ERROR, "update packet ttl failed.");
 		return -1;
@@ -1047,11 +1048,12 @@ int _dns_server_reply_passthrough(struct dns_server_post_context *context)
 		char clientip[DNS_MAX_CNAME_LEN] = {0};
 
 		/* When passthrough, modify the id to be the id of the client request. */
-		int ret = _dns_request_update_id_ttl(context);
+		int ret = _dns_request_update_id_ttl_domain(context);
 		if (ret != 0) {
 			tlog(TLOG_ERROR, "update packet ttl failed.");
 			return -1;
 		}
+
 		_dns_reply_inpacket(request, context->inpacket, context->inpacket_len);
 
 		tlog(TLOG_INFO, "result: %s, client: %s, qtype: %d, id: %d, group: %s, time: %lums", request->domain,

+ 1 - 0
src/dns_server/dns_server.h

@@ -272,6 +272,7 @@ struct dns_request {
 
 	/* dns query */
 	char domain[DNS_MAX_CNAME_LEN];
+	char *original_domain;
 	dns_type_t qtype;
 	int qclass;
 	unsigned long send_tick;

+ 19 - 1
src/dns_server/request.c

@@ -156,6 +156,11 @@ static void _dns_server_delete_request(struct dns_request *request)
 	if (request->https_svcb) {
 		free(request->https_svcb);
 	}
+
+	if (request->original_domain) {
+		free(request->original_domain);
+	}
+
 	memset(request, 0, sizeof(*request));
 	free(request);
 	atomic_dec(&server.request_num);
@@ -1207,7 +1212,20 @@ int _dns_server_parser_request(struct dns_request *request, struct dns_packet *p
 		}
 
 		// Only support one question.
-		safe_strncpy(request->domain, domain, sizeof(request->domain));
+		int case_changed = 0;
+		safe_strncpy_lower(request->domain, domain, sizeof(request->domain), &case_changed);
+
+		/* support draft dns0x20 */
+		if (case_changed) {
+			request->original_domain = malloc(DNS_MAX_CNAME_LEN);
+			if (request->original_domain == NULL) {
+				tlog(TLOG_ERROR, "malloc failed.\n");
+				goto errout;
+			}
+
+			safe_strncpy(request->original_domain, domain, DNS_MAX_CNAME_LEN);
+			tlog(TLOG_DEBUG, "query %s by origin domain %s", request->domain, request->original_domain);
+		}
 		request->qtype = qtype;
 		break;
 	}

+ 1 - 0
src/include/smartdns/dns.h

@@ -335,6 +335,7 @@ struct dns_update_param {
 	int id;
 	int ip_ttl;
 	int cname_ttl;
+	const char *query_domain;
 };
 
 int dns_packet_update(unsigned char *data, int size, struct dns_update_param *param);

+ 27 - 0
src/include/smartdns/lib/stringutil.h

@@ -47,5 +47,32 @@ static inline char *safe_strncpy(char *dest, const char *src, size_t n)
 	return ret;
 }
 
+static inline char *safe_strncpy_lower(char *dest, const char *src, size_t n, int *case_changed) 
+{
+	if (src == NULL) {
+		dest[0] = '\0';
+		return dest;
+	}
+
+	if (n <= 0) {
+		return NULL;
+	}
+	
+	while (n > 1 && *src) {
+		*dest = *src;
+		if (*dest >= 'A' && *dest <= 'Z') {
+			*dest = *dest + 32;
+			if (case_changed) {
+				*case_changed = 1;
+			}
+		}
+		dest++;
+		src++;
+		n--;
+	}
+	*dest = '\0';
+	return dest;
+}
+
 
 #endif

+ 2 - 2
test/cases/test-dns64.cc

@@ -247,7 +247,7 @@ server 127.0.0.1:61053
 	EXPECT_EQ(client.GetStatus(), "NOERROR");
 	EXPECT_LT(client.GetQueryTime(), 1200);
 	EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
-	EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
+	EXPECT_GT(client.GetAnswer()[0].GetTTL(), 590);
 	EXPECT_EQ(client.GetAnswer()[0].GetData(), "::ffff:1.2.3.4");
 
 	ASSERT_TRUE(client.Query("a.com A", 60053));
@@ -256,6 +256,6 @@ server 127.0.0.1:61053
 	EXPECT_EQ(client.GetStatus(), "NOERROR");
 	EXPECT_LT(client.GetQueryTime(), 1200);
 	EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
-	EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
+	EXPECT_GT(client.GetAnswer()[0].GetTTL(), 590);
 	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
 }

+ 36 - 0
test/cases/test-server.cc

@@ -552,4 +552,40 @@ speed-check-mode none
 	EXPECT_EQ(client.GetStatus(), "NOERROR");
 	EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
 	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.6");
+}
+
+TEST_F(Server, case_insensitive)
+{
+	smartdns::MockServer server_upstream;
+	smartdns::Server server;
+
+	server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
+		if (request->qtype != DNS_T_A) {
+			return smartdns::SERVER_REQUEST_SOA;
+		}
+		usleep(100000);
+		smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 611);
+		return smartdns::SERVER_REQUEST_OK;
+	});
+
+	server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 128, 1000);
+	server.Start(R"""(bind [::]:60053
+bind-tcp [::]:60053
+server 127.0.0.1:61053 
+)""");
+	smartdns::Client client;
+	ASSERT_TRUE(client.Query("a.com", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
+
+	ASSERT_TRUE(client.Query("A.cOm", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_LE(client.GetQueryTime(), 5);
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "A.cOm");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
 }