Просмотр исходного кода

feature: domain-rule, conf-file support group, add group-begin and group-end

Nick Peng 1 год назад
Родитель
Сommit
707a6dbcae
9 измененных файлов с 536 добавлено и 33 удалено
  1. 5 2
      etc/smartdns/smartdns.conf
  2. 336 25
      src/dns_conf.c
  3. 5 2
      src/dns_conf.h
  4. 12 1
      src/dns_server.c
  5. 2 2
      src/lib/conf.c
  6. 38 1
      test/cases/test-client-rule.cc
  7. 97 0
      test/cases/test-group.cc
  8. 2 0
      test/include/utils.h
  9. 39 0
      test/utils.cc

+ 5 - 2
etc/smartdns/smartdns.conf

@@ -13,9 +13,10 @@
 #   user nobody
 #
 
-# Include another configuration options
-# conf-file [file]
+# Include another configuration options, if -group is specified, only include the rules to specified group.
+# conf-file [file] [-group group-name]
 # conf-file blacklist-ip.conf
+# conf-file whitelist-ip.conf -group office
 # conf-file *.conf
 
 # dns server bind ip and port, default dns server port is 53, support binding multi ip and port
@@ -46,6 +47,7 @@
 #   -force-aaaa-soa: force AAAA query return SOA.
 #   -force-https-soa: force HTTPS query return SOA.
 #   -no-serve-expired: no serve expired.
+#   -no-rules: skip all rules.
 #   -ipset ipsetname: use ipset rule.
 #   -nftset nftsetname: use nftset rule.
 # example: 
@@ -334,6 +336,7 @@ log-level info
 #   [-p] -ipset [ipset|-]: same as ipset option
 #   [-t] -nftset [nftset|-]: same as nftset option
 #   [-d] -dualstack-ip-selection [yes|no]: same as dualstack-ip-selection option
+#   [-g|-group group-name]: set domain-rules to group.
 #   -no-serve-expired: ignore expired domain
 #   -delete: delete domain rule
 #   -no-ip-alias: ignore ip alias

+ 336 - 25
src/dns_conf.c

@@ -143,6 +143,15 @@ int dns_conf_audit_console;
 int dns_conf_audit_syslog;
 
 /* address rules */
+struct dns_conf_group_info {
+	struct list_head list;
+	const char *group_name;
+	struct dns_conf_doamin_rule_group *domain_rule;
+};
+struct dns_conf_group_info *dns_conf_current_group_info;
+struct dns_conf_group_info *dns_conf_default_group_info;
+static LIST_HEAD(dns_conf_group_info_list);
+
 struct dns_conf_domain_rule dns_conf_domain_rule;
 struct dns_conf_address_rule dns_conf_address_rule;
 struct dns_conf_client_rule dns_conf_client_rule;
@@ -199,6 +208,8 @@ static int _conf_domain_rule_address(char *domain, const char *domain_address);
 static struct dns_domain_rule *_config_domain_rule_get(const char *domain);
 typedef int (*set_rule_add_func)(const char *value, void *priv);
 static int _config_ip_rule_set_each(const char *ip_set, set_rule_add_func callback, void *priv);
+static struct dns_conf_doamin_rule_group *_config_domain_rule_group_get(const char *group_name);
+static struct dns_conf_doamin_rule_group *_config_domain_rule_group_new(const char *group_name);
 
 static void *_new_dns_rule_ext(enum domain_rule domain_rule, int ext_size)
 {
@@ -444,6 +455,136 @@ struct dns_proxy_names *dns_server_get_proxy_nams(const char *proxyname)
 	return NULL;
 }
 
+static struct dns_conf_group_info *_config_current_group(void)
+{
+	return dns_conf_current_group_info;
+}
+
+static void _config_current_group_pop(void)
+{
+	struct dns_conf_group_info *group_info = NULL;
+
+	group_info = list_last_entry(&dns_conf_group_info_list, struct dns_conf_group_info, list);
+	if (group_info == NULL) {
+		return;
+	}
+
+	if (group_info == dns_conf_default_group_info) {
+		dns_conf_current_group_info = dns_conf_default_group_info;
+		return;
+	}
+
+	list_del(&group_info->list);
+	free(group_info);
+
+	group_info = list_last_entry(&dns_conf_group_info_list, struct dns_conf_group_info, list);
+	if (group_info == NULL) {
+		dns_conf_current_group_info = NULL;
+		return;
+	}
+
+	dns_conf_current_group_info = group_info;
+}
+
+static int _config_current_group_push(const char *group_name)
+{
+	struct dns_conf_group_info *group_info = NULL;
+	struct dns_conf_doamin_rule_group *domain_rule = NULL;
+
+	group_info = malloc(sizeof(*group_info));
+	if (group_info == NULL) {
+		goto errout;
+	}
+
+	if (dns_conf_default_group_info != NULL) {
+		group_name = _dns_conf_get_group_name(group_name);
+		if (group_name == NULL) {
+			goto errout;
+		}
+	}
+
+	memset(group_info, 0, sizeof(*group_info));
+	INIT_LIST_HEAD(&group_info->list);
+	list_add_tail(&group_info->list, &dns_conf_group_info_list);
+
+	domain_rule = _config_domain_rule_group_get(group_name);
+	if (domain_rule == NULL) {
+		domain_rule = _config_domain_rule_group_new(group_name);
+		if (domain_rule == NULL) {
+			goto errout;
+		}
+	}
+
+	group_info->group_name = group_name;
+	group_info->domain_rule = domain_rule;
+
+	dns_conf_current_group_info = group_info;
+	if (dns_conf_default_group_info == NULL) {
+		dns_conf_default_group_info = group_info;
+	}
+
+	return 0;
+
+errout:
+	if (group_info) {
+		free(group_info);
+	}
+	return -1;
+}
+
+static int _config_group_begin(void *data, int argc, char *argv[])
+{
+	const char *group_name = NULL;
+	if (argc < 2) {
+		return -1;
+	}
+
+	group_name = argv[1];
+	if (group_name[0] == '\0') {
+		group_name = NULL;
+	}
+
+	if (_config_current_group_push(group_name) != 0) {
+		return -1;
+	}
+
+	return 0;
+}
+
+static int _config_current_group_push_default(void)
+{
+	return _config_current_group_push(NULL);
+}
+
+static int _config_current_group_pop_to(struct dns_conf_group_info *group_info)
+{
+	while (dns_conf_current_group_info != NULL && dns_conf_current_group_info != group_info) {
+		_config_current_group_pop();
+	}
+
+	return 0;
+}
+
+static int _config_current_group_pop_all(void)
+{
+	while (dns_conf_current_group_info != NULL && dns_conf_current_group_info != dns_conf_default_group_info) {
+		_config_current_group_pop();
+	}
+
+	list_del(&dns_conf_default_group_info->list);
+	free(dns_conf_default_group_info);
+	dns_conf_default_group_info = NULL;
+	dns_conf_current_group_info = NULL;
+
+	return 0;
+}
+
+static int _config_group_end(void *data, int argc, char *argv[])
+{
+	_config_current_group_pop();
+	return 0;
+}
+
 /* create and get dns server group */
 static struct dns_proxy_names *_dns_conf_get_proxy(const char *proxy_name)
 {
@@ -559,6 +700,7 @@ static int _config_server(int argc, char *argv[], dns_server_type_t type, int de
 	char host_ip[DNS_MAX_IPLEN] = {0};
 	int no_tls_host_name = 0;
 	int no_tls_host_verify = 0;
+	const char *group_name = NULL;
 
 	int ttl = 0;
 	/* clang-format off */
@@ -636,6 +778,11 @@ static int _config_server(int argc, char *argv[], dns_server_type_t type, int de
 		port = default_port;
 	}
 
+	/* get current group */
+	if (_config_current_group()) {
+		group_name = _config_current_group()->group_name;
+	}
+
 	/* process extra options */
 	optind = 1;
 	while (1) {
@@ -654,10 +801,7 @@ static int _config_server(int argc, char *argv[], dns_server_type_t type, int de
 			break;
 		}
 		case 'g': {
-			if (_dns_conf_get_group_set(optarg, server) != 0) {
-				tlog(TLOG_ERROR, "add group failed.");
-				goto errout;
-			}
+			group_name = optarg;
 			break;
 		}
 		case 'p': {
@@ -786,6 +930,13 @@ static int _config_server(int argc, char *argv[], dns_server_type_t type, int de
 		}
 	}
 
+	if (group_name) {
+		if (_dns_conf_get_group_set(group_name, server) != 0) {
+			tlog(TLOG_ERROR, "add group failed.");
+			goto errout;
+		}
+	}
+
 	dns_conf_server_num++;
 	tlog(TLOG_DEBUG, "add server %s, flag: %X, ttl: %d", ip, result_flag, ttl);
 
@@ -852,6 +1003,67 @@ static int _config_domain_iter_free(void *data, const unsigned char *key, uint32
 	return _config_domain_rule_free(domain_rule);
 }
 
+static struct dns_conf_doamin_rule_group *_config_domain_rule_group_get(const char *group_name)
+{
+	uint32_t key = 0;
+	struct dns_conf_doamin_rule_group *domain_rule_group = NULL;
+	if (group_name == NULL) {
+		group_name = "";
+	}
+
+	key = hash_string(group_name);
+	hash_for_each_possible(dns_conf_domain_rule.group, domain_rule_group, node, key)
+	{
+		if (strncmp(domain_rule_group->group_name, group_name, DNS_GROUP_NAME_LEN) == 0) {
+			return domain_rule_group;
+		}
+	}
+
+	return NULL;
+}
+
+struct dns_conf_doamin_rule_group *dns_server_get_domain_rule_group(const char *group_name, int no_fallback_default)
+{
+	struct dns_conf_doamin_rule_group *domain_rule_group = _config_domain_rule_group_get(group_name);
+	if (domain_rule_group) {
+		return domain_rule_group;
+	}
+
+	if (no_fallback_default) {
+		return NULL;
+	}
+
+	return dns_conf_domain_rule.default_rule;
+}
+
+static struct dns_conf_doamin_rule_group *_config_domain_rule_group_new(const char *group_name)
+{
+	struct dns_conf_doamin_rule_group *domain_rule_group = NULL;
+	uint32_t key = 0;
+
+	domain_rule_group = malloc(sizeof(*domain_rule_group));
+	if (domain_rule_group == NULL) {
+		return NULL;
+	}
+
+	memset(domain_rule_group, 0, sizeof(*domain_rule_group));
+	domain_rule_group->group_name = group_name;
+	INIT_HLIST_NODE(&domain_rule_group->node);
+	art_tree_init(&domain_rule_group->tree);
+	key = hash_string(group_name);
+	hash_add(dns_conf_domain_rule.group, &domain_rule_group->node, key);
+
+	return domain_rule_group;
+}
+
+static void _config_domain_rule_remove(struct dns_conf_doamin_rule_group *domain_rule_group)
+{
+	hlist_del_init(&domain_rule_group->node);
+	art_iter(&domain_rule_group->tree, _config_domain_iter_free, NULL);
+	art_tree_destroy(&domain_rule_group->tree);
+	free(domain_rule_group);
+}
+
 static void _config_domain_destroy(void)
 {
 	struct dns_conf_doamin_rule_group *group;
@@ -860,14 +1072,10 @@ static void _config_domain_destroy(void)
 
 	hash_for_each_safe(dns_conf_domain_rule.group, i, tmp, group, node)
 	{
-		hlist_del_init(&group->node);
-		art_iter(&group->rule, _config_domain_iter_free, NULL);
-		art_tree_destroy(&group->rule);
-		free(group);
+		_config_domain_rule_remove(group);
 	}
 
-	art_iter(&dns_conf_domain_rule.default_rule, _config_domain_iter_free, NULL);
-	art_tree_destroy(&dns_conf_domain_rule.default_rule);
+	dns_conf_domain_rule.default_rule = NULL;
 }
 
 static int _config_set_rule_each_from_list(const char *file, set_rule_add_func callback, void *priv)
@@ -1018,7 +1226,7 @@ static __attribute__((unused)) struct dns_domain_rule *_config_domain_rule_get(c
 		return NULL;
 	}
 
-	return art_search(&dns_conf_domain_rule.default_rule, (unsigned char *)domain_key, len);
+	return art_search(&_config_current_group()->domain_rule->tree, (unsigned char *)domain_key, len);
 }
 
 static int _config_domain_rule_add(const char *domain, enum domain_rule type, void *rule)
@@ -1050,7 +1258,7 @@ static int _config_domain_rule_add(const char *domain, enum domain_rule type, vo
 	}
 
 	/* Get existing or create domain rule */
-	domain_rule = art_search(&dns_conf_domain_rule.default_rule, (unsigned char *)domain_key, len);
+	domain_rule = art_search(&_config_current_group()->domain_rule->tree, (unsigned char *)domain_key, len);
 	if (domain_rule == NULL) {
 		add_domain_rule = malloc(sizeof(*add_domain_rule));
 		if (add_domain_rule == NULL) {
@@ -1074,7 +1282,7 @@ static int _config_domain_rule_add(const char *domain, enum domain_rule type, vo
 	/* update domain rule */
 	if (add_domain_rule) {
 		old_domain_rule =
-			art_insert(&dns_conf_domain_rule.default_rule, (unsigned char *)domain_key, len, add_domain_rule);
+			art_insert(&_config_current_group()->domain_rule->tree, (unsigned char *)domain_key, len, add_domain_rule);
 		if (old_domain_rule) {
 			_config_domain_rule_free(old_domain_rule);
 		}
@@ -1112,7 +1320,7 @@ static int _config_domain_rule_delete(const char *domain)
 	}
 
 	/* delete existing rules */
-	void *rule = art_delete(&dns_conf_domain_rule.default_rule, (unsigned char *)domain_key, len);
+	void *rule = art_delete(&_config_current_group()->domain_rule->tree, (unsigned char *)domain_key, len);
 	if (rule) {
 		_config_domain_rule_free(rule);
 	}
@@ -1155,7 +1363,7 @@ static int _config_domain_rule_flag_set(const char *domain, unsigned int flag, u
 	}
 
 	/* Get existing or create domain rule */
-	domain_rule = art_search(&dns_conf_domain_rule.default_rule, (unsigned char *)domain_key, len);
+	domain_rule = art_search(&_config_current_group()->domain_rule->tree, (unsigned char *)domain_key, len);
 	if (domain_rule == NULL) {
 		add_domain_rule = malloc(sizeof(*add_domain_rule));
 		if (add_domain_rule == NULL) {
@@ -1186,7 +1394,7 @@ static int _config_domain_rule_flag_set(const char *domain, unsigned int flag, u
 	/* update domain rule */
 	if (add_domain_rule) {
 		old_domain_rule =
-			art_insert(&dns_conf_domain_rule.default_rule, (unsigned char *)domain_key, len, add_domain_rule);
+			art_insert(&_config_current_group()->domain_rule->tree, (unsigned char *)domain_key, len, add_domain_rule);
 		if (old_domain_rule) {
 			_config_domain_rule_free(old_domain_rule);
 		}
@@ -2362,6 +2570,7 @@ static int _config_bind_ip(int argc, char *argv[], DNS_BIND_TYPE type)
 		{"no-dualstack-selection", no_argument, NULL, 'D'},
 		{"no-ip-alias", no_argument, NULL, 'a'},
 		{"force-aaaa-soa", no_argument, NULL, 'F'},
+		{"no-rules", no_argument, NULL, 252},
 		{"no-serve-expired", no_argument, NULL, 253},
 		{"force-https-soa", no_argument, NULL, 254},
 		{"ipset", required_argument, NULL, 255},
@@ -2398,6 +2607,10 @@ static int _config_bind_ip(int argc, char *argv[], DNS_BIND_TYPE type)
 	bind_ip->type = type;
 	bind_ip->flags = 0;
 	safe_strncpy(bind_ip->ip, ip, DNS_MAX_IPLEN);
+	/* get current group */
+	if (_config_current_group()) {
+		group = _config_current_group()->group_name;
+	}
 
 	/* process extra options */
 	optind = 1;
@@ -2453,6 +2666,10 @@ static int _config_bind_ip(int argc, char *argv[], DNS_BIND_TYPE type)
 			server_flag |= BIND_FLAG_FORCE_AAAA_SOA;
 			break;
 		}
+		case 252: {
+			server_flag |= BIND_FLAG_NO_RULES;
+			break;
+		}
 		case 253: {
 			server_flag |= BIND_FLAG_NO_SERVE_EXPIRED;
 			break;
@@ -3982,6 +4199,8 @@ static int _conf_domain_rules(void *data, int argc, char *argv[])
 	int rr_ttl = 0;
 	int rr_ttl_min = 0;
 	int rr_ttl_max = 0;
+	const char *group = NULL;
+	char group_name[DNS_MAX_CONF_CNAME_LEN];
 
 	/* clang-format off */
 	static struct option long_options[] = {
@@ -3991,7 +4210,7 @@ static int _conf_domain_rules(void *data, int argc, char *argv[])
 		{"ipset", required_argument, NULL, 'p'},
 		{"nftset", required_argument, NULL, 't'},
 		{"nameserver", required_argument, NULL, 'n'},
-		{"group", required_argument, NULL, 'n'},
+		{"group", required_argument, NULL, 'g'},
 		{"dualstack-ip-selection", required_argument, NULL, 'd'},
 		{"cname", required_argument, NULL, 'A'},
 		{"rr-ttl", required_argument, NULL, 251},
@@ -4024,10 +4243,22 @@ static int _conf_domain_rules(void *data, int argc, char *argv[])
 		}
 	}
 
+	for (int i = 2; i < argc - 1; i++) {
+		if (strncmp(argv[i], "-g", sizeof("-g")) == 0 || strncmp(argv[i], "--group", sizeof("--group")) == 0) {
+			safe_strncpy(group_name, argv[i + 1], DNS_MAX_CONF_CNAME_LEN);
+			group = group_name;
+			break;
+		}
+	}
+
+	if (group != NULL) {
+		_config_current_group_push(group);
+	}
+
 	/* process extra options */
 	optind = 1;
 	while (1) {
-		opt = getopt_long_only(argc, argv, "c:a:p:t:n:d:A:r:", long_options, NULL);
+		opt = getopt_long_only(argc, argv, "c:a:p:t:n:d:A:r:g:", long_options, NULL);
 		if (opt == -1) {
 			break;
 		}
@@ -4130,6 +4361,9 @@ static int _conf_domain_rules(void *data, int argc, char *argv[])
 
 			break;
 		}
+		case 'g': {
+			break;
+		}
 		case 251: {
 			rr_ttl = atoi(optarg);
 			break;
@@ -4188,8 +4422,15 @@ static int _conf_domain_rules(void *data, int argc, char *argv[])
 		}
 	}
 
+	if (group != NULL) {
+		_config_current_group_pop();
+	}
+
 	return 0;
 errout:
+	if (group != NULL) {
+		_config_current_group_pop();
+	}
 	return -1;
 }
 
@@ -4693,6 +4934,7 @@ static int _config_client_rules(void *data, int argc, char *argv[])
 	int opt = 0;
 	const char *client = argv[1];
 	unsigned int server_flag = 0;
+	const char *group = NULL;
 
 	/* clang-format off */
 	static struct option long_options[] = {
@@ -4707,6 +4949,7 @@ static int _config_client_rules(void *data, int argc, char *argv[])
 		{"no-dualstack-selection", no_argument, NULL, 'D'},
 		{"no-ip-alias", no_argument, NULL, 'a'},
 		{"force-aaaa-soa", no_argument, NULL, 'F'},
+		{"no-rules", no_argument, NULL, 252},
 		{"no-serve-expired", no_argument, NULL, 253},
 		{"force-https-soa", no_argument, NULL, 254},
 		{NULL, no_argument, NULL, 0}
@@ -4718,6 +4961,11 @@ static int _config_client_rules(void *data, int argc, char *argv[])
 		goto errout;
 	}
 
+	/* get current group */
+	if (_config_current_group()) {
+		group = _config_current_group()->group_name;
+	}
+
 	/* process extra options */
 	optind = 1;
 	while (1) {
@@ -4728,11 +4976,7 @@ static int _config_client_rules(void *data, int argc, char *argv[])
 
 		switch (opt) {
 		case 'g': {
-			const char *group = optarg;
-			if (_config_client_rule_group_add(client, group) != 0) {
-				tlog(TLOG_ERROR, "add group rule failed.");
-				goto errout;
-			}
+			group = optarg;
 			break;
 		}
 		case 'A': {
@@ -4775,6 +5019,10 @@ static int _config_client_rules(void *data, int argc, char *argv[])
 			server_flag |= BIND_FLAG_FORCE_AAAA_SOA;
 			break;
 		}
+		case 252: {
+			server_flag |= BIND_FLAG_NO_RULES;
+			break;
+		}
 		case 253: {
 			server_flag |= BIND_FLAG_NO_SERVE_EXPIRED;
 			break;
@@ -4783,6 +5031,17 @@ static int _config_client_rules(void *data, int argc, char *argv[])
 			server_flag |= BIND_FLAG_FORCE_HTTPS_SOA;
 			break;
 		}
+		default:
+			tlog(TLOG_WARN, "unknown client-rules option: %s at '%s:%d'.", argv[optind - 1], conf_get_conf_file(),
+				 conf_get_current_lineno());
+			break;
+		}
+	}
+
+	if (group != NULL) {
+		if (_config_client_rule_group_add(client, group) != 0) {
+			tlog(TLOG_ERROR, "add group rule failed.");
+			goto errout;
 		}
 	}
 
@@ -4995,6 +5254,8 @@ static struct config_item _config_item[] = {
 	CONF_CUSTOM("ddns-domain", _conf_ddns_domain, NULL),
 	CONF_CUSTOM("dnsmasq-lease-file", _conf_dhcp_lease_dnsmasq_file, NULL),
 	CONF_CUSTOM("hosts-file", _conf_hosts_file, NULL),
+	CONF_CUSTOM("group-begin", _config_group_begin, NULL),
+	CONF_CUSTOM("group-end", _config_group_end, NULL),
 	CONF_CUSTOM("client-rules", _config_client_rules, NULL),
 	CONF_STRING("ca-file", (char *)&dns_conf_ca_file, DNS_MAX_PATH),
 	CONF_STRING("ca-path", (char *)&dns_conf_ca_path, DNS_MAX_PATH),
@@ -5098,6 +5359,11 @@ static int _config_additional_file_callback(const char *file, void *priv)
 int config_additional_file(void *data, int argc, char *argv[])
 {
 	const char *conf_pattern = NULL;
+	int opt = 0;
+	const char *group_name = NULL;
+	int ret = 0;
+	struct dns_conf_group_info *last_group_info;
+
 	if (argc < 1) {
 		return -1;
 	}
@@ -5107,7 +5373,44 @@ int config_additional_file(void *data, int argc, char *argv[])
 		return -1;
 	}
 
-	return _config_foreach_file(conf_pattern, _config_additional_file_callback, NULL);
+	/* clang-format off */
+	static struct option long_options[] = {
+		{"group", required_argument, NULL, 'g'},
+		{NULL, no_argument, NULL, 0}
+	};
+	/* clang-format on */
+
+	/* process extra options */
+	optind = 1;
+	while (1) {
+		opt = getopt_long_only(argc, argv, "g:", long_options, NULL);
+		if (opt == -1) {
+			break;
+		}
+
+		switch (opt) {
+		case 'g': {
+			group_name = optarg;
+			break;
+		}
+		}
+	}
+
+	last_group_info = _config_current_group();
+	if (group_name != NULL) {
+		ret = _config_current_group_push(group_name);
+		if (ret != 0) {
+			tlog(TLOG_ERROR, "begin group '%s' failed.", group_name);
+			return -1;
+		}
+	}
+
+	ret = _config_foreach_file(conf_pattern, _config_additional_file_callback, NULL);
+	if (group_name != NULL) {
+		_config_current_group_pop_to(last_group_info);
+	}
+
+	return ret;
 }
 
 const char *dns_conf_get_cache_dir(void)
@@ -5129,8 +5432,12 @@ static int _dns_server_load_conf_init(void)
 		return -1;
 	}
 
-	art_tree_init(&dns_conf_domain_rule.default_rule);
 	hash_init(dns_conf_domain_rule.group);
+	dns_conf_domain_rule.default_rule = _config_domain_rule_group_new("");
+	if (dns_conf_domain_rule.default_rule == NULL) {
+		tlog(TLOG_WARN, "init default domain rule failed.");
+		return -1;
+	}
 
 	hash_init(dns_ipset_table.ipset);
 	hash_init(dns_nftset_table.nftset);
@@ -5147,6 +5454,8 @@ static int _dns_server_load_conf_init(void)
 	hash_init(dns_ip_set_name_table.names);
 	hash_init(dns_conf_srv_record_table.srv);
 
+	_config_current_group_push_default();
+
 	return 0;
 }
 
@@ -5362,6 +5671,8 @@ static int _dns_conf_load_post(void)
 
 	_config_file_hash_table_destroy();
 
+	_config_current_group_pop_all();
+
 	if (dns_conf_log_syslog == 0 && dns_conf_audit_syslog == 0) {
 		closelog();
 	}

+ 5 - 2
src/dns_conf.h

@@ -157,6 +157,7 @@ typedef enum {
 #define BIND_FLAG_NO_PREFETCH (1 << 12)
 #define BIND_FLAG_FORCE_HTTPS_SOA (1 << 13)
 #define BIND_FLAG_NO_SERVE_EXPIRED (1 << 14)
+#define BIND_FLAG_NO_RULES (1 << 15)
 
 enum response_mode_type {
 	DNS_RESPONSE_MODE_FIRST_PING_IP = 0,
@@ -280,12 +281,12 @@ struct dns_response_mode_rule {
 
 struct dns_conf_doamin_rule_group {
 	struct hlist_node node;
-	art_tree rule;
+	art_tree tree;
 	const char *group_name;
 };
 
 struct dns_conf_domain_rule {
-	art_tree default_rule;
+	struct dns_conf_doamin_rule_group *default_rule;
 	DECLARE_HASHTABLE(group, 8);
 };
 
@@ -661,6 +662,8 @@ struct dns_proxy_names *dns_server_get_proxy_nams(const char *proxyname);
 
 struct dns_srv_records *dns_server_get_srv_record(const char *domain);
 
+struct dns_conf_doamin_rule_group *dns_server_get_domain_rule_group(const char *group_name, int no_fallback_default);
+
 extern int config_additional_file(void *data, int argc, char *argv[]);
 
 const char *dns_conf_get_cache_dir(void);

+ 12 - 1
src/dns_server.c

@@ -4631,6 +4631,8 @@ static void _dns_server_get_domain_rule_by_domain(struct dns_request *request, c
 	unsigned char matched_key[DNS_MAX_CNAME_LEN];
 	struct rule_walk_args walk_args;
 	int i = 0;
+	struct dns_conf_doamin_rule_group *domain_rule_group = NULL;
+	int no_fallback_default_rule = 0;
 
 	if (request->skip_domain_rule != 0) {
 		return;
@@ -4650,8 +4652,17 @@ static void _dns_server_get_domain_rule_by_domain(struct dns_request *request, c
 	domain_len++;
 	domain_key[domain_len] = 0;
 
+	if (_dns_server_has_bind_flag(request, BIND_FLAG_NO_RULES) == 0) {
+		no_fallback_default_rule = 1;
+	}
+
+	domain_rule_group = dns_server_get_domain_rule_group(request->dns_group_name, no_fallback_default_rule);
+	if (domain_rule_group == NULL) {
+		return;
+	}
+
 	/* find domain rule */
-	art_substring_walk(&dns_conf_domain_rule.default_rule, (unsigned char *)domain_key, domain_len, _dns_server_get_rules,
+	art_substring_walk(&domain_rule_group->tree, (unsigned char *)domain_key, domain_len, _dns_server_get_rules,
 					   &walk_args);
 	if (likely(dns_conf_log_level > TLOG_DEBUG)) {
 		return;

+ 2 - 2
src/lib/conf.c

@@ -17,6 +17,7 @@
  */
 
 #include "conf.h"
+#include <errno.h>
 #include <getopt.h>
 #include <libgen.h>
 #include <linux/limits.h>
@@ -24,7 +25,6 @@
 #include <stdlib.h>
 #include <string.h>
 #include <unistd.h>
-#include <errno.h>
 
 static const char *current_conf_file = NULL;
 static int current_conf_lineno = 0;
@@ -429,7 +429,7 @@ static int load_conf_file(const char *file, struct config_item *items, conf_erro
 		}
 
 		/* if field format is not key = value, error */
-		if (filed_num != 2) {
+		if (filed_num != 2 && filed_num != 1) {
 			handler(file, line_no, CONF_RET_BADCONF);
 			goto errout;
 		}

+ 38 - 1
test/cases/test-client-rule.cc

@@ -29,7 +29,7 @@ class ClientRule : public ::testing::Test
 	virtual void TearDown() {}
 };
 
-TEST_F(ClientRule, bogus_nxdomain)
+TEST_F(ClientRule, rule)
 {
 	smartdns::MockServer server_upstream;
 	smartdns::MockServer server_upstream2;
@@ -63,3 +63,40 @@ client-rules 127.0.0.1 -g client
 	EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com");
 	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
 }
+
+TEST_F(ClientRule, group_begin_group_end)
+{
+	smartdns::MockServer server_upstream;
+	smartdns::MockServer server_upstream2;
+	smartdns::Server server;
+
+	server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
+		if (request->qtype != DNS_T_A) {
+			return smartdns::SERVER_REQUEST_SOA;
+		}
+
+		smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 611);
+		return smartdns::SERVER_REQUEST_OK;
+	});
+
+	server_upstream2.Start("udp://0.0.0.0:62053",
+						   [](struct smartdns::ServerRequestContext *request) { return smartdns::SERVER_REQUEST_SOA; });
+
+	/* this ip will be discard, but is reachable */
+	server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 10);
+
+	server.Start(R"""(bind [::]:60053
+server udp://127.0.0.1:62053
+group-begin client
+server udp://127.0.0.1:61053 -e 
+client-rules 127.0.0.1 
+group-end
+)""");
+	smartdns::Client client;
+	ASSERT_TRUE(client.Query("b.com", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
+}

+ 97 - 0
test/cases/test-group.cc

@@ -0,0 +1,97 @@
+/*************************************************************************
+ *
+ * Copyright (C) 2018-2023 Ruilin Peng (Nick) <[email protected]>.
+ *
+ * smartdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * smartdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include "client.h"
+#include "dns.h"
+#include "include/utils.h"
+#include "server.h"
+#include "gtest/gtest.h"
+
+class Group : public ::testing::Test
+{
+  protected:
+	virtual void SetUp() {}
+	virtual void TearDown() {}
+};
+
+TEST_F(Group, conf_file)
+{
+	smartdns::MockServer server_upstream;
+	smartdns::MockServer server_upstream2;
+	smartdns::Server server;
+	std::string file = "/tmp/smartdns_conf_file" + smartdns::GenerateRandomString(5) + ".conf";
+	std::ofstream ofs(file);
+	ASSERT_TRUE(ofs.is_open());
+	Defer
+	{
+		ofs.close();
+		unlink(file.c_str());
+	};
+
+	ofs << R"""(
+server udp://127.0.0.1:61053 -e
+client-rules 127.0.0.1
+address /a.com/1.1.1.1
+domain-rules /b.com/ -address 4.5.6.7
+)""";
+	ofs.flush();
+
+	server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
+		if (request->qtype != DNS_T_A) {
+			return smartdns::SERVER_REQUEST_SOA;
+		}
+
+		smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 611);
+		return smartdns::SERVER_REQUEST_OK;
+	});
+
+	server_upstream2.Start("udp://0.0.0.0:62053",
+						   [](struct smartdns::ServerRequestContext *request) { return smartdns::SERVER_REQUEST_SOA; });
+
+	/* this ip will be discard, but is reachable */
+	server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 10);
+
+	server.Start(R"""(bind [::]:60053
+conf-file /tmp/smartdns_conf_file*.conf -g client
+server udp://127.0.0.1:61053
+)""");
+	smartdns::Client client;
+	ASSERT_TRUE(client.Query("a.com", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.1.1.1");
+
+	ASSERT_TRUE(client.Query("b.com", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "4.5.6.7");
+
+	auto ipaddr = smartdns::GetAvailableIPAddresses();
+	if (ipaddr.size() > 0) {
+		ASSERT_TRUE(client.Query("b.com", 60053, ipaddr[0]));
+		std::cout << client.GetResult() << std::endl;
+		ASSERT_EQ(client.GetAnswerNum(), 1);
+		EXPECT_EQ(client.GetStatus(), "NOERROR");
+		EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com");
+		EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
+	}
+}

+ 2 - 0
test/include/utils.h

@@ -107,5 +107,7 @@ std::string GenerateRandomString(int len);
 
 int ParserArg(const std::string &cmd, std::vector<std::string> &args);
 
+std::vector<std::string> GetAvailableIPAddresses();
+
 } // namespace smartdns
 #endif // _SMARTDNS_TEST_UTILS_

+ 39 - 0
test/utils.cc

@@ -1,8 +1,13 @@
 #include "include/utils.h"
+#include <arpa/inet.h>
+#include <ifaddrs.h>
+#include <netinet/in.h>
 #include <signal.h>
 #include <stdio.h>
 #include <stdlib.h>
 #include <string.h>
+#include <sys/socket.h>
+#include <sys/types.h>
 #include <sys/wait.h>
 #include <unistd.h>
 
@@ -266,4 +271,38 @@ int ParserArg(const std::string &cmd, std::vector<std::string> &args)
 	return 0;
 }
 
+std::vector<std::string> GetAvailableIPAddresses()
+{
+	std::vector<std::string> ipAddresses;
+
+	struct ifaddrs *ifAddrStruct = nullptr;
+	struct ifaddrs *ifa = nullptr;
+	void *tmpAddrPtr = nullptr;
+
+	getifaddrs(&ifAddrStruct);
+
+	for (ifa = ifAddrStruct; ifa != nullptr; ifa = ifa->ifa_next) {
+		if (!ifa->ifa_addr) {
+			continue;
+		}
+
+		if (ifa->ifa_addr->sa_family == AF_INET) { // IPv4 address
+			tmpAddrPtr = &((struct sockaddr_in *)ifa->ifa_addr)->sin_addr;
+			char addressBuffer[INET_ADDRSTRLEN];
+			inet_ntop(AF_INET, tmpAddrPtr, addressBuffer, INET_ADDRSTRLEN);
+			std::string ipAddress(addressBuffer);
+
+			if (!ipAddress.empty() && ipAddress.substr(0, 4) != "127.") {
+				ipAddresses.push_back(ipAddress);
+			}
+		}
+	}
+
+	if (ifAddrStruct != nullptr) {
+		freeifaddrs(ifAddrStruct);
+	}
+
+	return ipAddresses;
+}
+
 } // namespace smartdns