test-cache.cc 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. /*************************************************************************
  2. *
  3. * Copyright (C) 2018-2023 Ruilin Peng (Nick) <[email protected]>.
  4. *
  5. * smartdns is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * smartdns is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. */
  18. #include "client.h"
  19. #include "dns.h"
  20. #include "include/utils.h"
  21. #include "server.h"
  22. #include "gtest/gtest.h"
  23. #include <fcntl.h>
  24. #include <fstream>
  25. #include <sys/stat.h>
  26. #include <sys/types.h>
  27. /* clang-format off */
  28. #include "dns_cache.h"
  29. /* clang-format on */
  30. class Cache : public ::testing::Test
  31. {
  32. protected:
  33. void SetUp() override {}
  34. void TearDown() override {}
  35. };
  36. TEST_F(Cache, min)
  37. {
  38. smartdns::MockServer server_upstream;
  39. smartdns::Server server;
  40. server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
  41. std::string domain = request->domain;
  42. if (request->domain.length() == 0) {
  43. return smartdns::SERVER_REQUEST_ERROR;
  44. }
  45. if (request->qtype == DNS_T_A) {
  46. unsigned char addr[4] = {1, 2, 3, 4};
  47. dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  48. } else if (request->qtype == DNS_T_AAAA) {
  49. unsigned char addr[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
  50. dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  51. } else {
  52. return smartdns::SERVER_REQUEST_ERROR;
  53. }
  54. request->response_packet->head.rcode = DNS_RC_NOERROR;
  55. return smartdns::SERVER_REQUEST_OK;
  56. });
  57. server.Start(R"""(bind [::]:60053
  58. server 127.0.0.1:61053
  59. log-num 0
  60. cache-size 1
  61. rr-ttl-min 1
  62. speed-check-mode none
  63. response-mode fastest-response
  64. log-console yes
  65. log-level debug
  66. cache-persist no)""");
  67. smartdns::Client client;
  68. ASSERT_TRUE(client.Query("a.com", 60053));
  69. std::cout << client.GetResult() << std::endl;
  70. ASSERT_EQ(client.GetAnswerNum(), 1);
  71. EXPECT_EQ(client.GetStatus(), "NOERROR");
  72. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  73. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 1);
  74. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  75. }
  76. TEST_F(Cache, max_reply_ttl)
  77. {
  78. smartdns::MockServer server_upstream;
  79. smartdns::Server server;
  80. server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
  81. std::string domain = request->domain;
  82. if (request->domain.length() == 0) {
  83. return smartdns::SERVER_REQUEST_ERROR;
  84. }
  85. if (request->qtype == DNS_T_A) {
  86. unsigned char addr[4] = {1, 2, 3, 4};
  87. dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  88. } else if (request->qtype == DNS_T_AAAA) {
  89. unsigned char addr[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
  90. dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  91. } else {
  92. return smartdns::SERVER_REQUEST_ERROR;
  93. }
  94. request->response_packet->head.rcode = DNS_RC_NOERROR;
  95. return smartdns::SERVER_REQUEST_OK;
  96. });
  97. server.Start(R"""(bind [::]:60053
  98. server 127.0.0.1:61053
  99. log-num 0
  100. cache-size 1
  101. rr-ttl-min 600
  102. rr-ttl-reply-max 5
  103. speed-check-mode none
  104. response-mode fastest-response
  105. log-console yes
  106. log-level debug
  107. cache-persist no)""");
  108. smartdns::Client client;
  109. ASSERT_TRUE(client.Query("a.com", 60053));
  110. std::cout << client.GetResult() << std::endl;
  111. ASSERT_EQ(client.GetAnswerNum(), 1);
  112. EXPECT_EQ(client.GetStatus(), "NOERROR");
  113. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  114. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 5);
  115. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  116. }
  117. TEST_F(Cache, max_reply_ttl_expired)
  118. {
  119. smartdns::MockServer server_upstream;
  120. smartdns::Server server;
  121. server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
  122. std::string domain = request->domain;
  123. if (request->domain.length() == 0) {
  124. return smartdns::SERVER_REQUEST_ERROR;
  125. }
  126. if (request->qtype == DNS_T_A) {
  127. unsigned char addr[4] = {1, 2, 3, 4};
  128. dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  129. } else if (request->qtype == DNS_T_AAAA) {
  130. unsigned char addr[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
  131. dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  132. } else {
  133. return smartdns::SERVER_REQUEST_ERROR;
  134. }
  135. request->response_packet->head.rcode = DNS_RC_NOERROR;
  136. return smartdns::SERVER_REQUEST_OK;
  137. });
  138. server.Start(R"""(bind [::]:60053
  139. server 127.0.0.1:61053
  140. log-num 0
  141. cache-size 1
  142. rr-ttl-min 600
  143. rr-ttl-reply-max 6
  144. log-console yes
  145. log-level debug
  146. cache-persist no)""");
  147. smartdns::Client client;
  148. ASSERT_TRUE(client.Query("a.com", 60053));
  149. std::cout << client.GetResult() << std::endl;
  150. ASSERT_EQ(client.GetAnswerNum(), 1);
  151. EXPECT_EQ(client.GetStatus(), "NOERROR");
  152. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  153. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
  154. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  155. ASSERT_TRUE(client.Query("a.com", 60053));
  156. std::cout << client.GetResult() << std::endl;
  157. ASSERT_EQ(client.GetAnswerNum(), 1);
  158. EXPECT_EQ(client.GetStatus(), "NOERROR");
  159. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  160. EXPECT_GE(client.GetAnswer()[0].GetTTL(), 5);
  161. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  162. }
  163. TEST_F(Cache, nocache)
  164. {
  165. smartdns::MockServer server_upstream;
  166. smartdns::Server server;
  167. server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
  168. std::string domain = request->domain;
  169. if (request->domain.length() == 0) {
  170. return smartdns::SERVER_REQUEST_ERROR;
  171. }
  172. usleep(15000);
  173. if (request->qtype == DNS_T_A) {
  174. unsigned char addr[4] = {1, 2, 3, 4};
  175. dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  176. } else if (request->qtype == DNS_T_AAAA) {
  177. unsigned char addr[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
  178. dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  179. } else {
  180. return smartdns::SERVER_REQUEST_ERROR;
  181. }
  182. request->response_packet->head.rcode = DNS_RC_NOERROR;
  183. return smartdns::SERVER_REQUEST_OK;
  184. });
  185. server.Start(R"""(bind [::]:60053
  186. server 127.0.0.1:61053
  187. log-num 0
  188. cache-size 100
  189. rr-ttl-min 600
  190. rr-ttl-reply-max 5
  191. log-console yes
  192. log-level debug
  193. domain-rules /a.com/ --no-cache
  194. cache-persist no)""");
  195. smartdns::Client client;
  196. ASSERT_TRUE(client.Query("a.com", 60053));
  197. std::cout << client.GetResult() << std::endl;
  198. ASSERT_EQ(client.GetAnswerNum(), 1);
  199. EXPECT_EQ(client.GetStatus(), "NOERROR");
  200. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  201. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
  202. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  203. ASSERT_TRUE(client.Query("a.com", 60053));
  204. EXPECT_GT(client.GetQueryTime(), 10);
  205. std::cout << client.GetResult() << std::endl;
  206. ASSERT_EQ(client.GetAnswerNum(), 1);
  207. EXPECT_EQ(client.GetStatus(), "NOERROR");
  208. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  209. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
  210. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  211. }
  212. TEST_F(Cache, save_file)
  213. {
  214. smartdns::MockServer server_upstream;
  215. auto cache_file = "/tmp/smartdns_cache." + smartdns::GenerateRandomString(10);
  216. std::string conf = R"""(
  217. bind [::]:60053@lo
  218. server 127.0.0.1:62053
  219. log-num 0
  220. log-console yes
  221. log-level debug
  222. cache-persist yes
  223. dualstack-ip-selection no
  224. )""";
  225. conf += "cache-file " + cache_file;
  226. Defer
  227. {
  228. unlink(cache_file.c_str());
  229. };
  230. server_upstream.Start("udp://0.0.0.0:62053", [](struct smartdns::ServerRequestContext *request) {
  231. if (request->qtype == DNS_T_A) {
  232. smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4");
  233. return smartdns::SERVER_REQUEST_OK;
  234. }
  235. return smartdns::SERVER_REQUEST_SOA;
  236. });
  237. {
  238. smartdns::Server server;
  239. server.Start(conf);
  240. smartdns::Client client;
  241. ASSERT_TRUE(client.Query("a.com", 60053));
  242. std::cout << client.GetResult() << std::endl;
  243. ASSERT_EQ(client.GetAnswerNum(), 1);
  244. EXPECT_EQ(client.GetStatus(), "NOERROR");
  245. EXPECT_LT(client.GetQueryTime(), 100);
  246. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
  247. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  248. server.Stop();
  249. usleep(200 * 1000);
  250. }
  251. ASSERT_EQ(access(cache_file.c_str(), F_OK), 0);
  252. std::fstream fs(cache_file, std::ios::in);
  253. struct dns_cache_file head;
  254. memset(&head, 0, sizeof(head));
  255. fs.read((char *)&head, sizeof(head));
  256. EXPECT_EQ(head.magic, MAGIC_NUMBER);
  257. EXPECT_EQ(head.cache_number, 1);
  258. }
  259. TEST_F(Cache, corrupt_file)
  260. {
  261. smartdns::MockServer server_upstream;
  262. auto cache_file = "/tmp/smartdns_cache." + smartdns::GenerateRandomString(10);
  263. std::string conf = R"""(
  264. bind [::]:60053@lo
  265. server 127.0.0.1:62053
  266. log-num 0
  267. log-console yes
  268. log-level debug
  269. dualstack-ip-selection no
  270. cache-persist yes
  271. )""";
  272. conf += "cache-file " + cache_file;
  273. Defer
  274. {
  275. unlink(cache_file.c_str());
  276. };
  277. server_upstream.Start("udp://0.0.0.0:62053", [](struct smartdns::ServerRequestContext *request) {
  278. if (request->qtype == DNS_T_A) {
  279. smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4");
  280. return smartdns::SERVER_REQUEST_OK;
  281. }
  282. return smartdns::SERVER_REQUEST_SOA;
  283. });
  284. {
  285. smartdns::Server server;
  286. server.Start(conf);
  287. smartdns::Client client;
  288. ASSERT_TRUE(client.Query("a.com", 60053));
  289. std::cout << client.GetResult() << std::endl;
  290. ASSERT_EQ(client.GetAnswerNum(), 1);
  291. EXPECT_EQ(client.GetStatus(), "NOERROR");
  292. EXPECT_LT(client.GetQueryTime(), 100);
  293. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
  294. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  295. server.Stop();
  296. usleep(200 * 1000);
  297. }
  298. ASSERT_EQ(access(cache_file.c_str(), F_OK), 0);
  299. int fd = open(cache_file.c_str(), O_RDWR);
  300. ASSERT_NE(fd, -1);
  301. srandom(time(NULL));
  302. off_t file_size = lseek(fd, 0, SEEK_END);
  303. off_t offset = random() % (file_size - 300);
  304. std::cout << "try make corrupt at " << offset << ", file size: " << file_size << std::endl;
  305. lseek(fd, offset, SEEK_SET);
  306. for (int i = 0; i < 300; i++) {
  307. unsigned char c = random() % 256;
  308. write(fd, &c, 1);
  309. }
  310. close(fd);
  311. {
  312. smartdns::Server server;
  313. server.Start(conf);
  314. smartdns::Client client;
  315. ASSERT_TRUE(client.Query("a.com", 60053));
  316. std::cout << client.GetResult() << std::endl;
  317. ASSERT_EQ(client.GetAnswerNum(), 1);
  318. EXPECT_EQ(client.GetStatus(), "NOERROR");
  319. EXPECT_LT(client.GetQueryTime(), 100);
  320. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
  321. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  322. server.Stop();
  323. usleep(200 * 1000);
  324. }
  325. }