Browse Source

feature: add local-domain options.

Nick Peng 3 months ago
parent
commit
3fba0f28ec

+ 3 - 0
etc/smartdns/smartdns.conf

@@ -365,6 +365,9 @@ log-level info
 # set ddns domain
 # ddns-domain domain
 
+# set local domain
+# local-domain domain
+
 # lookup local network hostname or ip address from mdns
 # mdns-lookup [yes|no]
 # mdns-lookup no

+ 6 - 0
package/openwrt/files/etc/init.d/smartdns

@@ -755,6 +755,12 @@ load_service()
 	config_get dns64 "$section" "dns64" ""
 	[ -z "$dns64" ] || conf_append "dns64" "$dns64"
 
+	config_get ddns_domain "$section" "ddns_domain" ""
+	[ -z "$ddns_domain" ] || conf_append "ddns-domain" "$ddns_domain"
+
+	config_get local_domain "$section" "local_domain" ""
+	[ -z "$local_domain" ] || conf_append "local-domain" "$local_domain"
+
 	config_get_bool mdns_lookup "$section" "mdns_lookup" "0"
 	[ "$mdns_lookup" = "1" ] && conf_append "mdns-lookup" "yes"
 

+ 2 - 0
src/dns_conf/dns_conf.c

@@ -40,6 +40,7 @@
 #include "ip_rule.h"
 #include "ip_set.h"
 #include "ipset.h"
+#include "local_domain.h"
 #include "nameserver.h"
 #include "nftset.h"
 #include "plugin.h"
@@ -215,6 +216,7 @@ static struct config_item _config_item[] = {
 	CONF_CUSTOM("domain-rules", _config_domain_rules, NULL),
 	CONF_CUSTOM("domain-set", _config_domain_set, NULL),
 	CONF_CUSTOM("ddns-domain", _config_ddns_domain, NULL),
+	CONF_CUSTOM("local-domain", _config_local_domain, NULL),
 	CONF_CUSTOM("dnsmasq-lease-file", _conf_dhcp_lease_dnsmasq_file, NULL),
 	CONF_CUSTOM("hosts-file", _config_hosts_file, NULL),
 	CONF_CUSTOM("group-begin", _config_group_begin, NULL),

+ 39 - 0
src/dns_conf/domain_rule.c

@@ -370,6 +370,45 @@ errout:
 	return 0;
 }
 
+int _config_domain_rule_remove(const char *domain, enum domain_rule type)
+{
+	char domain_key[DNS_MAX_CONF_CNAME_LEN];
+	int len = 0;
+	int sub_rule_only = 0;
+	int root_rule_only = 0;
+
+	if (type < 0 || type >= DOMAIN_RULE_MAX) {
+		tlog(TLOG_ERROR, "invalid domain rule type %d", type);
+		return -1;
+	}
+
+	if (strncmp(domain, "domain-set:", sizeof("domain-set:") - 1) == 0) {
+		return _config_domain_rule_set_each(domain + sizeof("domain-set:") - 1, _config_domain_rule_delete_callback,
+											NULL);
+	}
+
+	if (_config_setup_domain_key(domain, domain_key, sizeof(domain_key), &len, &root_rule_only, &sub_rule_only) != 0) {
+		tlog(TLOG_ERROR, "setup domain key failed for %s", domain);
+		return -1;
+	}
+
+	struct dns_domain_rule *domain_rule = art_search(&_config_current_rule_group()->domain_rule.tree,
+													   (unsigned char *)domain_key, len);
+	if (domain_rule == NULL) {
+		tlog(TLOG_ERROR, "domain %s not found", domain);
+		return -1;		
+	}
+
+	if (domain_rule->rules[type] == NULL) {
+		return 0;
+	}
+
+	_dns_rule_put(domain_rule->rules[type]);
+	domain_rule->rules[type] = NULL;
+	
+	return 0;
+}
+
 int _config_domain_rule_add(const char *domain, enum domain_rule type, void *rule)
 {
 	struct dns_domain_rule *domain_rule = NULL;

+ 1 - 0
src/dns_conf/domain_rule.h

@@ -34,6 +34,7 @@ void _dns_rule_get(struct dns_rule *rule);
 void _dns_rule_put(struct dns_rule *rule);
 
 int _config_domain_rule_add(const char *domain, enum domain_rule type, void *rule);
+int _config_domain_rule_remove(const char *domain, enum domain_rule type);
 int _config_domain_rule_flag_set(const char *domain, unsigned int flag, unsigned int is_clear);
 int _config_domain_rules(void *data, int argc, char *argv[]);
 int _config_domain_rule_delete(const char *domain);

+ 46 - 1
src/dns_conf/host_file.c

@@ -17,6 +17,7 @@
  */
 
 #include "host_file.h"
+#include "local_domain.h"
 #include "ptr.h"
 #include "set_file.h"
 #include "smartdns/lib/stringutil.h"
@@ -179,6 +180,46 @@ errout:
 	return NULL;
 }
 
+static int _conf_host_expand_local_domain(struct dns_hosts *host)
+{
+	struct dns_hosts *host_expand = NULL;
+	const char *local_domain = dns_conf_get_local_domain();
+	char domain[DNS_MAX_CNAME_LEN] = {0};
+	int ret;
+
+	if (local_domain == NULL || local_domain[0] == '\0') {
+		return 0;
+	}
+
+	if (strstr(host->domain, ".") != NULL) {
+		// already has domain, skip
+		return 0;
+	}
+
+	ret = snprintf(domain, sizeof(domain), "%s.%s", host->domain, local_domain);
+	if (ret < 0 || ret >= (int)sizeof(domain)) {
+		tlog(TLOG_WARN, "expand host %s with local domain %s failed, too long.", host->domain, local_domain);
+		return -1;
+	}
+
+	host_expand = _dns_conf_get_hosts(domain, host->dns_type);
+	if (host_expand == NULL) {
+		goto errout;
+	}
+
+	host_expand->is_soa = host->is_soa;
+	host_expand->is_dynamic = host->is_dynamic;
+	host_expand->host_type = host->host_type;
+	memcpy(host_expand->ipv6_addr, host->ipv6_addr, DNS_RR_AAAA_LEN);
+
+	dns_hosts_record_num++;
+
+	return 0;
+
+errout:
+	return -1;
+}
+
 int _conf_host_add(const char *hostname, const char *ip, dns_hosts_type host_type, int is_dynamic)
 {
 	struct dns_hosts *host = NULL;
@@ -250,7 +291,11 @@ int _conf_host_add(const char *hostname, const char *ip, dns_hosts_type host_typ
 		goto errout;
 	}
 
-	dns_hosts_record_num++;
+	dns_hosts_record_num += 2;
+
+	_conf_host_expand_local_domain(host);
+	_conf_host_expand_local_domain(host_other);
+
 	return 0;
 
 errout:

+ 52 - 0
src/dns_conf/local_domain.c

@@ -0,0 +1,52 @@
+/*************************************************************************
+ *
+ * Copyright (C) 2018-2025 Ruilin Peng (Nick) <[email protected]>.
+ *
+ * smartdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * smartdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include "local_domain.h"
+#include "domain_rule.h"
+#include "nameserver.h"
+#include "smartdns/lib/stringutil.h"
+
+static char local_domain[DNS_MAX_CNAME_LEN] = {0};
+
+const char *dns_conf_get_local_domain(void)
+{
+	return local_domain;
+}
+
+int _config_local_domain(void *data, int argc, char *argv[])
+{
+	if (argc <= 1) {
+		tlog(TLOG_ERROR, "invalid parameter.");
+		return -1;
+	}
+
+	const char *domain = argv[1];
+
+	if (local_domain[0] != '\0') {
+		_config_domain_rule_remove(local_domain, DOMAIN_RULE_NAMESERVER);
+        local_domain[0] = '\0';
+	}
+
+    if (domain[0] == '\0' || strncmp(domain, "-", sizeof("-")) == 0) {
+        return 0;
+    }
+
+	safe_strncpy(local_domain, domain, sizeof(local_domain));
+	_conf_domain_rule_nameserver(local_domain, DNS_SERVER_GROUP_MDNS);
+	return 0;
+}

+ 36 - 0
src/dns_conf/local_domain.h

@@ -0,0 +1,36 @@
+/*************************************************************************
+ *
+ * Copyright (C) 2018-2025 Ruilin Peng (Nick) <[email protected]>.
+ *
+ * smartdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * smartdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef _DNS_CONF_LOCAL_DOMAIN_H_
+#define _DNS_CONF_LOCAL_DOMAIN_H_
+
+#include "dns_conf.h"
+#include "smartdns/dns_conf.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif /*__cplusplus */
+
+const char *dns_conf_get_local_domain(void);
+
+int _config_local_domain(void *data, int argc, char *argv[]);
+
+#ifdef __cplusplus
+}
+#endif /*__cplusplus */
+#endif

+ 5 - 0
src/utils/neighbors.c

@@ -48,7 +48,11 @@ int netlink_get_neighbors(int family,
 	int ret = 0;
 	int send_count = 0;
 
+	memset(buf, 0, sizeof(buf));
+	memset(&sa, 0, sizeof(sa));
 	memset(&msg, 0, sizeof(msg));
+	
+	sa.nl_family = AF_NETLINK;
 	msg.msg_name = &sa;
 	msg.msg_namelen = sizeof(sa);
 	msg.msg_iov = &iov;
@@ -62,6 +66,7 @@ int netlink_get_neighbors(int family,
 	nlh->nlmsg_pid = getpid();
 
 	ndm = NLMSG_DATA(nlh);
+	memset(ndm, 0, sizeof(struct ndmsg));
 	ndm->ndm_family = family;
 
 	while (1) {

+ 159 - 0
test/cases/test-local-domain.cc

@@ -0,0 +1,159 @@
+/*************************************************************************
+ *
+ * Copyright (C) 2018-2025 Ruilin Peng (Nick) <[email protected]>.
+ *
+ * smartdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * smartdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include "client.h"
+#include "smartdns/dns.h"
+#include "smartdns/dns_client.h"
+#include "include/utils.h"
+#include "server.h"
+#include "gtest/gtest.h"
+#include <fstream>
+
+class LocalDomain : public ::testing::Test
+{
+  protected:
+	virtual void SetUp() {}
+	virtual void TearDown() {}
+};
+
+TEST(LocalDomain, query)
+{
+	smartdns::MockServer server_upstream1;
+	smartdns::MockServer server_upstream2;
+	smartdns::Server server;
+	smartdns::TempFile hosts_file;
+
+	std::string listen_url = "udp://";
+	listen_url += DNS_MDNS_IP;
+	listen_url += ":" + std::to_string(DNS_MDNS_PORT);
+
+	server_upstream1.Start(listen_url.c_str(), [](struct smartdns::ServerRequestContext *request) {
+		std::string domain = request->domain;
+		if (request->domain.length() == 0) {
+			return smartdns::SERVER_REQUEST_ERROR;
+		}
+
+		if (request->qtype == DNS_T_A) {
+			unsigned char addr[][4] = {{1, 2, 3, 4}};
+			dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[0]);
+		} else if (request->qtype == DNS_T_AAAA) {
+			unsigned char addr[][16] = {{1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}};
+			dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[0]);
+		} else {
+			return smartdns::SERVER_REQUEST_ERROR;
+		}
+
+		return smartdns::SERVER_REQUEST_OK;
+	});
+
+	server_upstream2.Start("udp://0.0.0.0:61053",
+						   [](struct smartdns::ServerRequestContext *request) { return smartdns::SERVER_REQUEST_SOA; });
+
+	server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 100);
+	server.MockPing(PING_TYPE_ICMP, "102:304:500::1", 60, 100);
+	hosts_file.Write("1.2.3.1 pc\n");
+	hosts_file.Write("1.2.3.2 phone\n");
+	hosts_file.Write("1.2.3.3 router\n");
+
+	
+	std::string conf = R"""(bind [::]:60053
+server 127.0.0.1:61053
+dualstack-ip-selection no
+local-domain lan
+# mdns-lookup yes
+)""";
+	conf += "hosts-file " + hosts_file.GetPath() + "\n";
+	conf += "\n";
+	server.Start(conf);
+	smartdns::Client client;
+
+	ASSERT_TRUE(client.Query("b.com A", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 0);
+	EXPECT_EQ(client.GetStatus(), "NXDOMAIN");
+
+	ASSERT_TRUE(client.Query("pc A", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "pc");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.1");
+
+	ASSERT_TRUE(client.Query("phone.lan A", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "phone.lan");
+	EXPECT_GT(client.GetAnswer()[0].GetTTL(), 59);
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.2");
+
+	ASSERT_TRUE(client.Query("router.lan AAAA", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 0);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+}
+
+TEST(LocalDomain, ptr)
+{
+	smartdns::MockServer server_upstream1;
+	smartdns::MockServer server_upstream2;
+	smartdns::Server server;
+	smartdns::TempFile hosts_file;
+
+	std::string listen_url = "udp://";
+	listen_url += DNS_MDNS_IP;
+	listen_url += ":" + std::to_string(DNS_MDNS_PORT);
+
+	server_upstream1.Start(listen_url.c_str(), [](struct smartdns::ServerRequestContext *request) {
+		std::string domain = request->domain;
+		if (request->domain.length() == 0) {
+			return smartdns::SERVER_REQUEST_ERROR;
+		}
+
+		if (request->qtype != DNS_T_PTR) {
+			return smartdns::SERVER_REQUEST_SOA;
+		}
+
+		dns_add_PTR(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, "host.local");
+
+		return smartdns::SERVER_REQUEST_OK;
+	});
+
+	server_upstream2.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
+		return smartdns::SERVER_REQUEST_ERROR;
+	});
+	hosts_file.Write("1.2.3.1 pc\n");
+	hosts_file.Write("1.2.3.2 phone\n");
+	hosts_file.Write("1.2.3.3 router\n");
+
+	std::string conf = R"""(bind [::]:60053
+server 127.0.0.1:61053
+dualstack-ip-selection no
+local-domain lan
+)""";
+	conf += "hosts-file " + hosts_file.GetPath() + "\n";
+	conf += "\n";
+	server.Start(conf);
+	smartdns::Client client;
+	ASSERT_TRUE(client.Query("-x 1.2.3.1", 60053));
+	std::cout << client.GetResult() << std::endl;
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetName(), "1.3.2.1.in-addr.arpa");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "pc.");
+}