Kaynağa Gözat

smartdns: follow sysv daemon initialize steps

Nick Peng 2 yıl önce
ebeveyn
işleme
d59c148a28
5 değiştirilmiş dosya ile 232 ekleme ve 13 silme
  1. 2 2
      src/dns_conf.c
  2. 1 1
      src/lib/conf.c
  3. 41 8
      src/smartdns.c
  4. 180 2
      src/util.c
  5. 8 0
      src/util.h

+ 2 - 2
src/dns_conf.c

@@ -763,7 +763,7 @@ static int _config_domain_rule_each_from_list(const char *file, domain_set_rule_
 	line_no = 0;
 	while (fgets(line, MAX_LINE_LEN, fp)) {
 		line_no++;
-		filed_num = sscanf(line, "%256s", domain);
+		filed_num = sscanf(line, "%255s", domain);
 		if (filed_num <= 0) {
 			continue;
 		}
@@ -3198,7 +3198,7 @@ static int _conf_dhcp_lease_dnsmasq_add(const char *file)
 	line_no = 0;
 	while (fgets(line, MAX_LINE_LEN, fp)) {
 		line_no++;
-		filed_num = sscanf(line, "%*s %*s %64s %256s %*s", ip, hostname);
+		filed_num = sscanf(line, "%*s %*s %63s %255s %*s", ip, hostname);
 		if (filed_num <= 0) {
 			continue;
 		}

+ 1 - 1
src/lib/conf.c

@@ -377,7 +377,7 @@ static int load_conf_file(const char *file, struct config_item *items, conf_erro
 		}
 		line_len = 0;
 
-		filed_num = sscanf(line, "%63s %8192[^\r\n]s", key, value);
+		filed_num = sscanf(line, "%63s %8191[^\r\n]s", key, value);
 		if (filed_num <= 0) {
 			continue;
 		}

+ 41 - 8
src/smartdns.c

@@ -459,6 +459,7 @@ static int _smartdns_init(void)
 	int ret = 0;
 	const char *logfile = _smartdns_log_path();
 	int i = 0;
+	char logdir[PATH_MAX] = {0};
 
 	ret = tlog_init(logfile, dns_conf_log_size, dns_conf_log_num, 0, 0);
 	if (ret != 0) {
@@ -466,7 +467,8 @@ static int _smartdns_init(void)
 		goto errout;
 	}
 
-	if (verbose_screen != 0 || dns_conf_log_console != 0) {
+	safe_strncpy(logdir, _smartdns_log_path(), PATH_MAX);
+	if (verbose_screen != 0 || dns_conf_log_console != 0 || access(dir_name(logdir), W_OK) != 0) {
 		tlog_setlogscreen(1);
 	}
 
@@ -736,9 +738,11 @@ static void smartdns_test_notify_func(int fd_notify, uint64_t retval)
 	}
 }
 
+#define smartdns_close_allfds() close_all_fd(fd_notify);
 int smartdns_main(int argc, char *argv[], int fd_notify)
 #else
 #define smartdns_test_notify(retval)
+#define smartdns_close_allfds() close_all_fd(-1);
 int main(int argc, char *argv[])
 #endif
 {
@@ -750,6 +754,7 @@ int main(int argc, char *argv[])
 	int signal_ignore = 0;
 	sigset_t empty_sigblock;
 	struct stat sb;
+	int daemon_ret = 0;
 
 	safe_strncpy(config_file, SMARTDNS_CONF_FILE, MAX_LINE_LEN);
 
@@ -762,6 +767,7 @@ int main(int argc, char *argv[])
 	/* patch for Asus router:  unblock all signal*/
 	sigemptyset(&empty_sigblock);
 	sigprocmask(SIG_SETMASK, &empty_sigblock, NULL);
+	smartdns_close_allfds();
 
 	while ((opt = getopt(argc, argv, "fhc:p:SvxN:")) != -1) {
 		switch (opt) {
@@ -769,10 +775,14 @@ int main(int argc, char *argv[])
 			is_foreground = 1;
 			break;
 		case 'c':
-			snprintf(config_file, sizeof(config_file), "%s", optarg);
+			if (full_path(config_file, sizeof(config_file), optarg) != 0) {
+				snprintf(config_file, sizeof(config_file), "%s", optarg);
+			}
 			break;
 		case 'p':
-			snprintf(pid_file, sizeof(pid_file), "%s", optarg);
+			if (strncmp(optarg, "-", 2) == 0 || full_path(pid_file, sizeof(pid_file), optarg) != 0) {
+				snprintf(pid_file, sizeof(pid_file), "%s", optarg);
+			}
 			break;
 		case 'S':
 			signal_ignore = 1;
@@ -794,16 +804,27 @@ int main(int argc, char *argv[])
 		}
 	}
 
-	if (dns_server_load_conf(config_file) != 0) {
+	ret = dns_server_load_conf(config_file);
+	if (ret != 0) {
 		fprintf(stderr, "load config failed.\n");
 		goto errout;
 	}
 
 	if (is_foreground == 0) {
-		if (daemon(0, 0) < 0) {
-			fprintf(stderr, "run daemon process failed, %s\n", strerror(errno));
+		daemon_ret = run_daemon();
+		if (daemon_ret < 0) {
+			char buff[4096];
+			char *log_path = realpath(_smartdns_log_path(), buff);
+
+			if (log_path != NULL && access(log_path, F_OK) == 0 && daemon_ret != -2) {
+				fprintf(stderr, "run daemon failed, please check log at %s\n", log_path);
+			}
 			return 1;
 		}
+
+		if (daemon_ret == 0) {
+			return 0;
+		}
 	}
 
 	if (signal_ignore == 0) {
@@ -811,6 +832,7 @@ int main(int argc, char *argv[])
 	}
 
 	if (strncmp(pid_file, "-", 2) != 0 && create_pid_file(pid_file) != 0) {
+		ret = -2;
 		goto errout;
 	}
 
@@ -818,9 +840,10 @@ int main(int argc, char *argv[])
 	signal(SIGINT, _sig_exit);
 	signal(SIGTERM, _sig_exit);
 
-	if (_smartdns_init_pre() != 0) {
+	ret = _smartdns_init_pre();
+	if (ret != 0) {
 		fprintf(stderr, "init failed.\n");
-		return 1;
+		goto errout;
 	}
 
 	drop_root_privilege();
@@ -831,11 +854,21 @@ int main(int argc, char *argv[])
 		goto errout;
 	}
 
+	if (daemon_ret > 0) {
+		ret = daemon_kickoff(daemon_ret, 0);
+		if (ret != 0) {
+			goto errout;
+		}
+	}
+
 	smartdns_test_notify(1);
 	ret = _smartdns_run();
 	_smartdns_exit();
 	return ret;
 errout:
+	if (daemon_ret > 0) {
+		daemon_kickoff(daemon_ret, ret);
+	}
 	smartdns_test_notify(2);
 	return 1;
 }

+ 180 - 2
src/util.c

@@ -25,6 +25,7 @@
 #include "util.h"
 #include <arpa/inet.h>
 #include <ctype.h>
+#include <dirent.h>
 #include <dlfcn.h>
 #include <errno.h>
 #include <fcntl.h>
@@ -38,11 +39,13 @@
 #include <openssl/crypto.h>
 #include <openssl/ssl.h>
 #include <openssl/x509v3.h>
+#include <poll.h>
 #include <pthread.h>
 #include <signal.h>
 #include <stdlib.h>
 #include <string.h>
 #include <sys/prctl.h>
+#include <sys/resource.h>
 #include <sys/stat.h>
 #include <sys/statvfs.h>
 #include <sys/sysinfo.h>
@@ -806,7 +809,11 @@ int create_pid_file(const char *pid_file)
 	}
 
 	if (lockf(fd, F_TLOCK, 0) < 0) {
-		fprintf(stderr, "Server is already running.\n");
+		memset(buff, 0, TMP_BUFF_LEN_32);
+		if (read(fd, buff, TMP_BUFF_LEN_32) <= 0) {
+			buff[0] = '\0';
+		}
+		fprintf(stderr, "Server is already running, pid is %s", buff);
 		goto errout;
 	}
 
@@ -831,6 +838,27 @@ errout:
 	return -1;
 }
 
+int full_path(char *normalized_path, int normalized_path_len, const char *path)
+{
+	const char *p = path;
+
+	if (path == NULL || normalized_path == NULL) {
+		return -1;
+	}
+
+	while (*p == ' ') {
+		p++;
+	}
+
+	if (*p == '\0' || *p == '/') {
+		return -1;
+	}
+
+	char buf[PATH_MAX];
+	snprintf(normalized_path, normalized_path_len, "%s/%s", getcwd(buf, sizeof(buf)), path);
+	return 0;
+}
+
 int generate_cert_key(const char *key_path, const char *cert_path, const char *san, int days)
 {
 	int ret = -1;
@@ -1479,6 +1507,156 @@ out:
 	return ret;
 }
 
+static void _close_all_fd_by_res(void)
+{
+	struct rlimit lim;
+	int maxfd = 0;
+	int i = 0;
+
+	getrlimit(RLIMIT_NOFILE, &lim);
+
+	maxfd = lim.rlim_cur;
+	if (maxfd > 4096) {
+		maxfd = 4096;
+	}
+
+	for (i = 3; i < maxfd; i++) {
+		close(i);
+	}
+}
+
+void close_all_fd(int keepfd)
+{
+	DIR *dirp;
+	int dir_fd = -1;
+	struct dirent *dentp;
+
+	dirp = opendir("/proc/self/fd");
+	if (dirp == NULL) {
+		goto errout;
+	}
+
+	dir_fd = dirfd(dirp);
+
+	while ((dentp = readdir(dirp)) != NULL) {
+		int fd = atol(dentp->d_name);
+		if (fd < 0) {
+			continue;
+		}
+
+		if (fd == dir_fd || fd == STDIN_FILENO || fd == STDOUT_FILENO || fd == STDERR_FILENO || fd == keepfd) {
+			continue;
+		}
+		close(fd);
+	}
+
+	closedir(dirp);
+	return;
+errout:
+	if (dirp) {
+		closedir(dirp);
+	}
+	_close_all_fd_by_res();
+	return;
+}
+
+int daemon_kickoff(int fd, int status)
+{
+	if (fd <= 0) {
+		return -1;
+	}
+
+	int ret = write(fd, &status, sizeof(status));
+	if (ret != sizeof(status)) {
+		return -1;
+	}
+
+	int fd_null = open("/dev/null", O_RDWR);
+	if (fd_null < 0) {
+		fprintf(stderr, "open /dev/null failed, %s\n", strerror(errno));
+		return -1;
+	}
+
+	dup2(fd_null, STDIN_FILENO);
+	dup2(fd_null, STDOUT_FILENO);
+	dup2(fd_null, STDERR_FILENO);
+
+	if (fd_null > 2) {
+		close(fd_null);
+	}
+
+	close(fd);
+
+	return 0;
+}
+
+int run_daemon()
+{
+	pid_t pid = 0;
+	int fds[2] = {0};
+
+	if (pipe(fds) != 0) {
+		fprintf(stderr, "run daemon process failed, pipe failed, %s\n", strerror(errno));
+		return -1;
+	}
+
+	pid = fork();
+	if (pid < 0) {
+		fprintf(stderr, "run daemon process failed, fork failed, %s\n", strerror(errno));
+		close(fds[0]);
+		close(fds[1]);
+		return -1;
+	} else if (pid > 0) {
+		struct pollfd pfd;
+		int ret = 0;
+		int status = 0;
+
+		close(fds[1]);
+
+		pfd.fd = fds[0];
+		pfd.events = POLLIN;
+		pfd.revents = 0;
+
+		ret = poll(&pfd, 1, 1000);
+		if (ret <= 0) {
+			fprintf(stderr, "run daemon process failed, wait child timeout\n");
+			goto errout;
+		}
+
+		if (!(pfd.revents & POLLIN)) {
+			goto errout;
+		}
+
+		ret = read(fds[0], &status, sizeof(status));
+		if (ret != sizeof(status)) {
+			goto errout;
+		}
+
+		return status;
+	}
+
+	setsid();
+
+	pid = fork();
+	if (pid < 0) {
+		fprintf(stderr, "double fork failed, %s\n", strerror(errno));
+		_exit(1);
+	} else if (pid > 0) {
+		_exit(0);
+	}
+
+	umask(0);
+	if (chdir("/") != 0) {
+		goto errout;
+	}
+	close(fds[0]);
+	return fds[1];
+
+errout:
+	kill(pid, SIGKILL);
+	return -1;
+}
+
 #ifdef DEBUG
 struct _dns_read_packet_info {
 	int data_len;
@@ -1604,7 +1782,7 @@ static int _dns_debug_display(struct dns_packet *packet)
 				int ret = 0;
 
 				ret = dns_get_HTTPS_svcparm_start(rrs, &p, name, DNS_MAX_CNAME_LEN, &ttl, &priority, target,
-												DNS_MAX_CNAME_LEN);
+												  DNS_MAX_CNAME_LEN);
 				if (ret != 0) {
 					printf("get HTTPS svcparm failed\n");
 					break;

+ 8 - 0
src/util.h

@@ -105,6 +105,8 @@ int generate_cert_key(const char *key_path, const char *cert_path, const char *s
 
 int create_pid_file(const char *pid_file);
 
+int full_path(char *normalized_path, int normalized_path_len, const char *path);
+
 /* Parse a TLS packet for the Server Name Indication extension in the client
  * hello handshake, returning the first server name found (pointer to static
  * array)
@@ -138,6 +140,12 @@ uint64_t get_free_space(const char *path);
 
 void print_stack(void);
 
+void close_all_fd(int keepfd);
+
+int run_daemon(void);
+
+int daemon_kickoff(int fd, int status);
+
 int write_file(const char *filename, void *data, int data_len);
 
 int dns_packet_save(const char *dir, const char *type, const char *from, const void *packet, int packet_len);