test_provider_callback.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. # coding=utf-8
  2. """
  3. Unit tests for CallbackProvider
  4. @author: GitHub Copilot
  5. """
  6. import os
  7. import ssl
  8. import logging
  9. import random
  10. from time import sleep
  11. from base_test import BaseProviderTestCase, unittest, patch
  12. from ddns.provider.callback import CallbackProvider
  13. class TestCallbackProvider(BaseProviderTestCase):
  14. """Test cases for CallbackProvider"""
  15. def setUp(self):
  16. """Set up test fixtures"""
  17. super(TestCallbackProvider, self).setUp()
  18. self.auth_id = "https://example.com/callback?domain=__DOMAIN__&ip=__IP__"
  19. self.auth_token = "" # Use empty string instead of None for auth_token
  20. def test_init_with_basic_config(self):
  21. """Test CallbackProvider initialization with basic configuration"""
  22. provider = CallbackProvider(self.auth_id, self.auth_token)
  23. self.assertEqual(provider.auth_id, self.auth_id)
  24. self.assertEqual(provider.auth_token, self.auth_token)
  25. self.assertFalse(provider.decode_response)
  26. def test_init_with_token_config(self):
  27. """Test CallbackProvider initialization with token configuration"""
  28. auth_token = '{"api_key": "__DOMAIN__", "value": "__IP__"}'
  29. provider = CallbackProvider(self.auth_id, auth_token)
  30. self.assertEqual(provider.auth_token, auth_token)
  31. def test_validate_success(self):
  32. """Test _validate method with valid configuration"""
  33. provider = CallbackProvider(self.auth_id, self.auth_token)
  34. # Should not raise any exception since we have a valid auth_id
  35. provider._validate()
  36. def test_validate_failure_no_id(self):
  37. """Test _validate method with missing id"""
  38. # _validate is called in __init__, so we need to test it directly
  39. with self.assertRaises(ValueError) as cm:
  40. CallbackProvider(None, self.auth_token) # type: ignore
  41. self.assertIn("id must be configured", str(cm.exception))
  42. def test_validate_failure_empty_id(self):
  43. """Test _validate method with empty id"""
  44. # _validate is called in __init__, so we need to test it directly
  45. with self.assertRaises(ValueError) as cm:
  46. CallbackProvider("", self.auth_token)
  47. self.assertIn("id must be configured", str(cm.exception))
  48. def test_replace_vars_basic(self):
  49. """Test _replace_vars method with basic replacements"""
  50. provider = CallbackProvider(self.auth_id, self.auth_token)
  51. test_str = "Hello __NAME__, your IP is __IP__"
  52. mapping = {"__NAME__": "World", "__IP__": "192.168.1.1"}
  53. result = provider._replace_vars(test_str, mapping)
  54. expected = "Hello World, your IP is 192.168.1.1"
  55. self.assertEqual(result, expected)
  56. def test_replace_vars_no_matches(self):
  57. """Test _replace_vars method with no matching variables"""
  58. provider = CallbackProvider(self.auth_id, self.auth_token)
  59. test_str = "No variables here"
  60. mapping = {"__NAME__": "World"}
  61. result = provider._replace_vars(test_str, mapping)
  62. self.assertEqual(result, test_str)
  63. def test_replace_vars_partial_matches(self):
  64. """Test _replace_vars method with partial matches"""
  65. provider = CallbackProvider(self.auth_id, self.auth_token)
  66. test_str = "__DOMAIN__ and __UNKNOWN__ and __IP__"
  67. mapping = {"__DOMAIN__": "example.com", "__IP__": "1.2.3.4"}
  68. result = provider._replace_vars(test_str, mapping)
  69. expected = "example.com and __UNKNOWN__ and 1.2.3.4"
  70. self.assertEqual(result, expected)
  71. def test_replace_vars_empty_string(self):
  72. """Test _replace_vars method with empty string"""
  73. provider = CallbackProvider(self.auth_id, self.auth_token)
  74. result = provider._replace_vars("", {"__TEST__": "value"})
  75. self.assertEqual(result, "")
  76. def test_replace_vars_empty_mapping(self):
  77. """Test _replace_vars method with empty mapping"""
  78. provider = CallbackProvider(self.auth_id, self.auth_token)
  79. test_str = "__DOMAIN__ test"
  80. result = provider._replace_vars(test_str, {})
  81. self.assertEqual(result, test_str)
  82. def test_replace_vars_none_values(self):
  83. """Test _replace_vars method with None values (should convert to string)"""
  84. provider = CallbackProvider(self.auth_id, self.auth_token)
  85. test_str = "TTL: __TTL__, Line: __LINE__"
  86. mapping = {"__TTL__": None, "__LINE__": None}
  87. result = provider._replace_vars(test_str, mapping)
  88. expected = "TTL: None, Line: None"
  89. self.assertEqual(result, expected)
  90. def test_replace_vars_numeric_values(self):
  91. """Test _replace_vars method with numeric values (should convert to string)"""
  92. provider = CallbackProvider(self.auth_id, self.auth_token)
  93. test_str = "Port: __PORT__, TTL: __TTL__"
  94. mapping = {"__PORT__": 8080, "__TTL__": 300}
  95. result = provider._replace_vars(test_str, mapping)
  96. expected = "Port: 8080, TTL: 300"
  97. self.assertEqual(result, expected)
  98. @patch("ddns.provider.callback.time")
  99. @patch.object(CallbackProvider, "_http")
  100. def test_set_record_get_method(self, mock_http, mock_time):
  101. """Test set_record method using GET method (no token)"""
  102. mock_time.return_value = 1634567890.123
  103. mock_http.return_value = "Success"
  104. provider = CallbackProvider(self.auth_id, None) # type: ignore
  105. result = provider.set_record("example.com", "192.168.1.1", "A", 300, "default")
  106. # Verify the result
  107. self.assertTrue(result)
  108. # Verify _http was called with correct parameters
  109. mock_http.assert_called_once()
  110. args, kwargs = mock_http.call_args
  111. self.assertEqual(args[0], "GET") # method # Check that URL contains replaced variables
  112. url = args[1]
  113. self.assertIn("example.com", url)
  114. self.assertIn("192.168.1.1", url)
  115. @patch("ddns.provider.callback.time")
  116. @patch.object(CallbackProvider, "_http")
  117. def test_set_record_post_method_dict_token(self, mock_http, mock_time):
  118. """Test set_record method using POST method with dict token"""
  119. mock_time.return_value = 1634567890.123
  120. mock_http.return_value = "Success"
  121. auth_token = {"api_key": "test_key", "domain": "__DOMAIN__", "ip": "__IP__"}
  122. provider = CallbackProvider(self.auth_id, auth_token) # type: ignore
  123. result = provider.set_record("example.com", "192.168.1.1", "A", 300, "default")
  124. # Verify the result
  125. self.assertTrue(result) # Verify _http was called with correct parameters
  126. mock_http.assert_called_once()
  127. args, kwargs = mock_http.call_args
  128. self.assertEqual(args[0], "POST") # method
  129. # URL should be replaced with actual values even for POST
  130. url = args[1]
  131. self.assertIn("example.com", url)
  132. self.assertIn("192.168.1.1", url)
  133. # Check params were properly replaced
  134. params = kwargs["body"]
  135. self.assertEqual(params["api_key"], "test_key")
  136. self.assertEqual(params["domain"], "example.com")
  137. self.assertEqual(params["ip"], "192.168.1.1")
  138. @patch("ddns.provider.callback.time")
  139. @patch.object(CallbackProvider, "_http")
  140. def test_set_record_post_method_json_token(self, mock_http, mock_time):
  141. """Test set_record method using POST method with JSON string token"""
  142. mock_time.return_value = 1634567890.123
  143. mock_http.return_value = "Success"
  144. auth_token = '{"api_key": "test_key", "domain": "__DOMAIN__", "ip": "__IP__"}'
  145. provider = CallbackProvider(self.auth_id, auth_token)
  146. result = provider.set_record("example.com", "192.168.1.1", "A", 300, "default")
  147. # Verify the result
  148. self.assertTrue(result) # Verify _http was called with correct parameters
  149. mock_http.assert_called_once()
  150. args, kwargs = mock_http.call_args
  151. self.assertEqual(args[0], "POST") # method
  152. # URL should be replaced with actual values even for POST
  153. url = args[1]
  154. self.assertIn("example.com", url)
  155. self.assertIn("192.168.1.1", url)
  156. # Check params were properly replaced
  157. params = kwargs["body"]
  158. self.assertEqual(params["api_key"], "test_key")
  159. self.assertEqual(params["domain"], "example.com")
  160. self.assertEqual(params["ip"], "192.168.1.1")
  161. @patch("ddns.provider.callback.time")
  162. @patch.object(CallbackProvider, "_http")
  163. def test_set_record_post_method_mixed_types(self, mock_http, mock_time):
  164. """Test set_record method with mixed type values in POST parameters"""
  165. mock_time.return_value = 1634567890.123
  166. mock_http.return_value = "Success"
  167. auth_token = {"api_key": 12345, "domain": "__DOMAIN__", "timeout": 30, "enabled": True}
  168. provider = CallbackProvider(self.auth_id, auth_token) # type: ignore
  169. result = provider.set_record("example.com", "192.168.1.1")
  170. # Verify the result
  171. self.assertTrue(result)
  172. # Verify _http was called with correct parameters
  173. mock_http.assert_called_once()
  174. args, kwargs = mock_http.call_args
  175. self.assertEqual(args[0], "POST") # method
  176. # Check that non-string values were not processed, but string values were replaced
  177. params = kwargs["body"]
  178. self.assertEqual(params["api_key"], 12345) # unchanged (not a string)
  179. self.assertEqual(params["domain"], "example.com") # replaced (was a string)
  180. self.assertEqual(params["timeout"], 30) # unchanged (not a string)
  181. self.assertEqual(params["enabled"], True) # unchanged (not a string)
  182. @patch("ddns.provider.callback.time")
  183. @patch.object(CallbackProvider, "_http")
  184. def test_set_record_http_failure(self, mock_http, mock_time):
  185. """Test set_record method when HTTP request fails"""
  186. mock_time.return_value = 1634567890.123
  187. mock_http.return_value = None # Simulate failure
  188. provider = CallbackProvider(self.auth_id, None) # type: ignore
  189. result = provider.set_record("example.com", "192.168.1.1")
  190. # Verify the result is False on failure
  191. self.assertFalse(result)
  192. @patch("ddns.provider.callback.time")
  193. @patch.object(CallbackProvider, "_http")
  194. def test_set_record_http_none_response(self, mock_http, mock_time):
  195. """Test set_record method with None HTTP response"""
  196. mock_time.return_value = 1634567890.123
  197. mock_http.return_value = None # None response
  198. provider = CallbackProvider(self.auth_id, None) # type: ignore
  199. result = provider.set_record("example.com", "192.168.1.1")
  200. # Empty string is falsy, so result should be False
  201. self.assertFalse(result)
  202. @patch("ddns.provider.callback.jsondecode")
  203. def test_json_decode_error_handling(self, mock_jsondecode):
  204. """Test handling of JSON decode errors in POST method"""
  205. mock_jsondecode.side_effect = ValueError("Invalid JSON")
  206. auth_token = "invalid json"
  207. provider = CallbackProvider(self.auth_id, auth_token)
  208. # This should raise an exception when trying to decode invalid JSON
  209. with self.assertRaises(ValueError):
  210. provider.set_record("example.com", "192.168.1.1")
  211. class TestCallbackProviderRealIntegration(BaseProviderTestCase):
  212. """Real integration tests for CallbackProvider using httpbin.org"""
  213. def setUp(self):
  214. """Set up real test fixtures"""
  215. super(TestCallbackProviderRealIntegration, self).setUp()
  216. # Use httpbin.org as a stable test server
  217. self.real_callback_url = "https://httpbin.org/post"
  218. def _setup_provider_with_mock_logger(self, provider):
  219. """Helper method to setup provider with a mock logger."""
  220. mock_logger = self.mock_logger(provider)
  221. # Ensure the logger is configured to capture info calls
  222. mock_logger.setLevel(logging.INFO)
  223. return mock_logger
  224. def _random_delay(self):
  225. """Add a random delay of 0-3 seconds to avoid rate limiting"""
  226. if os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS") or os.environ.get("GITHUB_REF_NAME"):
  227. # In CI environments, use a shorter delay to speed up tests
  228. delay = random.uniform(0, 3)
  229. else:
  230. delay = random.uniform(0, 1)
  231. sleep(delay)
  232. def _assert_callback_result_logged(self, mock_logger, *expected_strings):
  233. """
  234. Helper to assert that 'Callback result: %s' was logged with expected content.
  235. """
  236. info_calls = mock_logger.info.call_args_list
  237. response_logged = False
  238. for call in info_calls:
  239. if len(call[0]) >= 2 and call[0][0] == "Callback result: %s":
  240. response_content = str(call[0][1])
  241. if all(expected in response_content for expected in expected_strings):
  242. response_logged = True
  243. break
  244. self.assertTrue(
  245. response_logged,
  246. "Expected logger.info to log 'Callback result' containing: {}".format(", ".join(expected_strings)),
  247. )
  248. def test_real_callback_get_method(self):
  249. """Test real callback using GET method with httpbin.org and verify logger calls"""
  250. auth_id = "https://httpbin.org/get?domain=__DOMAIN__&ip=__IP__&record_type=__RECORDTYPE__"
  251. domain = "test.example.com"
  252. ip = "111.111.111.111"
  253. provider = CallbackProvider(auth_id, "")
  254. mock_logger = self._setup_provider_with_mock_logger(provider)
  255. self._random_delay() # Add random delay before real request
  256. result = provider.set_record(domain, ip, "A")
  257. self.assertTrue(result)
  258. self._assert_callback_result_logged(mock_logger, domain, ip)
  259. def test_real_callback_post_method_with_json(self):
  260. """Test real callback using POST method with JSON data and verify logger calls"""
  261. auth_id = "https://httpbin.org/post"
  262. auth_token = '{"domain": "__DOMAIN__", "ip": "__IP__", "record_type": "__RECORDTYPE__", "ttl": "__TTL__"}'
  263. provider = CallbackProvider(auth_id, auth_token)
  264. # Setup provider with mock logger
  265. mock_logger = self._setup_provider_with_mock_logger(provider)
  266. self._random_delay() # Add random delay before real request
  267. result = provider.set_record("test.example.com", "203.0.113.2", "A", 300)
  268. # httpbin.org returns JSON with our posted data, so it should be truthy
  269. self.assertTrue(result)
  270. # Verify that logger.info was called with response containing domain and IP
  271. self._assert_callback_result_logged(mock_logger, "test.example.com", "203.0.113.2")
  272. def test_real_callback_error_handling(self):
  273. """Test real callback error handling with invalid URL"""
  274. # Use an invalid URL to test error handling
  275. auth_id = "https://httpbin.org/status/500" # This returns HTTP 500
  276. provider = CallbackProvider(auth_id, "")
  277. self._random_delay() # Add random delay before real request
  278. result = provider.set_record("test.example.com", "203.0.113.5")
  279. self.assertFalse(result)
  280. def test_real_callback_redirects_handling(self):
  281. """Test real callback with various HTTP redirect scenarios and verify logger calls"""
  282. # Test simple redirect
  283. auth_id = "https://httpbin.org/redirect-to?url=https://httpbin.org/get&domain=__DOMAIN__&ip=__IP__"
  284. domain = "redirect.test.example.com"
  285. ip = "203.0.113.21"
  286. provider = CallbackProvider(auth_id, "")
  287. try:
  288. mock_logger = self._setup_provider_with_mock_logger(provider)
  289. self._random_delay() # Add random delay before real request
  290. result = provider.set_record(domain, ip, "A")
  291. self.assertTrue(result)
  292. self._assert_callback_result_logged(mock_logger, domain, ip)
  293. except Exception as e:
  294. error_str = str(e).lower()
  295. if "ssl" in error_str or "certificate" in error_str:
  296. self.skipTest("SSL certificate issue: {}".format(e))
  297. def test_real_callback_redirects_handling_relative(self):
  298. """Test real callback with relative redirect scenarios and verify logger calls"""
  299. # Test relative redirect
  300. auth_id = "https://httpbin.org/relative-redirect/1?domain=__DOMAIN__&ip=__IP__"
  301. domain = "relative-redirect.example.com"
  302. ip = "203.0.113.203"
  303. provider = CallbackProvider(auth_id, "")
  304. try:
  305. mock_logger = self._setup_provider_with_mock_logger(provider)
  306. self._random_delay() # Add random delay before real request
  307. result = provider.set_record(domain, ip, "A")
  308. self.assertTrue(result)
  309. self._assert_callback_result_logged(mock_logger, domain, ip)
  310. except Exception as e:
  311. error_str = str(e).lower()
  312. if "ssl" in error_str or "certificate" in error_str:
  313. self.skipTest("SSL certificate issue: {}".format(e))
  314. def test_real_callback_redirect_with_post(self):
  315. """Test POST request redirect behavior (should change to GET after 302) and verify logger calls"""
  316. # POST to redirect endpoint - should convert to GET after 302
  317. auth_id = "https://httpbin.org/redirect-to?url=https://httpbin.org/get"
  318. auth_token = '{"domain": "__DOMAIN__", "ip": "__IP__", "method": "POST->GET"}'
  319. provider = CallbackProvider(auth_id, auth_token)
  320. try:
  321. # Setup provider with mock logger
  322. mock_logger = self._setup_provider_with_mock_logger(provider)
  323. self._random_delay() # Add random delay before real request
  324. result = provider.set_record("post-redirect.example.com", "203.0.113.202", "A")
  325. # POST should be redirected as GET and succeed
  326. self.assertTrue(result)
  327. # Verify that logger.info was called with response (domain/IP may be lost in POST->GET redirect)
  328. self._assert_callback_result_logged(mock_logger)
  329. except ssl.SSLError as e:
  330. error_str = str(e).lower()
  331. if "ssl" in error_str or "certificate" in error_str:
  332. self.skipTest("SSL certificate issue: {}".format(e))
  333. if __name__ == "__main__":
  334. unittest.main()