Przeglądaj źródła

dns_cache: try to fix crash issue caused by cache.

Nick Peng 1 rok temu
rodzic
commit
53aa5121ba
3 zmienionych plików z 44 dodań i 16 usunięć
  1. 33 9
      src/dns_cache.c
  2. 6 6
      src/dns_server.c
  3. 5 1
      src/lib/timer_wheel.c

+ 33 - 9
src/dns_cache.c

@@ -52,6 +52,7 @@ static struct dns_cache_head dns_cache_head;
 int dns_cache_init(int size, dns_cache_callback timeout_callback)
 {
 	int bits = 0;
+	pthread_mutexattr_t mta;
 	if (is_cache_init == 1) {
 		return -1;
 	}
@@ -69,7 +70,10 @@ int dns_cache_init(int size, dns_cache_callback timeout_callback)
 	atomic_set(&dns_cache_head.num, 0);
 	dns_cache_head.size = size;
 	dns_cache_head.timeout_callback = timeout_callback;
-	pthread_mutex_init(&dns_cache_head.lock, NULL);
+	pthread_mutexattr_init(&mta);
+	pthread_mutexattr_settype(&mta, PTHREAD_MUTEX_RECURSIVE);
+	pthread_mutex_init(&dns_cache_head.lock, &mta);
+	pthread_mutexattr_destroy(&mta);
 
 	is_cache_init = 1;
 	return 0;
@@ -82,10 +86,15 @@ static struct dns_cache *_dns_cache_first(void)
 
 static void _dns_cache_delete(struct dns_cache *dns_cache)
 {
+	pthread_mutex_lock(&dns_cache_head.lock);
 	hash_del(&dns_cache->node);
 	list_del_init(&dns_cache->list);
+	pthread_mutex_unlock(&dns_cache_head.lock);
 	atomic_dec(&dns_cache_head.num);
-	dns_cache_data_put(dns_cache->cache_data);
+	if (dns_cache->cache_data) {
+		dns_cache_data_put(dns_cache->cache_data);
+	}
+
 	dns_cache->cache_data = NULL;
 	free(dns_cache);
 }
@@ -156,7 +165,7 @@ struct dns_cache_data *dns_cache_new_data_packet(void *packet, size_t packet_len
 static void dns_cache_timer_release(struct tw_base *base, struct tw_timer_list *timer, void *data)
 {
 	struct dns_cache *dns_cache = data;
-	dns_cache_release(dns_cache);
+	dns_cache_delete(dns_cache);
 }
 
 static void dns_cache_expired(struct tw_base *base, struct tw_timer_list *timer, void *data, unsigned long timestamp)
@@ -164,7 +173,7 @@ static void dns_cache_expired(struct tw_base *base, struct tw_timer_list *timer,
 	struct dns_cache *dns_cache = data;
 
 	if (dns_cache->del_pending == 1) {
-		dns_cache_release(dns_cache);
+		dns_cache_delete(dns_cache);
 		return;
 	}
 
@@ -174,7 +183,7 @@ static void dns_cache_expired(struct tw_base *base, struct tw_timer_list *timer,
 		case DNS_CACHE_TMOUT_ACTION_OK:
 			break;
 		case DNS_CACHE_TMOUT_ACTION_DEL:
-			dns_cache_release(dns_cache);
+			dns_cache_delete(dns_cache);
 			return;
 		case DNS_CACHE_TMOUT_ACTION_RETRY:
 			dns_timer_mod(&dns_cache->timer, DNS_CACHE_FAIL_TIMEOUT);
@@ -222,6 +231,7 @@ static int _dns_cache_replace(struct dns_cache_key *cache_key, int rcode, int tt
 		old_cache_data = dns_cache->cache_data;
 		dns_cache->cache_data = cache_data;
 	}
+
 	if (update_time) {
 		time(&dns_cache->info.insert_time);
 	}
@@ -286,6 +296,10 @@ static int _dns_cache_insert(struct dns_cache_info *info, struct dns_cache_data
 	uint32_t key = 0;
 	struct dns_cache *dns_cache = NULL;
 
+	if (cache_data == NULL || info == NULL) {
+		goto errout;
+	}
+
 	/* if cache already exists, free */
 	struct dns_cache_key cache_key;
 	cache_key.qtype = info->qtype;
@@ -312,10 +326,11 @@ static int _dns_cache_insert(struct dns_cache_info *info, struct dns_cache_data
 	dns_cache->timer.del_function = dns_cache_timer_release;
 	dns_cache->timer.expires = timeout;
 	dns_cache->timer.data = dns_cache;
+	INIT_LIST_HEAD(&dns_cache->check_list);
+
 	pthread_mutex_lock(&dns_cache_head.lock);
 	hash_table_add(dns_cache_head.cache_hash, &dns_cache->node, key);
 	list_add_tail(&dns_cache->list, head);
-	INIT_LIST_HEAD(&dns_cache->check_list);
 
 	/* Release extra cache, remove oldest cache record */
 	if (atomic_inc_return(&dns_cache_head.num) > dns_cache_head.size) {
@@ -460,6 +475,11 @@ struct dns_cache_data *dns_cache_get_data(struct dns_cache *dns_cache)
 {
 	struct dns_cache_data *cache_data;
 	pthread_mutex_lock(&dns_cache_head.lock);
+	if (dns_cache->cache_data == NULL) {
+		pthread_mutex_unlock(&dns_cache_head.lock);
+		return NULL;
+	}
+
 	dns_cache_data_get(dns_cache->cache_data);
 	cache_data = dns_cache->cache_data;
 	pthread_mutex_unlock(&dns_cache_head.lock);
@@ -567,14 +587,14 @@ static int _dns_cache_read_to_cache(struct dns_cache_record *cache_record, struc
 		timeout = DNS_CACHE_READ_TIMEOUT + (rand() % DNS_CACHE_READ_TIMEOUT);
 	}
 
+	dns_cache_data_get(cache_data);
 	if (_dns_cache_insert(&cache_record->info, cache_data, head, timeout) != 0) {
 		tlog(TLOG_ERROR, "insert cache data failed.");
+		dns_cache_data_put(cache_data);
 		cache_data = NULL;
 		goto errout;
 	}
 
-	dns_cache_data_get(cache_data);
-
 	daemon_keepalive();
 
 	return 0;
@@ -715,6 +735,11 @@ static int _dns_cache_write_record(int fd, uint32_t *cache_number, struct list_h
 	pthread_mutex_lock(&dns_cache_head.lock);
 	list_for_each_entry_safe(dns_cache, tmp, head, list)
 	{
+		struct dns_cache_data *cache_data = dns_cache->cache_data;
+		if (cache_data == NULL) {
+			continue;
+		}
+
 		cache_record.magic = MAGIC_RECORD;
 		memcpy(&cache_record.info, &dns_cache->info, sizeof(struct dns_cache_info));
 		ssize_t ret = write(fd, &cache_record, sizeof(cache_record));
@@ -723,7 +748,6 @@ static int _dns_cache_write_record(int fd, uint32_t *cache_number, struct list_h
 			goto errout;
 		}
 
-		struct dns_cache_data *cache_data = dns_cache->cache_data;
 		ret = write(fd, cache_data, sizeof(*cache_data) + cache_data->head.size);
 		if (ret != (int)sizeof(*cache_data) + cache_data->head.size) {
 			tlog(TLOG_ERROR, "write cache data failed, %s", strerror(errno));

+ 6 - 6
src/dns_server.c

@@ -1696,7 +1696,6 @@ static int _dns_cache_packet(struct dns_server_post_context *context)
 	}
 
 	/* if doing prefetch, update cache only */
-
 	struct dns_cache_key cache_key;
 	cache_key.dns_group_name = request->dns_group_name;
 	cache_key.domain = request->domain;
@@ -5162,7 +5161,12 @@ static int _dns_server_get_expired_ttl_reply(struct dns_cache *dns_cache)
 static int _dns_server_process_cache_packet(struct dns_request *request, struct dns_cache *dns_cache)
 {
 	int ret = -1;
-	struct dns_cache_packet *cache_packet = (struct dns_cache_packet *)dns_cache_get_data(dns_cache);
+	struct dns_cache_packet *cache_packet = NULL;
+	if (dns_cache->info.qtype != request->qtype) {
+		goto out;
+	}
+
+	cache_packet = (struct dns_cache_packet *)dns_cache_get_data(dns_cache);
 	if (cache_packet == NULL) {
 		goto out;
 	}
@@ -5172,10 +5176,6 @@ static int _dns_server_process_cache_packet(struct dns_request *request, struct
 		do_ipset = 1;
 	}
 
-	if (dns_cache->info.qtype != request->qtype) {
-		goto out;
-	}
-
 	struct dns_server_post_context context;
 	_dns_server_post_context_init(&context, request);
 	context.inpacket = cache_packet->data;

+ 5 - 1
src/lib/timer_wheel.c

@@ -167,17 +167,21 @@ void tw_add_timer(struct tw_base *base, struct tw_timer_list *timer)
 int tw_del_timer(struct tw_base *base, struct tw_timer_list *timer)
 {
 	int ret = 0;
+	int call_del = 0;
 
 	pthread_spin_lock(&base->lock);
 	{
 		if (timer_pending(timer)) {
 			ret = 1;
 			_tw_detach_timer(timer);
+			if (timer->del_function) {
+				call_del = 1;
+			}
 		}
 	}
 	pthread_spin_unlock(&base->lock);
 
-	if (ret == 1 && timer->del_function) {
+	if (call_del) {
 		timer->del_function(base, timer, timer->data);
 	}