test-cache.cc 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. /*************************************************************************
  2. *
  3. * Copyright (C) 2018-2023 Ruilin Peng (Nick) <[email protected]>.
  4. *
  5. * smartdns is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * smartdns is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. */
  18. #include "client.h"
  19. #include "dns.h"
  20. #include "include/utils.h"
  21. #include "server.h"
  22. #include "gtest/gtest.h"
  23. #include <fcntl.h>
  24. #include <fstream>
  25. #include <sys/stat.h>
  26. #include <sys/types.h>
  27. /* clang-format off */
  28. #include "dns_cache.h"
  29. /* clang-format on */
  30. class Cache : public ::testing::Test
  31. {
  32. protected:
  33. void SetUp() override {}
  34. void TearDown() override {}
  35. };
  36. TEST_F(Cache, min)
  37. {
  38. smartdns::MockServer server_upstream;
  39. smartdns::Server server;
  40. server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
  41. std::string domain = request->domain;
  42. if (request->domain.length() == 0) {
  43. return smartdns::SERVER_REQUEST_ERROR;
  44. }
  45. if (request->qtype == DNS_T_A) {
  46. unsigned char addr[4] = {1, 2, 3, 4};
  47. dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  48. } else if (request->qtype == DNS_T_AAAA) {
  49. unsigned char addr[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
  50. dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  51. } else {
  52. return smartdns::SERVER_REQUEST_ERROR;
  53. }
  54. request->response_packet->head.rcode = DNS_RC_NOERROR;
  55. return smartdns::SERVER_REQUEST_OK;
  56. });
  57. server.Start(R"""(bind [::]:60053
  58. server 127.0.0.1:61053
  59. log-num 0
  60. cache-size 1
  61. rr-ttl-min 1
  62. speed-check-mode none
  63. response-mode fastest-response
  64. log-console yes
  65. log-level debug
  66. cache-persist no)""");
  67. smartdns::Client client;
  68. ASSERT_TRUE(client.Query("a.com", 60053));
  69. std::cout << client.GetResult() << std::endl;
  70. ASSERT_EQ(client.GetAnswerNum(), 1);
  71. EXPECT_EQ(client.GetStatus(), "NOERROR");
  72. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  73. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 1);
  74. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  75. }
  76. TEST_F(Cache, max_reply_ttl)
  77. {
  78. smartdns::MockServer server_upstream;
  79. smartdns::Server server;
  80. server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
  81. std::string domain = request->domain;
  82. if (request->domain.length() == 0) {
  83. return smartdns::SERVER_REQUEST_ERROR;
  84. }
  85. if (request->qtype == DNS_T_A) {
  86. unsigned char addr[4] = {1, 2, 3, 4};
  87. dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  88. } else if (request->qtype == DNS_T_AAAA) {
  89. unsigned char addr[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
  90. dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  91. } else {
  92. return smartdns::SERVER_REQUEST_ERROR;
  93. }
  94. request->response_packet->head.rcode = DNS_RC_NOERROR;
  95. return smartdns::SERVER_REQUEST_OK;
  96. });
  97. server.Start(R"""(bind [::]:60053
  98. server 127.0.0.1:61053
  99. log-num 0
  100. cache-size 1
  101. rr-ttl-min 600
  102. rr-ttl-reply-max 5
  103. speed-check-mode none
  104. response-mode fastest-response
  105. log-console yes
  106. log-level debug
  107. cache-persist no)""");
  108. smartdns::Client client;
  109. ASSERT_TRUE(client.Query("a.com", 60053));
  110. std::cout << client.GetResult() << std::endl;
  111. ASSERT_EQ(client.GetAnswerNum(), 1);
  112. EXPECT_EQ(client.GetStatus(), "NOERROR");
  113. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  114. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 5);
  115. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  116. sleep(1);
  117. ASSERT_TRUE(client.Query("a.com", 60053));
  118. std::cout << client.GetResult() << std::endl;
  119. ASSERT_EQ(client.GetAnswerNum(), 1);
  120. EXPECT_EQ(client.GetStatus(), "NOERROR");
  121. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  122. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 5);
  123. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  124. }
  125. TEST_F(Cache, max_reply_ttl_expired)
  126. {
  127. smartdns::MockServer server_upstream;
  128. smartdns::Server server;
  129. server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
  130. std::string domain = request->domain;
  131. if (request->domain.length() == 0) {
  132. return smartdns::SERVER_REQUEST_ERROR;
  133. }
  134. if (request->qtype == DNS_T_A) {
  135. unsigned char addr[4] = {1, 2, 3, 4};
  136. dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  137. } else if (request->qtype == DNS_T_AAAA) {
  138. unsigned char addr[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
  139. dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  140. } else {
  141. return smartdns::SERVER_REQUEST_ERROR;
  142. }
  143. request->response_packet->head.rcode = DNS_RC_NOERROR;
  144. return smartdns::SERVER_REQUEST_OK;
  145. });
  146. server.Start(R"""(bind [::]:60053
  147. server 127.0.0.1:61053
  148. log-num 0
  149. cache-size 1
  150. rr-ttl-min 600
  151. rr-ttl-reply-max 6
  152. log-console yes
  153. log-level debug
  154. cache-persist no)""");
  155. smartdns::Client client;
  156. ASSERT_TRUE(client.Query("a.com", 60053));
  157. std::cout << client.GetResult() << std::endl;
  158. ASSERT_EQ(client.GetAnswerNum(), 1);
  159. EXPECT_EQ(client.GetStatus(), "NOERROR");
  160. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  161. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
  162. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  163. ASSERT_TRUE(client.Query("a.com", 60053));
  164. std::cout << client.GetResult() << std::endl;
  165. ASSERT_EQ(client.GetAnswerNum(), 1);
  166. EXPECT_EQ(client.GetStatus(), "NOERROR");
  167. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  168. EXPECT_GE(client.GetAnswer()[0].GetTTL(), 5);
  169. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  170. }
  171. TEST_F(Cache, nocache)
  172. {
  173. smartdns::MockServer server_upstream;
  174. smartdns::Server server;
  175. server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
  176. std::string domain = request->domain;
  177. if (request->domain.length() == 0) {
  178. return smartdns::SERVER_REQUEST_ERROR;
  179. }
  180. usleep(15000);
  181. if (request->qtype == DNS_T_A) {
  182. unsigned char addr[4] = {1, 2, 3, 4};
  183. dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  184. } else if (request->qtype == DNS_T_AAAA) {
  185. unsigned char addr[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
  186. dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
  187. } else {
  188. return smartdns::SERVER_REQUEST_ERROR;
  189. }
  190. request->response_packet->head.rcode = DNS_RC_NOERROR;
  191. return smartdns::SERVER_REQUEST_OK;
  192. });
  193. server.Start(R"""(bind [::]:60053
  194. server 127.0.0.1:61053
  195. log-num 0
  196. cache-size 100
  197. rr-ttl-min 600
  198. rr-ttl-reply-max 5
  199. log-console yes
  200. log-level debug
  201. domain-rules /a.com/ --no-cache
  202. cache-persist no)""");
  203. smartdns::Client client;
  204. ASSERT_TRUE(client.Query("a.com", 60053));
  205. std::cout << client.GetResult() << std::endl;
  206. ASSERT_EQ(client.GetAnswerNum(), 1);
  207. EXPECT_EQ(client.GetStatus(), "NOERROR");
  208. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  209. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
  210. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  211. ASSERT_TRUE(client.Query("a.com", 60053));
  212. EXPECT_GT(client.GetQueryTime(), 10);
  213. std::cout << client.GetResult() << std::endl;
  214. ASSERT_EQ(client.GetAnswerNum(), 1);
  215. EXPECT_EQ(client.GetStatus(), "NOERROR");
  216. EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
  217. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
  218. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  219. }
  220. TEST_F(Cache, save_file)
  221. {
  222. smartdns::MockServer server_upstream;
  223. auto cache_file = "/tmp/smartdns_cache." + smartdns::GenerateRandomString(10);
  224. std::string conf = R"""(
  225. bind [::]:60053@lo
  226. server 127.0.0.1:62053
  227. log-num 0
  228. log-console yes
  229. log-level debug
  230. cache-persist yes
  231. dualstack-ip-selection no
  232. )""";
  233. conf += "cache-file " + cache_file;
  234. Defer
  235. {
  236. unlink(cache_file.c_str());
  237. };
  238. server_upstream.Start("udp://0.0.0.0:62053", [](struct smartdns::ServerRequestContext *request) {
  239. if (request->qtype == DNS_T_A) {
  240. smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4");
  241. return smartdns::SERVER_REQUEST_OK;
  242. }
  243. return smartdns::SERVER_REQUEST_SOA;
  244. });
  245. {
  246. smartdns::Server server;
  247. server.Start(conf);
  248. smartdns::Client client;
  249. ASSERT_TRUE(client.Query("a.com", 60053));
  250. std::cout << client.GetResult() << std::endl;
  251. ASSERT_EQ(client.GetAnswerNum(), 1);
  252. EXPECT_EQ(client.GetStatus(), "NOERROR");
  253. EXPECT_LT(client.GetQueryTime(), 100);
  254. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
  255. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  256. server.Stop();
  257. usleep(200 * 1000);
  258. }
  259. ASSERT_EQ(access(cache_file.c_str(), F_OK), 0);
  260. std::fstream fs(cache_file, std::ios::in);
  261. struct dns_cache_file head;
  262. memset(&head, 0, sizeof(head));
  263. fs.read((char *)&head, sizeof(head));
  264. EXPECT_EQ(head.magic, MAGIC_NUMBER);
  265. EXPECT_EQ(head.cache_number, 1);
  266. }
  267. TEST_F(Cache, corrupt_file)
  268. {
  269. smartdns::MockServer server_upstream;
  270. auto cache_file = "/tmp/smartdns_cache." + smartdns::GenerateRandomString(10);
  271. std::string conf = R"""(
  272. bind [::]:60053@lo
  273. server 127.0.0.1:62053
  274. log-num 0
  275. log-console yes
  276. log-level debug
  277. dualstack-ip-selection no
  278. cache-persist yes
  279. )""";
  280. conf += "cache-file " + cache_file;
  281. Defer
  282. {
  283. unlink(cache_file.c_str());
  284. };
  285. server_upstream.Start("udp://0.0.0.0:62053", [](struct smartdns::ServerRequestContext *request) {
  286. if (request->qtype == DNS_T_A) {
  287. smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4");
  288. return smartdns::SERVER_REQUEST_OK;
  289. }
  290. return smartdns::SERVER_REQUEST_SOA;
  291. });
  292. {
  293. smartdns::Server server;
  294. server.Start(conf);
  295. smartdns::Client client;
  296. ASSERT_TRUE(client.Query("a.com", 60053));
  297. std::cout << client.GetResult() << std::endl;
  298. ASSERT_EQ(client.GetAnswerNum(), 1);
  299. EXPECT_EQ(client.GetStatus(), "NOERROR");
  300. EXPECT_LT(client.GetQueryTime(), 100);
  301. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
  302. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  303. server.Stop();
  304. usleep(200 * 1000);
  305. }
  306. ASSERT_EQ(access(cache_file.c_str(), F_OK), 0);
  307. int fd = open(cache_file.c_str(), O_RDWR);
  308. ASSERT_NE(fd, -1);
  309. srandom(time(NULL));
  310. off_t file_size = lseek(fd, 0, SEEK_END);
  311. off_t offset = random() % (file_size - 300);
  312. std::cout << "try make corrupt at " << offset << ", file size: " << file_size << std::endl;
  313. lseek(fd, offset, SEEK_SET);
  314. for (int i = 0; i < 300; i++) {
  315. unsigned char c = random() % 256;
  316. write(fd, &c, 1);
  317. }
  318. close(fd);
  319. {
  320. smartdns::Server server;
  321. server.Start(conf);
  322. smartdns::Client client;
  323. ASSERT_TRUE(client.Query("a.com", 60053));
  324. std::cout << client.GetResult() << std::endl;
  325. ASSERT_EQ(client.GetAnswerNum(), 1);
  326. EXPECT_EQ(client.GetStatus(), "NOERROR");
  327. EXPECT_LT(client.GetQueryTime(), 100);
  328. EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
  329. EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
  330. server.Stop();
  331. usleep(200 * 1000);
  332. }
  333. }