Browse Source

minor optimize

Nick Peng 6 years ago
parent
commit
87f04571b1
14 changed files with 111 additions and 64 deletions
  1. 7 0
      src/dns_cache.h
  2. 7 9
      src/dns_client.c
  3. 8 0
      src/dns_client.h
  4. 1 1
      src/dns_conf.c
  5. 7 1
      src/dns_conf.h
  6. 54 47
      src/dns_server.c
  7. 2 1
      src/fast_ping.c
  8. 1 0
      src/include/stringutil.h
  9. 4 0
      src/lib/art.c
  10. 1 1
      src/lib/radix.c
  11. 0 1
      src/smartdns.c
  12. 2 1
      src/tlog.c
  13. 10 2
      src/util.c
  14. 7 0
      src/util.h

+ 7 - 0
src/dns_cache.h

@@ -9,6 +9,10 @@
 #include <stdlib.h>
 #include <time.h>
 
+#ifdef __cpluscplus
+extern "C" {
+#endif
+
 #define DNS_CACHE_TTL_MIN 30
 
 struct dns_cache {
@@ -59,4 +63,7 @@ int dns_cache_get_ttl(struct dns_cache *dns_cache);
 
 void dns_cache_destroy(void);
 
+#ifdef __cpluscplus
+}
+#endif
 #endif // !_SMARTDNS_CACHE_H

+ 7 - 9
src/dns_client.c

@@ -639,7 +639,6 @@ static char *_dns_client_server_get_tls_host_check(struct dns_server_info *serve
 		struct client_dns_server_flag_tls *flag_tls = &server_info->flags.tls;
 		tls_host_check = flag_tls->tls_host_check;
 	} break;
-		break;
 	case DNS_SERVER_TCP:
 		break;
 	default:
@@ -673,7 +672,6 @@ static char *_dns_client_server_get_spki(struct dns_server_info *server_info, in
 		spki = flag_tls->spki;
 		*spki_len = flag_tls->spi_len;
 	} break;
-		break;
 	case DNS_SERVER_TCP:
 		break;
 	default:
@@ -728,7 +726,6 @@ static int _dns_client_server_add(char *server_ip, char *server_host, int port,
 		spki_data_len = flag_tls->spi_len;
 		sock_type = SOCK_STREAM;
 	} break;
-		break;
 	case DNS_SERVER_TCP:
 		sock_type = SOCK_STREAM;
 		break;
@@ -1273,10 +1270,12 @@ static int _dns_client_recv(struct dns_server_info *server_info, unsigned char *
 
 	/* get query reference */
 	query = _dns_client_get_request(packet->head.id, domain);
-	if (query == NULL || (query && has_opt == 0 && server_info->flags.result_flag & DNSSERVER_FLAG_CHECK_EDNS)) {
-		if (query) {
-			_dns_client_query_release(query);
-		}
+	if (query == NULL) {
+		return 0;
+	}
+
+	if (has_opt == 0 && server_info->flags.result_flag & DNSSERVER_FLAG_CHECK_EDNS) {
+		_dns_client_query_release(query);
 		return 0;
 	}
 
@@ -2494,7 +2493,6 @@ errout_del_list:
 	query = NULL;
 errout:
 	if (query) {
-		tlog(TLOG_ERROR, "release %p", query);
 		free(query);
 	}
 	return -1;
@@ -2618,9 +2616,9 @@ static void _dns_client_add_pending_servers(void)
 		/* if has no bootstrap DNS, just call getaddrinfo to get address */
 		if (dns_client_has_bootstrap_dns == 0) {
 			if (_dns_client_add_pendings(pending, pending->host) != 0) {
+				pthread_mutex_unlock(&pending_server_mutex);
 				tlog(TLOG_ERROR, "add pending DNS server %s failed", pending->host);
 				exit(1);
-				pthread_mutex_unlock(&pending_server_mutex);
 				return;
 			}
 			list_del_init(&pending->list);

+ 8 - 0
src/dns_client.h

@@ -2,6 +2,11 @@
 #define _SMART_DNS_CLIENT_H
 
 #include "dns.h"
+
+#ifdef __cpluscplus
+extern "C" {
+#endif
+
 #define DNS_SERVER_SPKI_LEN 64
 #define DNS_SERVER_GROUP_DEFAULT "default"
 
@@ -87,4 +92,7 @@ int dns_client_remove_group(char *group_name);
 
 int dns_server_num(void);
 
+#ifdef __cpluscplus
+}
+#endif
 #endif

+ 1 - 1
src/dns_conf.c

@@ -213,7 +213,7 @@ static int _config_server(int argc, char *argv[], dns_server_type_t type, int de
 		safe_strncpy(server->hostname, server->server, sizeof(server->hostname));
 		safe_strncpy(server->httphost, server->server, sizeof(server->httphost));
 		if (server->path[0] == 0) {
-			strcpy(server->path, "/");
+			safe_strncpy(server->path, "/", sizeof(server->path));
 		}
 	} else {
 		/* parse ip, port from ip */

+ 7 - 1
src/dns_conf.h

@@ -10,6 +10,10 @@
 #include "list.h"
 #include "radix.h"
 
+#ifdef __cpluscplus
+extern "C" {
+#endif
+
 #define DNS_MAX_BIND_IP 16
 #define DNS_MAX_SERVERS 64
 #define DNS_MAX_SERVER_NAME_LEN 128
@@ -218,5 +222,7 @@ void dns_server_load_exit(void);
 int dns_server_load_conf(const char *file);
 
 extern int config_addtional_file(void *data, int argc, char *argv[]);
-
+#ifdef __cpluscplus
+}
+#endif
 #endif // !_DNS_CONF

+ 54 - 47
src/dns_server.c

@@ -873,36 +873,6 @@ static void _dns_server_select_possible_ipaddress(struct dns_request *request)
 	}
 }
 
-static struct dns_request *_dns_server_new_request(void)
-{
-	struct dns_request *request = NULL;
-
-	request = malloc(sizeof(*request));
-	if (request == NULL) {
-		tlog(TLOG_ERROR, "malloc failed.\n");
-		goto errout;
-	}
-
-	memset(request, 0, sizeof(*request));
-	pthread_mutex_init(&request->ip_map_lock, NULL);
-	atomic_set(&request->adblock, 0);
-	atomic_set(&request->soa_num, 0);
-	atomic_set(&request->refcnt, 0);
-	request->ping_ttl_v4 = -1;
-	request->ping_ttl_v6 = -1;
-	request->prefetch = 0;
-	request->rcode = DNS_RC_SERVFAIL;
-	request->conn = NULL;
-	request->result_callback = NULL;
-	request->check_order_list = &dns_conf_check_order;
-	INIT_LIST_HEAD(&request->list);
-	hash_init(request->ip_map);
-
-	return request;
-errout:
-	return NULL;
-}
-
 static void _dns_server_delete_request(struct dns_request *request)
 {
 	if (request->conn) {
@@ -913,7 +883,7 @@ static void _dns_server_delete_request(struct dns_request *request)
 	free(request);
 }
 
-static void _dns_server_request_release(struct dns_request *request)
+static void _dns_server_request_release_complete(struct dns_request *request, int do_complete)
 {
 	struct dns_ip_address *addr_map;
 	struct hlist_node *tmp;
@@ -932,10 +902,12 @@ static void _dns_server_request_release(struct dns_request *request)
 	list_del_init(&request->list);
 	pthread_mutex_unlock(&server.request_list_lock);
 
-	/* Select max hit ip address, and return to client */
-	_dns_server_select_possible_ipaddress(request);
+	if (do_complete) {
+		/* Select max hit ip address, and return to client */
+		_dns_server_select_possible_ipaddress(request);
+		_dns_server_request_complete(request);
+	}
 
-	_dns_server_request_complete(request);
 	hash_for_each_safe(request->ip_map, bucket, tmp, addr_map, node)
 	{
 		hash_del(&addr_map->node);
@@ -945,6 +917,11 @@ static void _dns_server_request_release(struct dns_request *request)
 	_dns_server_delete_request(request);
 }
 
+static void _dns_server_request_release(struct dns_request *request)
+{
+	_dns_server_request_release_complete(request, 1);
+}
+
 static void _dns_server_request_get(struct dns_request *request)
 {
 	if (atomic_inc_return(&request->refcnt) <= 0) {
@@ -953,6 +930,37 @@ static void _dns_server_request_get(struct dns_request *request)
 	}
 }
 
+static struct dns_request *_dns_server_new_request(void)
+{
+	struct dns_request *request = NULL;
+
+	request = malloc(sizeof(*request));
+	if (request == NULL) {
+		tlog(TLOG_ERROR, "malloc failed.\n");
+		goto errout;
+	}
+
+	memset(request, 0, sizeof(*request));
+	pthread_mutex_init(&request->ip_map_lock, NULL);
+	atomic_set(&request->adblock, 0);
+	atomic_set(&request->soa_num, 0);
+	atomic_set(&request->refcnt, 0);
+	request->ping_ttl_v4 = -1;
+	request->ping_ttl_v6 = -1;
+	request->prefetch = 0;
+	request->rcode = DNS_RC_SERVFAIL;
+	request->conn = NULL;
+	request->result_callback = NULL;
+	request->check_order_list = &dns_conf_check_order;
+	INIT_LIST_HEAD(&request->list);
+	hash_init(request->ip_map);
+	_dns_server_request_get(request);
+
+	return request;
+errout:
+	return NULL;
+}
+
 static void _dns_server_ping_result(struct ping_host_struct *ping_host, const char *host, FAST_PING_RESULT result, struct sockaddr *addr, socklen_t addr_len,
 									int seqno, int ttl, struct timeval *tv, void *userptr)
 {
@@ -2057,40 +2065,37 @@ static int _dns_server_do_query(struct dns_request *request, const char *domain,
 		}
 	}
 
+	// Get reference for server thread
 	_dns_server_request_get(request);
 	pthread_mutex_lock(&server.request_list_lock);
 	list_add_tail(&request->list, &server.request_list);
 	pthread_mutex_unlock(&server.request_list_lock);
-
-	_dns_server_request_get(request);
 	request->send_tick = get_tick_count();
 
 	/* When the dual stack ip preference is enabled, both A and AAAA records are requested. */
 	if (qtype == DNS_T_AAAA && _dns_server_is_dualstack_selection(request)) {
+		// Get reference for AAAA query
 		_dns_server_request_get(request);
 		request->request_wait++;
 		if (dns_client_query(request->domain, DNS_T_A, dns_server_resolve_callback, request, group_name) != 0) {
-			_dns_server_request_release(request);
 			request->request_wait--;
+			_dns_server_request_release(request);
 		}
 	}
 
+	// Get reference for DNS query
 	request->request_wait++;
+	_dns_server_request_get(request);
 	if (dns_client_query(request->domain, qtype, dns_server_resolve_callback, request, group_name) != 0) {
+		request->request_wait--;
 		_dns_server_request_release(request);
 		tlog(TLOG_ERROR, "send dns request failed.");
 		goto errout;
 	}
 
-	return 0;
 clean_exit:
-	if (request) {
-		_dns_server_delete_request(request);
-	}
-
 	return 0;
 errout:
-
 	_dns_server_request_remove(request);
 	request = NULL;
 	return ret;
@@ -2172,12 +2177,12 @@ static int _dns_server_recv(struct dns_server_conn_head *conn, unsigned char *in
 		tlog(TLOG_ERROR, "do query %s failed.\n", domain);
 		goto errout;
 	}
-
+	_dns_server_request_release_complete(request, 0);
 	return ret;
 errout:
 	if (request) {
 		ret = _dns_server_forward_request(inpacket, inpacket_len);
-		_dns_server_delete_request(request);
+		_dns_server_request_release(request);
 	}
 
 	return ret;
@@ -2201,10 +2206,11 @@ static int _dns_server_prefetch_request(char *domain, dns_type_t qtype)
 		goto errout;
 	}
 
+	_dns_server_request_release_complete(request, 0);
 	return ret;
 errout:
 	if (request) {
-		_dns_server_delete_request(request);
+		_dns_server_request_release(request);
 	}
 
 	return ret;
@@ -2228,10 +2234,11 @@ int dns_server_query(char *domain, int qtype, dns_result_callback callback, void
 		goto errout;
 	}
 
+	_dns_server_request_release_complete(request, 0);
 	return ret;
 errout:
 	if (request) {
-		_dns_server_delete_request(request);
+		_dns_server_request_release(request);
 	}
 
 	return ret;

+ 2 - 1
src/fast_ping.c

@@ -1470,9 +1470,10 @@ static void _fast_ping_period_run(void)
 	struct hlist_node *tmp = NULL;
 	int i = 0;
 	struct timeval now;
+	struct timezone tz;
 	struct timeval interval;
 	int64_t millisecond;
-	gettimeofday(&now, NULL);
+	gettimeofday(&now, &tz);
 	LIST_HEAD(action);
 
 	pthread_mutex_lock(&ping.map_lock);

+ 1 - 0
src/include/stringutil.h

@@ -20,4 +20,5 @@ static inline char *safe_strncpy(char *dest, const char *src, size_t n)
 	return ret;
 }
 
+
 #endif

+ 4 - 0
src/lib/art.c

@@ -390,6 +390,10 @@ art_leaf* art_maximum(art_tree *t) {
 
 static art_leaf* make_leaf(const unsigned char *key, int key_len, void *value) {
     art_leaf *l = (art_leaf*)calloc(1, sizeof(art_leaf)+key_len+1);
+    if (l == NULL) {
+		return NULL;
+	}
+    
     l->value = value;
     l->key_len = key_len;
     memcpy(l->key, key, key_len);

+ 1 - 1
src/lib/radix.c

@@ -281,7 +281,7 @@ static radix_node_t
 *radix_search_best2(radix_tree_t *radix, prefix_t *prefix, int inclusive)
 {
 	radix_node_t *node;
-	radix_node_t *stack[RADIX_MAXBITS + 1];
+	radix_node_t *stack[RADIX_MAXBITS + 1] = {0};
 	unsigned char *addr;
 	unsigned int bitlen;
 	int cnt = 0;

+ 0 - 1
src/smartdns.c

@@ -165,7 +165,6 @@ static int _smartdns_add_servers(void)
 			safe_strncpy(flag_tls->hostname, dns_conf_servers[i].hostname, sizeof(flag_tls->hostname));
 			safe_strncpy(flag_tls->tls_host_check, dns_conf_servers[i].tls_host_check, sizeof(flag_tls->tls_host_check));
 		} break;
-			break;
 		case DNS_SERVER_TCP:
 			break;
 		default:

+ 2 - 1
src/tlog.c

@@ -453,6 +453,7 @@ static int _tlog_vprintf(struct tlog_log *log, vprint_callback print_callback, v
         return -1;
     } else if (len >= TLOG_MAX_LINE_LEN) {
 		strncpy(buff, "[LOG TOO LONG, DISCARD]\n", sizeof(buff));
+        buff[sizeof(buff) - 1] = '\0';
 		len = strnlen(buff, sizeof(buff));
 	}
 
@@ -1528,7 +1529,7 @@ tlog_log *tlog_open(const char *logfile, int maxlogsize, int maxlogcount, int bu
     strncpy(log->logname, basename(log_file), sizeof(log->logname));
     log->logname[sizeof(log->logname) - 1] = '\0';
     if (log->nocompress) {
-        strncpy(log->suffix, TLOG_SUFFIX_LOG, sizeof(sizeof(log->suffix)));
+        strncpy(log->suffix, TLOG_SUFFIX_LOG, sizeof(log->suffix));
     } else {
         strncpy(log->suffix, TLOG_SUFFIX_GZ, sizeof(log->suffix));
     }

+ 10 - 2
src/util.c

@@ -65,6 +65,7 @@ struct ipset_netlink_msg {
 };
 
 static int ipset_fd;
+static int pidfile_fd;
 
 unsigned long get_tick_count(void)
 {
@@ -125,6 +126,7 @@ int getaddr_by_host(char *host, struct sockaddr *addr, socklen_t *addr_len)
 		result->ai_addrlen = *addr_len;
 	}
 
+	addr->sa_family = result->ai_family;
 	memcpy(addr, result->ai_addr, result->ai_addrlen);
 	*addr_len = result->ai_addrlen;
 
@@ -354,7 +356,7 @@ int parse_uri(char *value, char *scheme, char *host, int *port, char *path)
 	process_ptr += field_len;
 
 	if (path) {
-		strcpy(path, process_ptr);
+		strncpy(path, process_ptr, PATH_MAX);
 	} 
 	return 0;
 }
@@ -608,6 +610,12 @@ int create_pid_file(const char *pid_file)
 		goto errout;
 	}
 
+	if (pidfile_fd > 0) {
+		close(pidfile_fd);
+	}
+
+	pidfile_fd = fd;
+
 	return 0;
 errout:
 	if (fd > 0) {
@@ -859,7 +867,7 @@ void get_compiled_time(struct tm *tm)
 	int hour, min, sec;
     static const char *month_names = "JanFebMarAprMayJunJulAugSepOctNovDec";
 
-    sscanf(__DATE__, "%s %d %d", s_month, &day, &year);
+    sscanf(__DATE__, "%5s %d %d", s_month, &day, &year);
     month = (strstr(month_names, s_month) - month_names) / 3;
 	sscanf(__TIME__, "%d:%d:%d", &hour, &min, &sec);
     tm->tm_year = year - 1900;

+ 7 - 0
src/util.h

@@ -7,6 +7,10 @@
 #include <time.h>
 #include "stringutil.h"
 
+#ifdef __cplusplus
+extern "C" {
+#endif /*__cplusplus */
+
 #define PORT_NOT_DEFINED -1
 #define MAX_IP_LEN 64
 
@@ -61,4 +65,7 @@ int parse_tls_header(const char *data, size_t data_len, char *hostname, const ch
 
 void get_compiled_time(struct tm *tm);
 
+#ifdef __cplusplus
+}
+#endif /*__cplusplus */
 #endif