Browse Source

feature: add client-rules option.

Nick Peng 1 year ago
parent
commit
fbc3d0e58f
5 changed files with 587 additions and 15 deletions
  1. 4 0
      etc/smartdns/smartdns.conf
  2. 406 12
      src/dns_conf.c
  3. 43 2
      src/dns_conf.h
  4. 69 1
      src/dns_server.c
  5. 65 0
      test/cases/test-client-rule.cc

+ 4 - 0
etc/smartdns/smartdns.conf

@@ -374,3 +374,7 @@ log-level info
 # bogus-nxdomain ip-set:ip-list
 # ip-alias ip-set:ip-list 1.2.3.4
 # ip-alias ip-set:ip-list ip-set:ip-map-list
+
+# set client rules
+# client-rules ip-cidr [-group [group]] [-no-rule-addr] [-no-rule-nameserver] [-no-rule-ipset] [-no-speed-check] [-no-cache] [-no-rule-soa] [-no-dualstack-selection]
+# client-rules option is same as bind option, please see bind option for detail.

+ 406 - 12
src/dns_conf.c

@@ -143,8 +143,9 @@ int dns_conf_audit_console;
 int dns_conf_audit_syslog;
 
 /* address rules */
-art_tree dns_conf_domain_rule;
+struct dns_conf_domain_rule dns_conf_domain_rule;
 struct dns_conf_address_rule dns_conf_address_rule;
+struct dns_conf_client_rule dns_conf_client_rule;
 
 /* dual-stack selection */
 int dns_conf_dualstack_ip_selection = 1;
@@ -196,6 +197,8 @@ static int _conf_client_subnet(char *subnet, struct dns_edns_client_subnet *ipv4
 							   struct dns_edns_client_subnet *ipv6_ecs);
 static int _conf_domain_rule_address(char *domain, const char *domain_address);
 static struct dns_domain_rule *_config_domain_rule_get(const char *domain);
+typedef int (*set_rule_add_func)(const char *value, void *priv);
+static int _config_ip_rule_set_each(const char *ip_set, set_rule_add_func callback, void *priv);
 
 static void *_new_dns_rule_ext(enum domain_rule domain_rule, int ext_size)
 {
@@ -851,11 +854,22 @@ static int _config_domain_iter_free(void *data, const unsigned char *key, uint32
 
 static void _config_domain_destroy(void)
 {
-	art_iter(&dns_conf_domain_rule, _config_domain_iter_free, NULL);
-	art_tree_destroy(&dns_conf_domain_rule);
+	struct dns_conf_doamin_rule_group *group;
+	struct hlist_node *tmp = NULL;
+	unsigned long i = 0;
+
+	hash_for_each_safe(dns_conf_domain_rule.group, i, tmp, group, node)
+	{
+		hlist_del_init(&group->node);
+		art_iter(&group->rule, _config_domain_iter_free, NULL);
+		art_tree_destroy(&group->rule);
+		free(group);
+	}
+
+	art_iter(&dns_conf_domain_rule.default_rule, _config_domain_iter_free, NULL);
+	art_tree_destroy(&dns_conf_domain_rule.default_rule);
 }
 
-typedef int (*set_rule_add_func)(const char *value, void *priv);
 static int _config_set_rule_each_from_list(const char *file, set_rule_add_func callback, void *priv)
 {
 	FILE *fp = NULL;
@@ -1004,7 +1018,7 @@ static __attribute__((unused)) struct dns_domain_rule *_config_domain_rule_get(c
 		return NULL;
 	}
 
-	return art_search(&dns_conf_domain_rule, (unsigned char *)domain_key, len);
+	return art_search(&dns_conf_domain_rule.default_rule, (unsigned char *)domain_key, len);
 }
 
 static int _config_domain_rule_add(const char *domain, enum domain_rule type, void *rule)
@@ -1036,7 +1050,7 @@ static int _config_domain_rule_add(const char *domain, enum domain_rule type, vo
 	}
 
 	/* Get existing or create domain rule */
-	domain_rule = art_search(&dns_conf_domain_rule, (unsigned char *)domain_key, len);
+	domain_rule = art_search(&dns_conf_domain_rule.default_rule, (unsigned char *)domain_key, len);
 	if (domain_rule == NULL) {
 		add_domain_rule = malloc(sizeof(*add_domain_rule));
 		if (add_domain_rule == NULL) {
@@ -1059,7 +1073,8 @@ static int _config_domain_rule_add(const char *domain, enum domain_rule type, vo
 
 	/* update domain rule */
 	if (add_domain_rule) {
-		old_domain_rule = art_insert(&dns_conf_domain_rule, (unsigned char *)domain_key, len, add_domain_rule);
+		old_domain_rule =
+			art_insert(&dns_conf_domain_rule.default_rule, (unsigned char *)domain_key, len, add_domain_rule);
 		if (old_domain_rule) {
 			_config_domain_rule_free(old_domain_rule);
 		}
@@ -1097,7 +1112,7 @@ static int _config_domain_rule_delete(const char *domain)
 	}
 
 	/* delete existing rules */
-	void *rule = art_delete(&dns_conf_domain_rule, (unsigned char *)domain_key, len);
+	void *rule = art_delete(&dns_conf_domain_rule.default_rule, (unsigned char *)domain_key, len);
 	if (rule) {
 		_config_domain_rule_free(rule);
 	}
@@ -1140,7 +1155,7 @@ static int _config_domain_rule_flag_set(const char *domain, unsigned int flag, u
 	}
 
 	/* Get existing or create domain rule */
-	domain_rule = art_search(&dns_conf_domain_rule, (unsigned char *)domain_key, len);
+	domain_rule = art_search(&dns_conf_domain_rule.default_rule, (unsigned char *)domain_key, len);
 	if (domain_rule == NULL) {
 		add_domain_rule = malloc(sizeof(*add_domain_rule));
 		if (add_domain_rule == NULL) {
@@ -1170,7 +1185,8 @@ static int _config_domain_rule_flag_set(const char *domain, unsigned int flag, u
 
 	/* update domain rule */
 	if (add_domain_rule) {
-		old_domain_rule = art_insert(&dns_conf_domain_rule, (unsigned char *)domain_key, len, add_domain_rule);
+		old_domain_rule =
+			art_insert(&dns_conf_domain_rule.default_rule, (unsigned char *)domain_key, len, add_domain_rule);
 		if (old_domain_rule) {
 			_config_domain_rule_free(old_domain_rule);
 		}
@@ -2805,6 +2821,224 @@ static void _dns_ip_rule_put(struct dns_ip_rule *rule)
 	}
 }
 
+static radix_node_t *_create_client_rules_node(const char *addr)
+{
+	radix_node_t *node = NULL;
+	void *p = NULL;
+	prefix_t prefix;
+	const char *errmsg = NULL;
+
+	p = prefix_pton(addr, -1, &prefix, &errmsg);
+	if (p == NULL) {
+		return NULL;
+	}
+
+	node = radix_lookup(dns_conf_client_rule.rule, &prefix);
+	return node;
+}
+
+static void *_new_dns_client_rule_ext(enum client_rule client_rule, int ext_size)
+{
+	struct dns_client_rule *rule;
+	int size = 0;
+
+	if (client_rule >= CLIENT_RULE_MAX) {
+		return NULL;
+	}
+
+	switch (client_rule) {
+	case CLIENT_RULE_FLAGS:
+		size = sizeof(struct client_rule_flags);
+		break;
+	case CLIENT_RULE_GROUP:
+		size = sizeof(struct client_rule_group);
+		break;
+	default:
+		return NULL;
+	}
+
+	size += ext_size;
+	rule = malloc(size);
+	if (!rule) {
+		return NULL;
+	}
+	memset(rule, 0, size);
+	rule->rule = client_rule;
+	atomic_set(&rule->refcnt, 1);
+	return rule;
+}
+
+static void *_new_dns_client_rule(enum client_rule client_rule)
+{
+	return _new_dns_client_rule_ext(client_rule, 0);
+}
+
+static void _dns_client_rule_get(struct dns_client_rule *rule)
+{
+	atomic_inc(&rule->refcnt);
+}
+
+static void _dns_client_rule_put(struct dns_client_rule *rule)
+{
+	int refcount = atomic_dec_return(&rule->refcnt);
+	if (refcount > 0) {
+		return;
+	}
+
+	free(rule);
+}
+
+static int _config_client_rules_free(struct dns_client_rules *client_rules)
+{
+	int i = 0;
+
+	if (client_rules == NULL) {
+		return 0;
+	}
+
+	for (i = 0; i < CLIENT_RULE_MAX; i++) {
+		if (client_rules->rules[i] == NULL) {
+			continue;
+		}
+
+		_dns_client_rule_put(client_rules->rules[i]);
+		client_rules->rules[i] = NULL;
+	}
+
+	free(client_rules);
+	return 0;
+}
+
+static int _config_client_rule_flag_set(const char *ip_cidr, unsigned int flag, unsigned int is_clear);
+static int _config_client_rule_flag_callback(const char *ip_cidr, void *priv)
+{
+	struct dns_set_rule_flags_callback_args *args = (struct dns_set_rule_flags_callback_args *)priv;
+	return _config_client_rule_flag_set(ip_cidr, args->flags, args->is_clear_flag);
+}
+
+static int _config_client_rule_flag_set(const char *ip_cidr, unsigned int flag, unsigned int is_clear)
+{
+	struct dns_client_rules *client_rules = NULL;
+	struct dns_client_rules *add_client_rules = NULL;
+	struct client_rule_flags *client_rule_flags = NULL;
+	radix_node_t *node = NULL;
+
+	if (strncmp(ip_cidr, "ip-set:", sizeof("ip-set:") - 1) == 0) {
+		struct dns_set_rule_flags_callback_args args;
+		args.flags = flag;
+		args.is_clear_flag = is_clear;
+		return _config_ip_rule_set_each(ip_cidr + sizeof("ip-set:") - 1, _config_client_rule_flag_callback, &args);
+	}
+
+	/* Get existing or create domain rule */
+	node = _create_client_rules_node(ip_cidr);
+	if (node == NULL) {
+		tlog(TLOG_ERROR, "create addr node failed.");
+		goto errout;
+	}
+
+	client_rules = node->data;
+	if (client_rules == NULL) {
+		add_client_rules = malloc(sizeof(*add_client_rules));
+		if (add_client_rules == NULL) {
+			goto errout;
+		}
+		memset(add_client_rules, 0, sizeof(*add_client_rules));
+		client_rules = add_client_rules;
+		node->data = client_rules;
+	}
+
+	/* add new rule to domain */
+	if (client_rules->rules[CLIENT_RULE_FLAGS] == NULL) {
+		client_rule_flags = _new_dns_client_rule(CLIENT_RULE_FLAGS);
+		client_rule_flags->flags = 0;
+		client_rules->rules[CLIENT_RULE_FLAGS] = &client_rule_flags->head;
+	}
+
+	client_rule_flags = container_of(client_rules->rules[CLIENT_RULE_FLAGS], struct client_rule_flags, head);
+	if (is_clear == false) {
+		client_rule_flags->flags |= flag;
+	} else {
+		client_rule_flags->flags &= ~flag;
+	}
+	client_rule_flags->is_flag_set |= flag;
+
+	return 0;
+errout:
+	if (add_client_rules) {
+		free(add_client_rules);
+	}
+
+	tlog(TLOG_ERROR, "set ip %s flags failed", ip_cidr);
+
+	return 0;
+}
+
+static int _config_client_rule_add(const char *ip_cidr, enum client_rule type, void *rule);
+static int _config_client_rule_add_callback(const char *ip_cidr, void *priv)
+{
+	struct dns_set_rule_add_callback_args *args = (struct dns_set_rule_add_callback_args *)priv;
+	return _config_client_rule_add(ip_cidr, args->type, args->rule);
+}
+
+static int _config_client_rule_add(const char *ip_cidr, enum client_rule type, void *rule)
+{
+	struct dns_client_rules *client_rules = NULL;
+	struct dns_client_rules *add_client_rules = NULL;
+	radix_node_t *node = NULL;
+
+	if (ip_cidr == NULL) {
+		goto errout;
+	}
+
+	if (type >= CLIENT_RULE_MAX) {
+		goto errout;
+	}
+
+	if (strncmp(ip_cidr, "ip-set:", sizeof("ip-set:") - 1) == 0) {
+		struct dns_set_rule_add_callback_args args;
+		args.type = type;
+		args.rule = rule;
+		return _config_ip_rule_set_each(ip_cidr + sizeof("ip-set:") - 1, _config_client_rule_add_callback, &args);
+	}
+
+	/* Get existing or create domain rule */
+	node = _create_client_rules_node(ip_cidr);
+	if (node == NULL) {
+		tlog(TLOG_ERROR, "create addr node failed.");
+		goto errout;
+	}
+
+	client_rules = node->data;
+	if (client_rules == NULL) {
+		add_client_rules = malloc(sizeof(*add_client_rules));
+		if (add_client_rules == NULL) {
+			goto errout;
+		}
+		memset(add_client_rules, 0, sizeof(*add_client_rules));
+		client_rules = add_client_rules;
+		node->data = client_rules;
+	}
+
+	/* add new rule to domain */
+	if (client_rules->rules[type]) {
+		_dns_client_rule_put(client_rules->rules[type]);
+		client_rules->rules[type] = NULL;
+	}
+
+	client_rules->rules[type] = rule;
+	_dns_client_rule_get(rule);
+
+	return 0;
+errout:
+	if (add_client_rules) {
+		free(add_client_rules);
+	}
+
+	tlog(TLOG_ERROR, "add client %s rule failed", ip_cidr);
+	return -1;
+}
+
 static int _config_qtype_soa(void *data, int argc, char *argv[])
 {
 	int i = 0;
@@ -3638,6 +3872,22 @@ static void _config_ip_iter_free(radix_node_t *node, void *cbctx)
 	node->data = NULL;
 }
 
+static void _config_client_rule_iter_free_cb(radix_node_t *node, void *cbctx)
+{
+	struct dns_client_rules *client_rules = NULL;
+	if (node == NULL) {
+		return;
+	}
+
+	if (node->data == NULL) {
+		return;
+	}
+
+	client_rules = node->data;
+	_config_client_rules_free(client_rules);
+	node->data = NULL;
+}
+
 static void _config_ip_set_name_table_destroy(void)
 {
 	struct dns_ip_set_name_list *set_name_list = NULL;
@@ -4408,6 +4658,146 @@ static void _config_host_table_destroy(int only_dynamic)
 	dns_hosts_record_num = 0;
 }
 
+static int _config_client_rule_group_add(const char *client, const char *group_name)
+{
+	struct client_rule_group *client_rule = NULL;
+	const char *group = NULL;
+
+	client_rule = _new_dns_client_rule(CLIENT_RULE_GROUP);
+	if (client_rule == NULL) {
+		goto errout;
+	}
+
+	group = _dns_conf_get_group_name(group_name);
+	if (group == NULL) {
+		goto errout;
+	}
+
+	client_rule->group_name = group;
+	if (_config_client_rule_add(client, CLIENT_RULE_GROUP, client_rule) != 0) {
+		goto errout;
+	}
+
+	_dns_client_rule_put(&client_rule->head);
+
+	return 0;
+errout:
+	if (client_rule != NULL) {
+		_dns_client_rule_put(&client_rule->head);
+	}
+	return -1;
+}
+
+static int _config_client_rules(void *data, int argc, char *argv[])
+{
+	int opt = 0;
+	const char *client = argv[1];
+	unsigned int server_flag = 0;
+
+	/* clang-format off */
+	static struct option long_options[] = {
+		{"group", required_argument, NULL, 'g'},
+		{"no-rule-addr", no_argument, NULL, 'A'},   
+		{"no-rule-nameserver", no_argument, NULL, 'N'},   
+		{"no-rule-ipset", no_argument, NULL, 'I'},   
+		{"no-rule-sni-proxy", no_argument, NULL, 'P'},   
+		{"no-rule-soa", no_argument, NULL, 'O'},
+		{"no-speed-check", no_argument, NULL, 'S'},  
+		{"no-cache", no_argument, NULL, 'C'},  
+		{"no-dualstack-selection", no_argument, NULL, 'D'},
+		{"no-ip-alias", no_argument, NULL, 'a'},
+		{"force-aaaa-soa", no_argument, NULL, 'F'},
+		{"no-serve-expired", no_argument, NULL, 253},
+		{"force-https-soa", no_argument, NULL, 254},
+		{NULL, no_argument, NULL, 0}
+	};
+	/* clang-format on */
+
+	if (argc <= 1) {
+		tlog(TLOG_ERROR, "invalid parameter.");
+		goto errout;
+	}
+
+	/* process extra options */
+	optind = 1;
+	while (1) {
+		opt = getopt_long_only(argc, argv, "g:", long_options, NULL);
+		if (opt == -1) {
+			break;
+		}
+
+		switch (opt) {
+		case 'g': {
+			const char *group = optarg;
+			if (_config_client_rule_group_add(client, group) != 0) {
+				tlog(TLOG_ERROR, "add group rule failed.");
+				goto errout;
+			}
+			break;
+		}
+		case 'A': {
+			server_flag |= BIND_FLAG_NO_RULE_ADDR;
+			break;
+		}
+		case 'a': {
+			server_flag |= BIND_FLAG_NO_IP_ALIAS;
+			break;
+		}
+		case 'N': {
+			server_flag |= BIND_FLAG_NO_RULE_NAMESERVER;
+			break;
+		}
+		case 'I': {
+			server_flag |= BIND_FLAG_NO_RULE_IPSET;
+			break;
+		}
+		case 'P': {
+			server_flag |= BIND_FLAG_NO_RULE_SNIPROXY;
+			break;
+		}
+		case 'S': {
+			server_flag |= BIND_FLAG_NO_SPEED_CHECK;
+			break;
+		}
+		case 'C': {
+			server_flag |= BIND_FLAG_NO_CACHE;
+			break;
+		}
+		case 'O': {
+			server_flag |= BIND_FLAG_NO_RULE_SOA;
+			break;
+		}
+		case 'D': {
+			server_flag |= BIND_FLAG_NO_DUALSTACK_SELECTION;
+			break;
+		}
+		case 'F': {
+			server_flag |= BIND_FLAG_FORCE_AAAA_SOA;
+			break;
+		}
+		case 253: {
+			server_flag |= BIND_FLAG_NO_SERVE_EXPIRED;
+			break;
+		}
+		case 254: {
+			server_flag |= BIND_FLAG_FORCE_HTTPS_SOA;
+			break;
+		}
+		}
+	}
+
+	if (server_flag != 0) {
+		if (_config_client_rule_flag_set(client, server_flag, 0) != 0) {
+			tlog(TLOG_ERROR, "set client rule flags failed.");
+			goto errout;
+		}
+	}
+
+	return 0;
+errout:
+	return -1;
+}
+
 int dns_server_check_update_hosts(void)
 {
 	struct stat statbuf;
@@ -4605,6 +4995,7 @@ static struct config_item _config_item[] = {
 	CONF_CUSTOM("ddns-domain", _conf_ddns_domain, NULL),
 	CONF_CUSTOM("dnsmasq-lease-file", _conf_dhcp_lease_dnsmasq_file, NULL),
 	CONF_CUSTOM("hosts-file", _conf_hosts_file, NULL),
+	CONF_CUSTOM("client-rules", _config_client_rules, NULL),
 	CONF_STRING("ca-file", (char *)&dns_conf_ca_file, DNS_MAX_PATH),
 	CONF_STRING("ca-path", (char *)&dns_conf_ca_path, DNS_MAX_PATH),
 	CONF_STRING("user", (char *)&dns_conf_user, sizeof(dns_conf_user)),
@@ -4732,12 +5123,14 @@ static int _dns_server_load_conf_init(void)
 {
 	dns_conf_address_rule.ipv4 = New_Radix();
 	dns_conf_address_rule.ipv6 = New_Radix();
-	if (dns_conf_address_rule.ipv4 == NULL || dns_conf_address_rule.ipv6 == NULL) {
+	dns_conf_client_rule.rule = New_Radix();
+	if (dns_conf_address_rule.ipv4 == NULL || dns_conf_address_rule.ipv6 == NULL || dns_conf_client_rule.rule == NULL) {
 		tlog(TLOG_WARN, "init radix tree failed.");
 		return -1;
 	}
 
-	art_tree_init(&dns_conf_domain_rule);
+	art_tree_init(&dns_conf_domain_rule.default_rule);
+	hash_init(dns_conf_domain_rule.group);
 
 	hash_init(dns_ipset_table.ipset);
 	hash_init(dns_nftset_table.nftset);
@@ -4790,6 +5183,7 @@ static void _config_ip_rules_destroy(void)
 {
 	Destroy_Radix(dns_conf_address_rule.ipv4, _config_ip_iter_free, NULL);
 	Destroy_Radix(dns_conf_address_rule.ipv6, _config_ip_iter_free, NULL);
+	Destroy_Radix(dns_conf_client_rule.rule, _config_client_rule_iter_free_cb, NULL);
 }
 
 void dns_server_load_exit(void)

+ 43 - 2
src/dns_conf.h

@@ -94,6 +94,12 @@ enum ip_rule {
 	IP_RULE_MAX,
 };
 
+enum client_rule {
+	CLIENT_RULE_FLAGS = 0,
+	CLIENT_RULE_GROUP,
+	CLIENT_RULE_MAX,
+};
+
 typedef enum {
 	DNS_BIND_TYPE_UDP,
 	DNS_BIND_TYPE_TCP,
@@ -240,7 +246,6 @@ extern struct dns_nftset_names dns_conf_nftset_no_speed;
 extern struct dns_nftset_names dns_conf_nftset;
 
 struct dns_domain_rule {
-	struct dns_rule head;
 	unsigned char sub_rule_only : 1;
 	unsigned char root_rule_only : 1;
 	struct dns_rule *rules[DOMAIN_RULE_MAX];
@@ -273,6 +278,17 @@ struct dns_response_mode_rule {
 	enum response_mode_type mode;
 };
 
+struct dns_conf_doamin_rule_group {
+	struct hlist_node node;
+	art_tree rule;
+	const char *group_name;
+};
+
+struct dns_conf_domain_rule {
+	art_tree default_rule;
+	DECLARE_HASHTABLE(group, 8);
+};
+
 struct dns_group_table {
 	DECLARE_HASHTABLE(group, 8);
 };
@@ -393,6 +409,30 @@ struct dns_conf_address_rule {
 	radix_tree_t *ipv6;
 };
 
+struct dns_client_rule {
+	atomic_t refcnt;
+	enum client_rule rule;
+};
+
+struct client_rule_flags {
+	struct dns_client_rule head;
+	unsigned int flags;
+	unsigned int is_flag_set;
+};
+
+struct client_rule_group {
+	struct dns_client_rule head;
+	const char *group_name;
+};
+
+struct dns_client_rules {
+	struct dns_client_rule *rules[CLIENT_RULE_MAX];
+};
+
+struct dns_conf_client_rule {
+	radix_tree_t *rule;
+};
+
 struct nftset_ipset_rules {
 	struct dns_ipset_rule *ipset;
 	struct dns_ipset_rule *ipset_ip;
@@ -572,8 +612,9 @@ extern int dns_conf_audit_console;
 extern int dns_conf_audit_syslog;
 
 extern char dns_conf_server_name[DNS_MAX_SERVER_NAME_LEN];
-extern art_tree dns_conf_domain_rule;
+extern struct dns_conf_domain_rule dns_conf_domain_rule;
 extern struct dns_conf_address_rule dns_conf_address_rule;
+extern struct dns_conf_client_rule dns_conf_client_rule;
 
 extern int dns_conf_dualstack_ip_selection;
 extern int dns_conf_dualstack_ip_allow_force_AAAA;

+ 69 - 1
src/dns_server.c

@@ -3027,6 +3027,48 @@ static int _dns_server_check_speed(struct dns_request *request, char *ip)
 	return -1;
 }
 
+static struct dns_client_rules *_dns_server_get_client_rules(struct sockaddr_storage *addr, socklen_t addr_len)
+{
+	prefix_t prefix;
+	radix_node_t *node = NULL;
+	uint8_t *netaddr = NULL;
+	int netaddr_len = 0;
+
+	switch (addr->ss_family) {
+	case AF_INET: {
+		struct sockaddr_in *addr_in = NULL;
+		addr_in = (struct sockaddr_in *)addr;
+		netaddr = (unsigned char *)&(addr_in->sin_addr.s_addr);
+		netaddr_len = 4;
+	} break;
+	case AF_INET6: {
+		struct sockaddr_in6 *addr_in6 = NULL;
+		addr_in6 = (struct sockaddr_in6 *)addr;
+		if (IN6_IS_ADDR_V4MAPPED(&addr_in6->sin6_addr)) {
+			netaddr = addr_in6->sin6_addr.s6_addr + 12;
+			netaddr_len = 4;
+		} else {
+			netaddr = addr_in6->sin6_addr.s6_addr;
+			netaddr_len = 16;
+		}
+	} break;
+	default:
+		return NULL;
+		break;
+	}
+
+	if (prefix_from_blob(netaddr, netaddr_len, netaddr_len * 8, &prefix) == NULL) {
+		return NULL;
+	}
+
+	node = radix_search_best(dns_conf_client_rule.rule, &prefix);
+	if (node == NULL) {
+		return NULL;
+	}
+
+	return node->data;
+}
+
 static struct dns_ip_rules *_dns_server_ip_rule_get(struct dns_request *request, unsigned char *addr, int addr_len,
 													dns_type_t addr_type)
 {
@@ -4609,7 +4651,7 @@ static void _dns_server_get_domain_rule_by_domain(struct dns_request *request, c
 	domain_key[domain_len] = 0;
 
 	/* find domain rule */
-	art_substring_walk(&dns_conf_domain_rule, (unsigned char *)domain_key, domain_len, _dns_server_get_rules,
+	art_substring_walk(&dns_conf_domain_rule.default_rule, (unsigned char *)domain_key, domain_len, _dns_server_get_rules,
 					   &walk_args);
 	if (likely(dns_conf_log_level > TLOG_DEBUG)) {
 		return;
@@ -5543,6 +5585,29 @@ static void _dns_server_request_set_client(struct dns_request *request, struct d
 	_dns_server_conn_get(conn);
 }
 
+static void _dns_server_request_set_client_rules(struct dns_request *request, struct dns_client_rules *client_rule)
+{
+	if (client_rule == NULL) {
+		return;
+	}
+
+	tlog(TLOG_DEBUG, "match client rule.\n");
+
+	if (client_rule->rules[CLIENT_RULE_GROUP]) {
+		struct client_rule_group *group = (struct client_rule_group *)client_rule->rules[CLIENT_RULE_GROUP];
+		if (group && group->group_name[0] != '\0') {
+			safe_strncpy(request->dns_group_name, group->group_name, sizeof(request->dns_group_name));
+		}
+	}
+
+	if (client_rule->rules[CLIENT_RULE_FLAGS]) {
+		struct client_rule_flags *flags = (struct client_rule_flags *)client_rule->rules[CLIENT_RULE_FLAGS];
+		if (flags) {
+			request->server_flags = flags->flags;
+		}
+	}
+}
+
 static void _dns_server_request_set_id(struct dns_request *request, unsigned short id)
 {
 	request->id = id;
@@ -6096,6 +6161,7 @@ static int _dns_server_recv(struct dns_server_conn_head *conn, unsigned char *in
 	char name[DNS_MAX_CNAME_LEN];
 	struct dns_packet *packet = (struct dns_packet *)packet_buff;
 	struct dns_request *request = NULL;
+	struct dns_client_rules *client_rules = NULL;
 
 	/* decode packet */
 	tlog(TLOG_DEBUG, "recv query packet from %s, len = %d, type = %d",
@@ -6116,6 +6182,7 @@ static int _dns_server_recv(struct dns_server_conn_head *conn, unsigned char *in
 		 packet->head.qdcount, packet->head.ancount, packet->head.nscount, packet->head.nrcount, inpacket_len,
 		 packet->head.id, packet->head.tc, packet->head.rd, packet->head.ra, packet->head.rcode);
 
+	client_rules = _dns_server_get_client_rules(from, from_len);
 	request = _dns_server_new_request();
 	if (request == NULL) {
 		tlog(TLOG_ERROR, "malloc failed.\n");
@@ -6124,6 +6191,7 @@ static int _dns_server_recv(struct dns_server_conn_head *conn, unsigned char *in
 
 	memcpy(&request->localaddr, local, local_len);
 	_dns_server_request_set_client(request, conn);
+	_dns_server_request_set_client_rules(request, client_rules);
 	_dns_server_request_set_client_addr(request, from, from_len);
 	_dns_server_request_set_id(request, packet->head.id);
 

+ 65 - 0
test/cases/test-client-rule.cc

@@ -0,0 +1,65 @@
+/*************************************************************************
+ *
+ * Copyright (C) 2018-2023 Ruilin Peng (Nick) <[email protected]>.
+ *
+ * smartdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * smartdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include "client.h"
+#include "dns.h"
+#include "include/utils.h"
+#include "server.h"
+#include "gtest/gtest.h"
+
+class ClientRule : public ::testing::Test
+{
+  protected:
+	virtual void SetUp() {}
+	virtual void TearDown() {}
+};
+
+TEST_F(ClientRule, bogus_nxdomain)
+{
+	smartdns::MockServer server_upstream;
+	smartdns::MockServer server_upstream2;
+	smartdns::Server server;
+
+	server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
+		if (request->qtype != DNS_T_A) {
+			return smartdns::SERVER_REQUEST_SOA;
+		}
+
+		smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 611);
+		return smartdns::SERVER_REQUEST_OK;
+	});
+
+	server_upstream2.Start("udp://0.0.0.0:62053",
+						   [](struct smartdns::ServerRequestContext *request) { return smartdns::SERVER_REQUEST_SOA; });
+
+	/* this ip will be discard, but is reachable */
+	server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 10);
+
+	server.Start(R"""(bind [::]:60053
+server udp://127.0.0.1:61053 -g client -e 
+server udp://127.0.0.1:62053
+client-rules 127.0.0.1 -g client
+)""");
+	smartdns::Client client;
+	ASSERT_TRUE(client.Query("b.com", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
+}