Browse Source

test: add test framework

Nick Peng 2 years ago
parent
commit
81ab3f413a
15 changed files with 1075 additions and 8 deletions
  1. 10 1
      .github/workflows/c-cpp.yml
  2. 2 0
      .gitignore
  3. 7 0
      src/dns.h
  4. 16 1
      src/smartdns.c
  5. 4 4
      src/util.c
  6. 2 2
      src/util.h
  7. 46 0
      test/Makefile
  8. 17 0
      test/cases/test-mock-server.cc
  9. 34 0
      test/cases/test-tls-server.cc
  10. 301 0
      test/client.cc
  11. 106 0
      test/client.h
  12. 59 0
      test/include/utils.h
  13. 355 0
      test/server.cc
  14. 91 0
      test/server.h
  15. 25 0
      test/test.cc

+ 10 - 1
.github/workflows/c-cpp.yml

@@ -13,5 +13,14 @@ jobs:
 
     steps:
     - uses: actions/checkout@v2
+    - name: prepare
+      run: |
+        sudo apt update
+        sudo apt install libgtest-dev
     - name: make
-      run: make
+      run: |
+        make all -j4
+        make clean
+    - name: test
+      run: |
+        make -C test test -j8

+ 2 - 0
.gitignore

@@ -1,5 +1,7 @@
 .vscode
 *.o
+*.pem
 .DS_Store
 *.swp.
 systemd/smartdns.service
+test.bin

+ 7 - 0
src/dns.h

@@ -19,6 +19,10 @@
 #ifndef _DNS_HEAD_H
 #define _DNS_HEAD_H
 
+#ifdef __cplusplus
+extern "C" {
+#endif /*__cplusplus */
+
 #define DNS_RR_A_LEN 4
 #define DNS_RR_AAAA_LEN 16
 #define DNS_MAX_CNAME_LEN 256
@@ -310,4 +314,7 @@ struct dns_update_param {
 
 int dns_packet_update(unsigned char *data, int size, struct dns_update_param *param);
 
+#ifdef __cplusplus
+}
+#endif /*__cplusplus */
 #endif

+ 16 - 1
src/smartdns.c

@@ -646,7 +646,21 @@ static int _smartdns_init_pre(void)
 	return 0;
 }
 
+#ifdef TEST
+#define smartdns_test_notify(retval) smartdns_test_notify_func(fd_notify, retval)
+static void smartdns_test_notify_func(int fd_notify, uint64_t retval) {
+	/* notify parent kickoff */
+	if (fd_notify > 0) {
+		write(fd_notify, &retval, sizeof(retval));
+	}
+}
+
+int smartdns_main(int argc, char *argv[], int fd_notify);
+int smartdns_main(int argc, char *argv[], int fd_notify) 
+#else
+#define smartdns_test_notify(retval)
 int main(int argc, char *argv[])
+#endif
 {
 	int ret = 0;
 	int is_foreground = 0;
@@ -732,10 +746,11 @@ int main(int argc, char *argv[])
 	}
 
 	atexit(_smartdns_exit);
+	smartdns_test_notify(1);
 
 	return _smartdns_run();
 
 errout:
-
+	smartdns_test_notify(2);
 	return 1;
 }

+ 4 - 4
src/util.c

@@ -401,7 +401,7 @@ int check_is_ipaddr(const char *ip)
 	return -1;
 }
 
-int parse_uri(char *value, char *scheme, char *host, int *port, char *path)
+int parse_uri(const char *value, char *scheme, char *host, int *port, char *path)
 {
 	return parse_uri_ext(value, scheme, NULL, NULL, host, port, path);
 }
@@ -442,16 +442,16 @@ void urldecode(char *dst, const char *src)
 	*dst++ = '\0';
 }
 
-int parse_uri_ext(char *value, char *scheme, char *user, char *password, char *host, int *port, char *path)
+int parse_uri_ext(const char *value, char *scheme, char *user, char *password, char *host, int *port, char *path)
 {
 	char *scheme_end = NULL;
 	int field_len = 0;
-	char *process_ptr = value;
+	const char *process_ptr = value;
 	char user_pass_host_part[PATH_MAX];
 	char *user_password = NULL;
 	char *host_part = NULL;
 
-	char *host_end = NULL;
+	const char *host_end = NULL;
 
 	scheme_end = strstr(value, "://");
 	if (scheme_end) {

+ 2 - 2
src/util.h

@@ -69,9 +69,9 @@ int parse_ip(const char *value, char *ip, int *port);
 
 int check_is_ipaddr(const char *ip);
 
-int parse_uri(char *value, char *scheme, char *host, int *port, char *path);
+int parse_uri(const char *value, char *scheme, char *host, int *port, char *path);
 
-int parse_uri_ext(char *value, char *scheme, char *user, char *password, char *host, int *port, char *path);
+int parse_uri_ext(const char *value, char *scheme, char *user, char *password, char *host, int *port, char *path);
 
 void urldecode(char *dst, const char *src);
 

+ 46 - 0
test/Makefile

@@ -0,0 +1,46 @@
+
+# Copyright (C) 2018-2023 Ruilin Peng (Nick) <[email protected]>.
+#
+# smartdns is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# smartdns is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+BIN=test.bin
+CFLAGS += -I../src -I../src/include
+CFLAGS += -DTEST
+CFLAGS += -g -Wall -Wstrict-prototypes -fno-omit-frame-pointer -Wstrict-aliasing -funwind-tables -Wmissing-prototypes -Wshadow -Wextra -Wno-unused-parameter -Wno-implicit-fallthrough
+
+CXXFLAGS += -g
+CXXFLAGS += -I./ -I../src -I../src/include
+
+SMARTDNS_OBJS = lib/rbtree.o lib/art.o lib/bitops.o lib/radix.o lib/conf.o lib/nftset.o
+SMARTDNS_OBJS += smartdns.o fast_ping.o dns_client.o dns_server.o dns.o util.o tlog.o dns_conf.o dns_cache.o http_parse.o proxy.o
+OBJS = $(addprefix ../src/, $(SMARTDNS_OBJS))
+
+TEST_SOURCES := $(wildcard *.cc) $(wildcard */*.cc) $(wildcard */*/*.cc)
+TEST_OBJECTS := $(patsubst %.cc, %.o, $(TEST_SOURCES))
+OBJS += $(TEST_OBJECTS)
+
+LDFLAGS += -lssl -lcrypto -lpthread -ldl -lgtest -lstdc++ -lm
+
+.PHONY: all clean test
+
+all: $(BIN)
+
+$(BIN) : $(OBJS)
+	$(CC) $(OBJS) -o $@ $(LDFLAGS)
+
+test: $(BIN)
+	./$(BIN)
+
+clean:
+	$(RM) $(OBJS) $(BIN)

+ 17 - 0
test/cases/test-mock-server.cc

@@ -0,0 +1,17 @@
+#include "client.h"
+#include "include/utils.h"
+#include "server.h"
+#include "gtest/gtest.h"
+
+TEST(server, mock)
+{
+	smartdns::MockServer server;
+	smartdns::Client client;
+	server.Start("udp://0.0.0.0:7053", [](struct smartdns::ServerRequestContext *request) {
+		request->response_data_len = 0;
+		return false;
+	});
+
+	ASSERT_TRUE(client.Query("example.com", 7053));
+	EXPECT_EQ(client.GetStatus(), "SERVFAIL");
+}

+ 34 - 0
test/cases/test-tls-server.cc

@@ -0,0 +1,34 @@
+#include "client.h"
+#include "include/utils.h"
+#include "server.h"
+#include "gtest/gtest.h"
+
+TEST(server, TLSServer)
+{
+	Defer
+	{
+		unlink("/tmp/smartdns-cert.pem");
+		unlink("/tmp/smartdns-key.pem");
+	};
+
+	smartdns::Server server_wrap;
+	smartdns::Server server;
+
+	server.Start(R"""(bind [::]:61053
+server-tls 127.0.0.1:60053 -no-check-certificate
+log-num 0
+log-console yes
+log-level debug
+cache-persist no)""");
+	server_wrap.Start(R"""(bind-tls [::]:60053
+address /example.com/1.2.3.4
+log-num 0
+log-console yes
+log-level debug
+cache-persist no)""");
+	smartdns::Client client;
+	ASSERT_TRUE(client.Query("example.com", 61053));
+	ASSERT_EQ(client.GetAnswerNum(), 1);
+	EXPECT_EQ(client.GetStatus(), "NOERROR");
+	EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
+}

+ 301 - 0
test/client.cc

@@ -0,0 +1,301 @@
+/*************************************************************************
+ *
+ * Copyright (C) 2018-2023 Ruilin Peng (Nick) <[email protected]>.
+ *
+ * smartdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * smartdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include "client.h"
+#include <iostream>
+#include <memory>
+#include <regex>
+#include <signal.h>
+#include <string>
+#include <sys/types.h>
+#include <sys/wait.h>
+#include <vector>
+
+namespace smartdns
+{
+
+std::vector<std::string> StringSplit(const std::string &s, const char delim)
+{
+	std::vector<std::string> ret;
+	std::string::size_type lastPos = s.find_first_not_of(delim, 0);
+	std::string::size_type pos = s.find_first_of(delim, lastPos);
+	while (std::string::npos != pos || std::string::npos != lastPos) {
+		ret.push_back(s.substr(lastPos, pos - lastPos));
+		lastPos = s.find_first_not_of(delim, pos);
+		pos = s.find_first_of(delim, lastPos);
+	}
+
+	return ret;
+}
+
+DNSRecord::DNSRecord() {}
+
+DNSRecord::~DNSRecord() {}
+
+bool DNSRecord::Parser(const std::string &line)
+{
+	std::vector<std::string> fields = StringSplit(line, '\t');
+	if (fields.size() < 3) {
+		std::cerr << "Invalid DNS record: " << line << ", size: " << fields.size() << std::endl;
+		return false;
+	}
+
+	if (fields.size() == 3) {
+		name_ = fields[0];
+		if (name_.size() > 1) {
+			name_.resize(name_.size() - 1);
+		}
+		class_ = fields[1];
+		type_ = fields[2];
+		return true;
+	}
+
+	name_ = fields[0];
+	if (name_.size() > 1) {
+		name_.resize(name_.size() - 1);
+	}
+	ttl_ = std::stoi(fields[1]);
+	class_ = fields[2];
+	type_ = fields[3];
+	data_ = fields[4];
+
+	for (int i = 5; i < fields.size(); i++) {
+		data_ += " " + fields[i];
+	}
+
+	return true;
+}
+
+std::string DNSRecord::GetName()
+{
+	return name_;
+}
+
+std::string DNSRecord::GetType()
+{
+	return type_;
+}
+
+std::string DNSRecord::GetClass()
+{
+	return class_;
+}
+
+int DNSRecord::GetTTL()
+{
+	return ttl_;
+}
+
+std::string DNSRecord::GetData()
+{
+	return data_;
+}
+
+Client::Client() {}
+
+bool Client::Query(const std::string &dig_cmds, int port, const std::string &ip)
+{
+	std::string cmd = "dig ";
+	if (port > 0) {
+		cmd += "-p " + std::to_string(port);
+	}
+
+	if (ip.length() > 0) {
+		cmd += " @" + ip;
+	} else {
+		cmd += " @127.0.0.1";
+	}
+
+	cmd += " " + dig_cmds;
+	cmd += " +tries=1";
+	FILE *fp = NULL;
+
+	fp = popen(cmd.c_str(), "r");
+	if (fp == NULL) {
+		return false;
+	}
+
+	std::shared_ptr<FILE> pipe(fp, pclose);
+	result_.clear();
+	char buffer[4096];
+	usleep(10000);
+	while (fgets(buffer, 4096, pipe.get())) {
+		result_ += buffer;
+	}
+
+	if (ParserResult() == false) {
+		Clear();
+	}
+
+	return true;
+}
+
+std::vector<DNSRecord> Client::GetQuery()
+{
+	return records_query_;
+}
+
+std::vector<DNSRecord> Client::GetAnswer()
+{
+	return records_answer_;
+}
+
+std::vector<DNSRecord> Client::GetAuthority()
+{
+	return records_authority_;
+}
+
+std::vector<DNSRecord> Client::GetAdditional()
+{
+	return records_additional_;
+}
+
+int Client::GetAnswerNum()
+{
+	return answer_num_;
+}
+
+std::string Client::GetStatus()
+{
+	return status_;
+}
+
+std::string Client::GetServer()
+{
+	return server_;
+}
+
+int Client::GetQueryTime()
+{
+	return query_time_;
+}
+
+int Client::GetMsgSize()
+{
+	return msg_size_;
+}
+
+std::string Client::GetFlags()
+{
+	return flags_;
+}
+
+std::string Client::GetResult()
+{
+	return result_;
+}
+
+void Client::Clear()
+{
+	result_.clear();
+	answer_num_ = 0;
+	status_.clear();
+	server_.clear();
+	query_time_ = 0;
+	msg_size_ = 0;
+	flags_.clear();
+	records_query_.clear();
+	records_answer_.clear();
+	records_authority_.clear();
+	records_additional_.clear();
+}
+
+void Client::PrintResult()
+{
+	std::cout << result_ << std::endl;
+}
+
+bool Client::ParserRecord(const std::string &record_str, std::vector<DNSRecord> &record)
+{
+	DNSRecord r;
+
+	if (r.Parser(record_str) == false) {
+		return false;
+	}
+
+	record.push_back(r);
+	return true;
+}
+
+bool Client::ParserResult()
+{
+	std::smatch match;
+
+	std::regex reg_goanswer(";; Got answer:");
+	if (std::regex_search(result_, match, reg_goanswer) == false) {
+		std::cout << "DIG FAILED:\n" << result_ << std::endl;
+		return false;
+	}
+
+	std::regex reg_answer_num(", ANSWER: ([0-9]+),");
+	if (std::regex_search(result_, match, reg_answer_num)) {
+		answer_num_ = std::stoi(match[1]);
+	}
+
+	std::regex reg_status(", status: ([A-Z]+),");
+	if (std::regex_search(result_, match, reg_status)) {
+		status_ = match[1];
+	}
+
+	std::regex reg_server(";; SERVER: ([0-9.]+)#");
+	if (std::regex_search(result_, match, reg_server)) {
+		server_ = match[1];
+	}
+
+	std::regex reg_querytime(";; Query time: ([0-9]+) msec");
+	if (std::regex_search(result_, match, reg_querytime)) {
+		query_time_ = std::stoi(match[1]);
+	}
+
+	std::regex reg_msg_size(";; MSG SIZE  rcvd: ([0-9]+)");
+	if (std::regex_search(result_, match, reg_msg_size)) {
+		msg_size_ = std::stoi(match[1]);
+	}
+
+	std::regex reg_flags(";; flags: ([a-z A-Z]+);");
+	if (std::regex_search(result_, match, reg_flags)) {
+		flags_ = match[1];
+	}
+
+	std::regex reg_question(";; QUESTION SECTION:\n.(.*\n)+?\n");
+	if (std::regex_search(result_, match, reg_question)) {
+		if (ParserRecord(match[1], records_query_) == false) {
+			return false;
+		}
+	}
+
+	std::regex reg_answer(";; ANSWER SECTION:\n(.*)\n");
+	if (std::regex_search(result_, match, reg_answer)) {
+		if (ParserRecord(match[1], records_answer_) == false) {
+			return false;
+		}
+	}
+
+	std::regex reg_addition(";; ADDITIONAL SECTION:\n(.*)\n");
+	if (std::regex_search(result_, match, reg_answer)) {
+		if (ParserRecord(match[1], records_additional_) == false) {
+			return false;
+		}
+	}
+
+	return true;
+}
+
+Client::~Client() {}
+
+} // namespace smartdns

+ 106 - 0
test/client.h

@@ -0,0 +1,106 @@
+/*************************************************************************
+ *
+ * Copyright (C) 2018-2023 Ruilin Peng (Nick) <[email protected]>.
+ *
+ * smartdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * smartdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef _SMARTDNS_CLIENT_
+#define _SMARTDNS_CLIENT_
+
+#include <string>
+#include <unistd.h>
+#include <vector>
+
+namespace smartdns
+{
+
+class DNSRecord
+{
+  public:
+	DNSRecord();
+	virtual ~DNSRecord();
+
+	bool Parser(const std::string &line);
+
+	std::string GetName();
+
+	std::string GetType();
+
+	std::string GetClass();
+
+	int GetTTL();
+
+	std::string GetData();
+
+  private:
+	std::string name_;
+	std::string type_;
+	std::string class_;
+	int ttl_;
+	std::string data_;
+};
+
+class Client
+{
+  public:
+	Client();
+	virtual ~Client();
+	bool Query(const std::string &dig_cmds, int port = 0, const std::string &ip = "");
+
+	std::string GetResult();
+
+	std::vector<DNSRecord> GetQuery();
+
+	std::vector<DNSRecord> GetAnswer();
+
+	std::vector<DNSRecord> GetAuthority();
+
+	std::vector<DNSRecord> GetAdditional();
+
+	int GetAnswerNum();
+
+	std::string GetStatus();
+
+	std::string GetServer();
+
+	int GetQueryTime();
+
+	int GetMsgSize();
+
+	std::string GetFlags();
+
+	void Clear();
+
+	void PrintResult();
+
+  private:
+	bool ParserResult();
+	bool ParserRecord(const std::string &record_str, std::vector<DNSRecord> &record);
+	std::string result_;
+	int answer_num_{0};
+	std::string status_;
+	std::string server_;
+	int query_time_{0};
+	int msg_size_{0};
+	std::string flags_;
+
+	std::vector<DNSRecord> records_query_;
+	std::vector<DNSRecord> records_answer_;
+	std::vector<DNSRecord> records_authority_;
+	std::vector<DNSRecord> records_additional_;
+};
+
+} // namespace smartdns
+#endif // _SMARTDNS_CLIENT_

+ 59 - 0
test/include/utils.h

@@ -0,0 +1,59 @@
+/*************************************************************************
+ *
+ * Copyright (C) 2018-2023 Ruilin Peng (Nick) <[email protected]>.
+ *
+ * smartdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * smartdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef _SMARTDNS_TEST_UTILS_
+#define _SMARTDNS_TEST_UTILS_
+
+#include <functional>
+
+namespace smartdns
+{
+
+class DeferGuard
+{
+  public:
+	template <class Callable>
+
+	DeferGuard(Callable &&fn) noexcept : fn_(std::forward<Callable>(fn))
+	{
+	}
+	DeferGuard(DeferGuard &&other) noexcept
+	{
+		fn_ = std::move(other.fn_);
+		other.fn_ = nullptr;
+	}
+
+	virtual ~DeferGuard()
+	{
+		if (fn_) {
+			fn_();
+		}
+	};
+	DeferGuard(const DeferGuard &) = delete;
+	void operator=(const DeferGuard &) = delete;
+
+  private:
+	std::function<void()> fn_;
+};
+
+#define SMARTDNS_CONCAT_(a, b) a##b
+#define SMARTDNS_CONCAT(a, b) SMARTDNS_CONCAT_(a, b)
+#define Defer ::smartdns::DeferGuard SMARTDNS_CONCAT(__defer__, __LINE__) = [&]()
+
+} // namespace smartdns
+#endif // _SMARTDNS_TEST_UTILS_

+ 355 - 0
test/server.cc

@@ -0,0 +1,355 @@
+/*************************************************************************
+ *
+ * Copyright (C) 2018-2023 Ruilin Peng (Nick) <[email protected]>.
+ *
+ * smartdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * smartdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include "server.h"
+#include "include/utils.h"
+#include "util.h"
+#include <arpa/inet.h>
+#include <fcntl.h>
+#include <netinet/in.h>
+#include <poll.h>
+#include <signal.h>
+#include <string>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/wait.h>
+#include <unistd.h>
+#include <vector>
+#include <fstream>
+
+namespace smartdns
+{
+
+extern "C" int smartdns_main(int argc, char *argv[], int fd_notify);
+
+MockServer::MockServer() {}
+
+MockServer::~MockServer()
+{
+	Stop();
+}
+
+bool MockServer::IsRunning()
+{
+	if (fd_ > 0) {
+		return true;
+	}
+
+	return false;
+}
+
+void MockServer::Stop()
+{
+	if (run_ == true) {
+		run_ = false;
+		if (thread_.joinable()) {
+			thread_.join();
+		}
+	}
+
+	if (fd_ > 0) {
+		close(fd_);
+		fd_;
+	}
+}
+
+void MockServer::Run()
+{
+	while (run_ == true) {
+		struct pollfd fds[1];
+		fds[0].fd = fd_;
+		fds[0].events = POLLIN;
+		fds[0].revents = 0;
+		int ret = poll(fds, 1, 100);
+		if (ret == 0) {
+			continue;
+		} else if (ret < 0) {
+			sleep(1);
+			continue;
+		}
+
+		if (fds[0].revents & POLLIN) {
+			struct sockaddr_storage from;
+			socklen_t addrlen = sizeof(from);
+			unsigned char in_buff[4096];
+			int len = recvfrom(fd_, in_buff, sizeof(in_buff), 0, (struct sockaddr *)&from, &addrlen);
+			if (len < 0) {
+				continue;
+			}
+
+			char packet_buff[4096];
+			unsigned char out_buff[4096];
+			memset(packet_buff, 0, sizeof(packet_buff));
+			struct dns_packet *packet = (struct dns_packet *)packet_buff;
+			struct ServerRequestContext request;
+			memset(&request, 0, sizeof(request));
+
+			int ret = dns_decode(packet, sizeof(packet_buff), in_buff, len);
+			if (ret == 0) {
+				request.packet = packet;
+				if (packet->head.qr == DNS_QR_QUERY) {
+					struct dns_rrs *rrs = NULL;
+					int rr_count = 0;
+					int qtype = 0;
+					int qclass = 0;
+					char domain[256];
+
+					rrs = dns_get_rrs_start(packet, DNS_RRS_QD, &rr_count);
+					for (int i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(packet, rrs)) {
+						ret = dns_get_domain(rrs, domain, sizeof(domain), &qtype, &qclass);
+						if (ret == 0) {
+							request.domain = domain;
+							request.qtype = (dns_type)qtype;
+							request.qclass = qclass;
+							break;
+						}
+					}
+				}
+			}
+
+			request.from = (struct sockaddr_storage *)&from;
+			request.fromlen = addrlen;
+			request.request_data = in_buff;
+			request.request_data_len = len;
+			request.response_data = out_buff;
+			request.response_data_len = 0;
+			request.response_data_max_len = sizeof(out_buff);
+
+			auto callback_ret = callback_(&request);
+			if (callback_ret == false) {
+				unsigned char out_packet_buff[4096];
+				struct dns_packet *out_packet = (struct dns_packet *)out_packet_buff;
+				struct dns_head head;
+				memset(&head, 0, sizeof(head));
+				head.id = packet->head.id;
+				head.qr = DNS_QR_ANSWER;
+				head.opcode = DNS_OP_QUERY;
+				head.aa = 0;
+				head.rd = 1;
+				head.ra = 0;
+				head.rcode = DNS_RC_SERVFAIL;
+
+				dns_packet_init(out_packet, sizeof(out_packet_buff), &head);
+				request.response_data_len =
+					dns_encode(request.response_data, request.response_data_max_len, out_packet);
+			}
+
+			sendto(fd_, request.response_data, request.response_data_len, MSG_NOSIGNAL, (struct sockaddr *)&from,
+				   addrlen);
+		}
+	}
+}
+
+bool MockServer::GetAddr(const std::string &host, const std::string port, int type, int protocol,
+						 struct sockaddr_storage *addr, socklen_t *addrlen)
+
+{
+	struct addrinfo hints;
+	struct addrinfo *result = NULL;
+
+	memset(&hints, 0, sizeof(hints));
+	hints.ai_family = AF_UNSPEC;
+	hints.ai_socktype = type;
+	hints.ai_protocol = protocol;
+	hints.ai_flags = AI_PASSIVE;
+	if (getaddrinfo(host.c_str(), port.c_str(), &hints, &result) != 0) {
+		goto errout;
+	}
+
+	memcpy(addr, result->ai_addr, result->ai_addrlen);
+	*addrlen = result->ai_addrlen;
+	return true;
+errout:
+	if (result) {
+		freeaddrinfo(result);
+	}
+	return NULL;
+}
+
+bool MockServer::Start(const std::string &url, ServerRequest callback)
+{
+	char c_scheme[256];
+	char c_host[256];
+	int port;
+	char c_path[256];
+	int fd;
+	struct sockaddr_storage addr;
+	socklen_t addrlen;
+
+	if (callback == nullptr) {
+		return false;
+	}
+
+	if (parse_uri(url.c_str(), c_scheme, c_host, &port, c_path) != 0) {
+		return false;
+	}
+
+	std::string scheme(c_scheme);
+	std::string host(c_host);
+	std::string path(c_path);
+
+	if (scheme != "udp") {
+		return false;
+	}
+
+	if (GetAddr(host, std::to_string(port), SOCK_DGRAM, IPPROTO_UDP, &addr, &addrlen) == false) {
+		return false;
+	}
+
+	fd = socket(addr.ss_family, SOCK_DGRAM | SOCK_CLOEXEC, 0);
+	if (fd < 0) {
+		return false;
+	}
+
+	if (bind(fd, (struct sockaddr *)&addr, addrlen) != 0) {
+		close(fd);
+		return false;
+	}
+
+	run_ = true;
+	thread_ = std::thread(&MockServer::Run, this);
+	fd_ = fd;
+	callback_ = callback;
+	return true;
+}
+
+Server::Server() {}
+
+bool Server::Start(const std::string &conf, enum CONF_TYPE type)
+{
+	int fds[2];
+	std::string conf_file;
+
+	fds[0] = 0;
+	fds[1] = 0;
+	Defer
+	{
+		if (fds[0] > 0) {
+			close(fds[0]);
+		}
+
+		if (fds[0] > 0) {
+			close(fds[1]);
+		}
+	};
+
+	if (type == CONF_TYPE_STRING) {
+		char filename[128];
+		strncpy(filename, "/tmp/smartdns_conf.XXXXXX", sizeof(filename));
+		int fd = mkstemp(filename);
+		if (fd < 0) {
+			return false;
+		}
+		Defer {
+			close(fd);
+		};
+
+		std::ofstream ofs(filename);
+		if (ofs.is_open() == false) {
+			return false;
+		}
+		ofs.write(conf.data(), conf.size());
+		ofs.flush();
+		ofs.close();
+		conf_file = filename;
+		clean_conf_file_ = true;
+	} else if (type == CONF_TYPE_FILE) {
+		conf_file = conf;
+	} else {
+		return false;
+	}
+
+	if (access(conf_file.c_str(), F_OK) != 0) {
+		return false;
+	}
+
+	conf_file_ = conf_file;
+
+	if (pipe2(fds, O_CLOEXEC | O_NONBLOCK) != 0) {
+		return false;
+	}
+
+	pid_t pid = fork();
+	if (pid == 0) {
+		std::vector<std::string> args = {
+			"smartdns", "-f", "-x", "-c", conf_file, "-p", "-",
+		};
+		char *argv[args.size() + 1];
+		for (size_t i = 0; i < args.size(); i++) {
+			argv[i] = (char *)args[i].c_str();
+		}
+
+		smartdns_main(args.size(), argv, fds[1]);
+		_exit(1);
+	} else if (pid < 0) {
+		return false;
+	}
+
+	struct pollfd pfd[1];
+	pfd[0].fd = fds[0];
+	pfd[0].events = POLLIN;
+
+	int ret = poll(pfd, 1, 10000);
+	if (ret == 0) {
+		kill(pid, SIGKILL);
+		return false;
+	}
+
+	pid_ = pid;
+	return pid > 0;
+}
+
+void Server::Stop(bool graceful)
+{
+	if (pid_ > 0) {
+		if (graceful) {
+			kill(pid_, SIGTERM);
+		} else {
+			kill(pid_, SIGKILL);
+		}
+	}
+
+	waitpid(pid_, NULL, 0);
+
+	pid_ = 0;
+	if (clean_conf_file_ == true) {
+		unlink(conf_file_.c_str());
+		conf_file_.clear();
+		clean_conf_file_ = false;
+	}
+}
+
+bool Server::IsRunning()
+{
+	if (pid_ <= 0) {
+		return false;
+	}
+
+	if (waitpid(pid_, NULL, WNOHANG) == 0) {
+		return true;
+	}
+
+	return kill(pid_, 0) == 0;
+}
+
+Server::~Server()
+{
+	Stop(false);
+}
+
+} // namespace smartdns

+ 91 - 0
test/server.h

@@ -0,0 +1,91 @@
+/*************************************************************************
+ *
+ * Copyright (C) 2018-2023 Ruilin Peng (Nick) <[email protected]>.
+ *
+ * smartdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * smartdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef _SMARTDNS_SERVER_
+#define _SMARTDNS_SERVER_
+
+#include "dns.h"
+#include <functional>
+#include <string>
+#include <unistd.h>
+#include <sys/socket.h>
+#include <thread>
+
+namespace smartdns
+{
+
+class Server
+{
+  public:
+	enum CONF_TYPE {
+		CONF_TYPE_STRING,
+		CONF_TYPE_FILE,
+	};
+	Server();
+	virtual ~Server();
+
+	bool Start(const std::string &conf, enum CONF_TYPE type = CONF_TYPE_STRING);
+	void Stop(bool graceful = true);
+	bool IsRunning();
+
+  private:
+	pid_t pid_;
+	int fd_;
+	std::string conf_file_;
+	bool clean_conf_file_{false};
+};
+
+struct ServerRequestContext {
+	std::string domain;
+	dns_type qtype;
+	int qclass;
+	struct sockaddr_storage *from;
+	socklen_t fromlen;
+	struct dns_packet *packet;
+	uint8_t *request_data;
+	int request_data_len;
+	uint8_t *response_data;
+	int response_data_max_len;
+	int response_data_len;
+};
+
+using ServerRequest = std::function<bool(struct ServerRequestContext *request)>;
+
+class MockServer
+{
+  public:
+	MockServer();
+	virtual ~MockServer();
+
+	bool Start(const std::string &url, ServerRequest callback);
+	void Stop();
+	bool IsRunning();
+
+  private:
+	void Run();
+
+	bool GetAddr(const std::string &host, const std::string port, int type, int protocol, struct sockaddr_storage *addr,
+				 socklen_t *addrlen);
+	int fd_;
+	std::thread thread_;
+	bool run_;
+	ServerRequest callback_;
+};
+
+} // namespace smartdns
+#endif // _SMARTDNS_SERVER_

+ 25 - 0
test/test.cc

@@ -0,0 +1,25 @@
+/*************************************************************************
+ *
+ * Copyright (C) 2018-2023 Ruilin Peng (Nick) <[email protected]>.
+ *
+ * smartdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * smartdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include "gtest/gtest.h"
+
+int main(int argc, char **argv)
+{
+	::testing::InitGoogleTest(&argc, argv);
+	return RUN_ALL_TESTS();
+}