test-cache.cc 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. /*************************************************************************
  2. *
  3. * Copyright (C) 2018-2024 Ruilin Peng (Nick) <[email protected]>.
  4. *
  5. * smartdns is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * smartdns is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. */
  18. #include "client.h"
  19. #include "dns.h"
  20. #include "include/utils.h"
  21. #include "server.h"
  22. #include "gtest/gtest.h"
  23. #include <fcntl.h>
  24. #include <fstream>
  25. #include <sys/stat.h>
  26. #include <sys/types.h>
  27. /* clang-format off */
  28. #include "dns_cache.h"
  29. /* clang-format on */
  30. class Cache : public ::testing::Test
  31. {
  32. protected:
  33. void SetUp() override {}
  34. void TearDown() override {}
  35. };
  36. TEST_F(Cache, min)
  37. {
  38. smartdns::MockServer server_upstream;
  39. smartdns::Server server;
  40. server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
  41. std::string domain = request->domain;
  42. if (request->domain.length() == 0) {
  43. return smartdns::SERVER_REQUEST_ERROR;
  44. }
  45. if (request->qtype == DNS_T_A) {
  46. unsigned char addr[4] = {1, 2, 3, 4};
  47. dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  48. } else if (request->qtype == DNS_T_AAAA) {
  49. unsigned char addr[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
  50. dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  51. } else {
  52. return smartdns::SERVER_REQUEST_ERROR;
  53. }
  54. request->response_packet->head.rcode = DNS_RC_NOERROR;
  55. return smartdns::SERVER_REQUEST_OK;
  56. });
  57. server.Start(R"""(bind [::]:60053
  58. server 127.0.0.1:61053
  59. cache-size 1
  60. rr-ttl-min 1
  61. speed-check-mode none
  62. response-mode fastest-response
  63. )""");
  64. smartdns::Client client;
  65. ASSERT_TRUE(client.Query("a.com", 60053));
  66. std::cout << client.GetResult() << std::endl;
  67. ASSERT_EQ(client.GetAnswerNum(), 1);
  68. EXPECT_EQ(client.GetStatus(), "NOERROR");
  69. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  70. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 1);
  71. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  72. }
  73. TEST_F(Cache, max_reply_ttl)
  74. {
  75. smartdns::MockServer server_upstream;
  76. smartdns::Server server;
  77. server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
  78. std::string domain = request->domain;
  79. if (request->domain.length() == 0) {
  80. return smartdns::SERVER_REQUEST_ERROR;
  81. }
  82. if (request->qtype == DNS_T_A) {
  83. unsigned char addr[4] = {1, 2, 3, 4};
  84. dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  85. } else if (request->qtype == DNS_T_AAAA) {
  86. unsigned char addr[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
  87. dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  88. } else {
  89. return smartdns::SERVER_REQUEST_ERROR;
  90. }
  91. request->response_packet->head.rcode = DNS_RC_NOERROR;
  92. return smartdns::SERVER_REQUEST_OK;
  93. });
  94. server.Start(R"""(bind [::]:60053
  95. server 127.0.0.1:61053
  96. cache-size 1
  97. rr-ttl-min 600
  98. rr-ttl-reply-max 5
  99. speed-check-mode none
  100. response-mode fastest-response
  101. )""");
  102. smartdns::Client client;
  103. ASSERT_TRUE(client.Query("a.com", 60053));
  104. std::cout << client.GetResult() << std::endl;
  105. ASSERT_EQ(client.GetAnswerNum(), 1);
  106. EXPECT_EQ(client.GetStatus(), "NOERROR");
  107. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  108. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 5);
  109. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  110. sleep(1);
  111. ASSERT_TRUE(client.Query("a.com", 60053));
  112. std::cout << client.GetResult() << std::endl;
  113. ASSERT_EQ(client.GetAnswerNum(), 1);
  114. EXPECT_EQ(client.GetStatus(), "NOERROR");
  115. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  116. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 5);
  117. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  118. }
  119. TEST_F(Cache, max_reply_ttl_expired)
  120. {
  121. smartdns::MockServer server_upstream;
  122. smartdns::Server server;
  123. server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
  124. std::string domain = request->domain;
  125. if (request->domain.length() == 0) {
  126. return smartdns::SERVER_REQUEST_ERROR;
  127. }
  128. if (request->qtype == DNS_T_A) {
  129. unsigned char addr[4] = {1, 2, 3, 4};
  130. dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  131. } else if (request->qtype == DNS_T_AAAA) {
  132. unsigned char addr[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
  133. dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  134. } else {
  135. return smartdns::SERVER_REQUEST_ERROR;
  136. }
  137. request->response_packet->head.rcode = DNS_RC_NOERROR;
  138. return smartdns::SERVER_REQUEST_OK;
  139. });
  140. server.Start(R"""(bind [::]:60053
  141. server 127.0.0.1:61053
  142. cache-size 1
  143. rr-ttl-min 600
  144. rr-ttl-reply-max 6
  145. )""");
  146. smartdns::Client client;
  147. ASSERT_TRUE(client.Query("a.com", 60053));
  148. std::cout << client.GetResult() << std::endl;
  149. ASSERT_EQ(client.GetAnswerNum(), 1);
  150. EXPECT_EQ(client.GetStatus(), "NOERROR");
  151. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  152. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
  153. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  154. ASSERT_TRUE(client.Query("a.com", 60053));
  155. std::cout << client.GetResult() << std::endl;
  156. ASSERT_EQ(client.GetAnswerNum(), 1);
  157. EXPECT_EQ(client.GetStatus(), "NOERROR");
  158. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  159. EXPECT_GE(client.GetAnswer()[0].GetTTL(), 5);
  160. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  161. }
  162. TEST_F(Cache, prefetch)
  163. {
  164. smartdns::MockServer server_upstream;
  165. smartdns::MockServer server_upstream1;
  166. smartdns::MockServer server_upstream2;
  167. smartdns::Server server;
  168. server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
  169. if (request->qtype != DNS_T_A) {
  170. return smartdns::SERVER_REQUEST_SOA;
  171. }
  172. smartdns::MockServer::AddIP(request, request->domain.c_str(), "9.10.11.12", 611);
  173. return smartdns::SERVER_REQUEST_OK;
  174. });
  175. server_upstream1.Start("udp://0.0.0.0:62053", [](struct smartdns::ServerRequestContext *request) {
  176. if (request->qtype != DNS_T_A) {
  177. return smartdns::SERVER_REQUEST_SOA;
  178. }
  179. smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 611);
  180. return smartdns::SERVER_REQUEST_OK;
  181. });
  182. server_upstream2.Start("udp://0.0.0.0:63053", [](struct smartdns::ServerRequestContext *request) {
  183. if (request->qtype != DNS_T_A) {
  184. return smartdns::SERVER_REQUEST_SOA;
  185. }
  186. smartdns::MockServer::AddIP(request, request->domain.c_str(), "5.6.7.8", 611);
  187. return smartdns::SERVER_REQUEST_OK;
  188. });
  189. server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 100);
  190. server.MockPing(PING_TYPE_ICMP, "5.6.7.8", 60, 110);
  191. server.MockPing(PING_TYPE_ICMP, "9.10.11.12", 60, 110);
  192. server.Start(R"""(bind [::]:60053
  193. bind [::]:60153 -group g1
  194. server 127.0.0.1:61053
  195. server 127.0.0.1:62053 -group g1 -exclude-default-group
  196. server 127.0.0.1:63053 -group g2
  197. prefetch-domain yes
  198. rr-ttl-max 2
  199. serve-expired no
  200. )""");
  201. smartdns::Client client;
  202. ASSERT_TRUE(client.Query("a.com", 60053));
  203. std::cout << client.GetResult() << std::endl;
  204. ASSERT_EQ(client.GetAnswerNum(), 1);
  205. EXPECT_EQ(client.GetStatus(), "NOERROR");
  206. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  207. ASSERT_TRUE(client.Query("a.com", 60153));
  208. std::cout << client.GetResult() << std::endl;
  209. ASSERT_EQ(client.GetAnswerNum(), 1);
  210. EXPECT_EQ(client.GetStatus(), "NOERROR");
  211. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  212. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  213. sleep(1);
  214. ASSERT_TRUE(client.Query("a.com", 60053));
  215. std::cout << client.GetResult() << std::endl;
  216. ASSERT_EQ(client.GetAnswerNum(), 2);
  217. EXPECT_EQ(client.GetStatus(), "NOERROR");
  218. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  219. sleep(1);
  220. ASSERT_TRUE(client.Query("a.com", 60153));
  221. std::cout << client.GetResult() << std::endl;
  222. ASSERT_EQ(client.GetAnswerNum(), 1);
  223. EXPECT_EQ(client.GetStatus(), "NOERROR");
  224. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  225. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  226. }
  227. TEST_F(Cache, nocache)
  228. {
  229. smartdns::MockServer server_upstream;
  230. smartdns::Server server;
  231. server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
  232. std::string domain = request->domain;
  233. if (request->domain.length() == 0) {
  234. return smartdns::SERVER_REQUEST_ERROR;
  235. }
  236. usleep(15000);
  237. if (request->qtype == DNS_T_A) {
  238. unsigned char addr[4] = {1, 2, 3, 4};
  239. dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  240. } else if (request->qtype == DNS_T_AAAA) {
  241. unsigned char addr[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
  242. dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  243. } else {
  244. return smartdns::SERVER_REQUEST_ERROR;
  245. }
  246. request->response_packet->head.rcode = DNS_RC_NOERROR;
  247. return smartdns::SERVER_REQUEST_OK;
  248. });
  249. server.Start(R"""(bind [::]:60053
  250. server 127.0.0.1:61053
  251. cache-size 100
  252. rr-ttl-min 600
  253. rr-ttl-reply-max 5
  254. domain-rules /a.com/ --no-cache
  255. )""");
  256. smartdns::Client client;
  257. ASSERT_TRUE(client.Query("a.com", 60053));
  258. std::cout << client.GetResult() << std::endl;
  259. ASSERT_EQ(client.GetAnswerNum(), 1);
  260. EXPECT_EQ(client.GetStatus(), "NOERROR");
  261. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  262. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
  263. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  264. ASSERT_TRUE(client.Query("a.com", 60053));
  265. EXPECT_GT(client.GetQueryTime(), 10);
  266. std::cout << client.GetResult() << std::endl;
  267. ASSERT_EQ(client.GetAnswerNum(), 1);
  268. EXPECT_EQ(client.GetStatus(), "NOERROR");
  269. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  270. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
  271. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  272. }
  273. TEST_F(Cache, save_file)
  274. {
  275. smartdns::MockServer server_upstream;
  276. auto cache_file = "/tmp/smartdns_cache." + smartdns::GenerateRandomString(10);
  277. std::string conf = R"""(
  278. bind [::]:60053@lo
  279. server 127.0.0.1:62053
  280. cache-persist yes
  281. dualstack-ip-selection no
  282. )""";
  283. conf += "cache-file " + cache_file;
  284. Defer
  285. {
  286. unlink(cache_file.c_str());
  287. };
  288. server_upstream.Start("udp://0.0.0.0:62053", [](struct smartdns::ServerRequestContext *request) {
  289. if (request->qtype == DNS_T_A) {
  290. smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4");
  291. return smartdns::SERVER_REQUEST_OK;
  292. }
  293. return smartdns::SERVER_REQUEST_SOA;
  294. });
  295. {
  296. smartdns::Server server;
  297. server.Start(conf);
  298. smartdns::Client client;
  299. ASSERT_TRUE(client.Query("a.com", 60053));
  300. std::cout << client.GetResult() << std::endl;
  301. ASSERT_EQ(client.GetAnswerNum(), 1);
  302. EXPECT_EQ(client.GetStatus(), "NOERROR");
  303. EXPECT_LT(client.GetQueryTime(), 100);
  304. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
  305. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  306. server.Stop();
  307. usleep(200 * 1000);
  308. }
  309. ASSERT_EQ(access(cache_file.c_str(), F_OK), 0);
  310. std::fstream fs(cache_file, std::ios::in);
  311. struct dns_cache_file head;
  312. memset(&head, 0, sizeof(head));
  313. fs.read((char *)&head, sizeof(head));
  314. EXPECT_EQ(head.magic, MAGIC_NUMBER);
  315. EXPECT_EQ(head.cache_number, 1);
  316. }
  317. TEST_F(Cache, corrupt_file)
  318. {
  319. smartdns::MockServer server_upstream;
  320. auto cache_file = "/tmp/smartdns_cache." + smartdns::GenerateRandomString(10);
  321. std::string conf = R"""(
  322. bind [::]:60053@lo
  323. server 127.0.0.1:62053
  324. dualstack-ip-selection no
  325. cache-persist yes
  326. )""";
  327. conf += "cache-file " + cache_file;
  328. Defer
  329. {
  330. unlink(cache_file.c_str());
  331. };
  332. server_upstream.Start("udp://0.0.0.0:62053", [](struct smartdns::ServerRequestContext *request) {
  333. if (request->qtype == DNS_T_A) {
  334. smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4");
  335. return smartdns::SERVER_REQUEST_OK;
  336. }
  337. return smartdns::SERVER_REQUEST_SOA;
  338. });
  339. {
  340. smartdns::Server server;
  341. server.Start(conf);
  342. smartdns::Client client;
  343. ASSERT_TRUE(client.Query("a.com", 60053));
  344. std::cout << client.GetResult() << std::endl;
  345. ASSERT_EQ(client.GetAnswerNum(), 1);
  346. EXPECT_EQ(client.GetStatus(), "NOERROR");
  347. EXPECT_LT(client.GetQueryTime(), 100);
  348. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
  349. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  350. server.Stop();
  351. usleep(200 * 1000);
  352. }
  353. ASSERT_EQ(access(cache_file.c_str(), F_OK), 0);
  354. int fd = open(cache_file.c_str(), O_RDWR);
  355. ASSERT_NE(fd, -1);
  356. srandom(time(NULL));
  357. off_t file_size = lseek(fd, 0, SEEK_END);
  358. off_t offset = random() % (file_size - 300);
  359. std::cout << "try make corrupt at " << offset << ", file size: " << file_size << std::endl;
  360. lseek(fd, offset, SEEK_SET);
  361. for (int i = 0; i < 300; i++) {
  362. unsigned char c = random() % 256;
  363. write(fd, &c, 1);
  364. }
  365. close(fd);
  366. {
  367. smartdns::Server server;
  368. server.Start(conf);
  369. smartdns::Client client;
  370. ASSERT_TRUE(client.Query("a.com", 60053));
  371. std::cout << client.GetResult() << std::endl;
  372. ASSERT_EQ(client.GetAnswerNum(), 1);
  373. EXPECT_EQ(client.GetStatus(), "NOERROR");
  374. EXPECT_LT(client.GetQueryTime(), 100);
  375. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
  376. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  377. server.Stop();
  378. usleep(200 * 1000);
  379. }
  380. }
  381. TEST_F(Cache, cname)
  382. {
  383. smartdns::MockServer server_upstream;
  384. smartdns::Server server;
  385. server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
  386. std::string domain = request->domain;
  387. std::string cname = "cname." + domain;
  388. if (request->qtype != DNS_T_A) {
  389. return smartdns::SERVER_REQUEST_SOA;
  390. }
  391. unsigned char addr[4] = {1, 2, 3, 4};
  392. dns_add_domain(request->response_packet, domain.c_str(), DNS_T_A, DNS_C_IN);
  393. dns_add_CNAME(request->response_packet, DNS_RRS_AN, domain.c_str(), 300, cname.c_str());
  394. dns_add_A(request->response_packet, DNS_RRS_AN, cname.c_str(), 300, addr);
  395. request->response_packet->head.rcode = DNS_RC_NOERROR;
  396. return smartdns::SERVER_REQUEST_OK;
  397. });
  398. server.Start(R"""(bind [::]:60053
  399. server 127.0.0.1:61053
  400. cache-size 100
  401. )""");
  402. smartdns::Client client;
  403. ASSERT_TRUE(client.Query("a.com A", 60053));
  404. std::cout << client.GetResult() << std::endl;
  405. ASSERT_EQ(client.GetAnswerNum(), 2);
  406. EXPECT_EQ(client.GetStatus(), "NOERROR");
  407. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  408. EXPECT_GE(client.GetAnswer()[0].GetTTL(), 3);
  409. EXPECT_EQ(client.GetAnswer()[0].GetData(), "cname.a.com.");
  410. EXPECT_EQ(client.GetAnswer()[1].GetName(), "cname.a.com");
  411. EXPECT_GE(client.GetAnswer()[1].GetTTL(), 3);
  412. EXPECT_EQ(client.GetAnswer()[1].GetData(), "1.2.3.4");
  413. ASSERT_TRUE(client.Query("cname.a.com A", 60053));
  414. std::cout << client.GetResult() << std::endl;
  415. ASSERT_EQ(client.GetAnswerNum(), 1);
  416. EXPECT_EQ(client.GetStatus(), "NOERROR");
  417. EXPECT_EQ(client.GetAnswer()[0].GetName(), "cname.a.com");
  418. EXPECT_GE(client.GetAnswer()[0].GetTTL(), 590);
  419. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  420. }