test-stress.cc 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. #include "server.h"
  2. #include "smartdns/dns.h"
  3. #include "gtest/gtest.h"
  4. #include <atomic>
  5. #include <chrono>
  6. #include <thread>
  7. #include <vector>
  8. #include <string>
  9. #include <cstdlib>
  10. #include <sys/socket.h>
  11. #include <netinet/in.h>
  12. #include <arpa/inet.h>
  13. #include <cstring>
  14. // Helper function to get environment variable with default value
  15. int get_env_int(const char* name, int default_value) {
  16. const char* value = std::getenv(name);
  17. if (value) {
  18. return std::atoi(value);
  19. }
  20. return default_value;
  21. }
  22. // Simple UDP DNS query function
  23. bool udp_dns_query(const std::string& domain, int port) {
  24. int sock = socket(AF_INET, SOCK_DGRAM, 0);
  25. if (sock < 0) return false;
  26. struct sockaddr_in addr;
  27. memset(&addr, 0, sizeof(addr));
  28. addr.sin_family = AF_INET;
  29. addr.sin_port = htons(port);
  30. inet_pton(AF_INET, "127.0.0.1", &addr.sin_addr);
  31. // Build simple DNS query for A record
  32. unsigned char query[256];
  33. memset(query, 0, sizeof(query));
  34. // Random ID
  35. query[0] = rand() % 256;
  36. query[1] = rand() % 256;
  37. // Flags: recursion desired
  38. query[2] = 0x01;
  39. query[3] = 0x00;
  40. // QDCOUNT = 1
  41. query[4] = 0x00;
  42. query[5] = 0x01;
  43. // Encode domain name
  44. int pos = 12;
  45. size_t start = 0;
  46. size_t dot_pos = domain.find('.');
  47. while (dot_pos != std::string::npos) {
  48. std::string label = domain.substr(start, dot_pos - start);
  49. query[pos++] = label.size();
  50. memcpy(&query[pos], label.c_str(), label.size());
  51. pos += label.size();
  52. start = dot_pos + 1;
  53. dot_pos = domain.find('.', start);
  54. }
  55. std::string label = domain.substr(start);
  56. query[pos++] = label.size();
  57. memcpy(&query[pos], label.c_str(), label.size());
  58. pos += label.size();
  59. query[pos++] = 0; // null terminator
  60. // QTYPE: A (1)
  61. query[pos++] = 0x00;
  62. query[pos++] = 0x01;
  63. // QCLASS: IN (1)
  64. query[pos++] = 0x00;
  65. query[pos++] = 0x01;
  66. // Send query
  67. if (sendto(sock, query, pos, 0, (struct sockaddr*)&addr, sizeof(addr)) < 0) {
  68. close(sock);
  69. return false;
  70. }
  71. // Receive response
  72. unsigned char response[512];
  73. socklen_t addr_len = sizeof(addr);
  74. int recv_len = recvfrom(sock, response, sizeof(response), 0, (struct sockaddr*)&addr, &addr_len);
  75. close(sock);
  76. if (recv_len < 12) return false;
  77. // Check RCODE (last 4 bits of flags)
  78. if ((response[3] & 0x0F) == 0) return true; // NOERROR
  79. return false;
  80. }
  81. // Protocol stress test configuration
  82. struct ProtocolConfig {
  83. std::string name;
  84. std::string bind_config;
  85. std::string server_config;
  86. std::string upstream_bind_config;
  87. };
  88. class Stress : public ::testing::TestWithParam<ProtocolConfig> {
  89. protected:
  90. void SetUp() override {
  91. // Common setup if needed
  92. }
  93. void TearDown() override {
  94. // Common cleanup if needed
  95. }
  96. };
  97. // Define protocol configurations
  98. const ProtocolConfig protocols[] = {
  99. {
  100. "UDP",
  101. "bind [::]:61053",
  102. "server udp://127.0.0.1:60053",
  103. "bind [::]:60053"
  104. },
  105. {
  106. "TCP",
  107. "bind [::]:61053",
  108. "server tcp://127.0.0.1:60053",
  109. "bind-tcp [::]:60053"
  110. },
  111. {
  112. "TLS",
  113. "bind [::]:61053",
  114. "server tls://127.0.0.1:60053 -no-check-certificate",
  115. "bind-tls [::]:60053"
  116. },
  117. {
  118. "HTTP2",
  119. "bind [::]:61053",
  120. "server https://127.0.0.1:60053/dns-query -no-check-certificate -alpn h2",
  121. "bind-https [::]:60053 -alpn h2"
  122. },
  123. {
  124. "HTTP1_1",
  125. "bind [::]:61053",
  126. "server https://127.0.0.1:60053/dns-query -no-check-certificate -alpn http/1.1",
  127. "bind-https [::]:60053 -alpn http/1.1"
  128. }
  129. };
  130. // Test stress for each protocol: 100 clients, each making 100 queries
  131. TEST_P(Stress, Query) {
  132. const auto& config = GetParam();
  133. smartdns::Server upstream_server;
  134. smartdns::Server main_server;
  135. // Start upstream server (second layer) that returns fixed IP and mocks ping
  136. upstream_server.Start(config.upstream_bind_config + R"""(
  137. address /test.com/192.168.1.100
  138. address /example.com/192.168.1.101
  139. address /domain.com/192.168.1.102
  140. )""");
  141. // Mock ping responses for the IPs
  142. main_server.MockPing(PING_TYPE_ICMP, "192.168.1.100", 60, 10);
  143. main_server.MockPing(PING_TYPE_ICMP, "192.168.1.101", 60, 5);
  144. main_server.MockPing(PING_TYPE_ICMP, "192.168.1.102", 60, 20);
  145. // Start main server that forwards to upstream via specified protocol
  146. main_server.Start(config.bind_config + "\n" + config.server_config + R"""(
  147. cache-size 0
  148. speed-check-mode ping
  149. )""");
  150. // Wait for servers to be ready
  151. std::this_thread::sleep_for(std::chrono::milliseconds(500));
  152. std::vector<std::thread> client_threads;
  153. std::atomic<int> total_queries{0};
  154. std::atomic<int> success_count{0};
  155. std::atomic<int> failure_count{0};
  156. std::atomic<bool> stop_all_tasks{false}; // Flag to control all tasks exit
  157. const int num_clients = get_env_int("SMARTDNS_STRESS_CLIENTS", 1);
  158. const int queries_per_client = get_env_int("SMARTDNS_STRESS_QUERIES", 200);
  159. auto start_time = std::chrono::steady_clock::now();
  160. // Launch 100 client threads, each making 100 queries
  161. for (int client_id = 0; client_id < num_clients; client_id++) {
  162. client_threads.emplace_back([client_id, &total_queries, &success_count, &failure_count, &stop_all_tasks, queries_per_client]() {
  163. for (int query_id = 0; query_id < queries_per_client; query_id++) {
  164. // Check if stop flag is set, terminate all tasks
  165. if (stop_all_tasks.load()) {
  166. return;
  167. }
  168. std::string domain;
  169. // Rotate through different domains to test various responses
  170. switch (query_id % 3) {
  171. case 0:
  172. domain = "test.com";
  173. break;
  174. case 1:
  175. domain = "example.com";
  176. break;
  177. case 2:
  178. domain = "domain.com";
  179. break;
  180. }
  181. total_queries++;
  182. if (udp_dns_query(domain, 61053)) {
  183. success_count++;
  184. } else {
  185. failure_count++;
  186. stop_all_tasks.store(true); // Set flag to stop all tasks
  187. return;
  188. }
  189. }
  190. });
  191. }
  192. // Wait for all client threads to complete
  193. for (auto& t : client_threads) {
  194. t.join();
  195. }
  196. auto end_time = std::chrono::steady_clock::now();
  197. auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
  198. int expected_total = num_clients * queries_per_client;
  199. double qps = (expected_total * 1000.0) / duration.count();
  200. std::cout << config.name << " Stress Test Results:" << std::endl;
  201. std::cout << " Total Queries: " << total_queries.load() << " (expected: " << expected_total << ")" << std::endl;
  202. std::cout << " Success: " << success_count.load() << std::endl;
  203. std::cout << " Failure: " << failure_count.load() << std::endl;
  204. std::cout << " Duration: " << duration.count() << "ms" << std::endl;
  205. std::cout << " QPS: " << qps << std::endl;
  206. double success_rate = total_queries.load() > 0 ? (success_count.load() * 100.0 / total_queries.load()) : 0.0;
  207. std::cout << " Success Rate: " << success_rate << "%" << std::endl;
  208. // Assertions
  209. EXPECT_FALSE(stop_all_tasks.load()); // No failures should occur, all tasks should complete
  210. EXPECT_EQ(total_queries.load(), expected_total);
  211. EXPECT_EQ(success_count.load(), expected_total);
  212. EXPECT_EQ(failure_count.load(), 0);
  213. }
  214. // Instantiate the test for each protocol
  215. INSTANTIATE_TEST_SUITE_P(, Stress,
  216. ::testing::ValuesIn(protocols),
  217. [](const ::testing::TestParamInfo<ProtocolConfig>& info) {
  218. return info.param.name;
  219. });
  220. // filter to run specific tests
  221. // ./test.bin --gtest_filter="Stress.Query/UDP"
  222. // ./test.bin --gtest_filter="Stress.Query/TCP"