_base.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. # coding=utf-8
  2. """
  3. ## SimpleProvider 简单DNS抽象基类
  4. * set_record()
  5. ## BaseProvider 标准DNS抽象基类
  6. 定义所有 DNS 服务商 API 类应继承的抽象基类,统一接口,便于扩展适配多服务商。
  7. Abstract base class for DNS provider APIs.
  8. Defines a unified interface to support extension and adaptation across providers.
  9. * _query_zone_id
  10. * _query_record
  11. * _update_record
  12. * _create_record
  13. ┌──────────────────────────────────────────────────┐
  14. │ 用户调用 set_record(domain, value...) │
  15. └──────────────────────────────────────────────────┘
  16. ┌──────────────────────────────────────┐
  17. │ 快速解析 是否包含 ~ 或 + 分隔符? │
  18. └──────────────────────────────────────┘
  19. │ │
  20. [是,拆解成功] [否,无法拆解]
  21. sub 和 main│ │ domain
  22. ▼ ▼
  23. ┌────────────────────────┐ ┌──────────────────────────┐
  24. │ 查询 zone_id │ │ 自动循环解析 while: │
  25. │ _query_zone_id(main) │ │ _query_zone_id(...) │
  26. └────────────────────────┘ └──────────────────────────┘
  27. │ │
  28. ▼ ▼
  29. zone_id ←──────────────┬─── sub
  30. ┌─────────────────────────────────────┐
  31. │ 查询 record: │
  32. │ _query_record(zone_id, sub, ...) │
  33. └─────────────────────────────────────┘
  34. ┌─────────────┴────────────────┐
  35. │ record_id 是否存在? │
  36. └────────────┬─────────────────┘
  37. ┌──────────────┴─────────────┐
  38. │ │
  39. ▼ ▼
  40. ┌─────────────────────┐ ┌─────────────────────┐
  41. │ 更新记录 │ │ 创建记录 │
  42. │ _update_record(...) │ │ _create_record(...) │
  43. └─────────────────────┘ └─────────────────────┘
  44. │ │
  45. ▼ ▼
  46. ┌───────────────────────────────┐
  47. │ 返回操作结果 │
  48. └───────────────────────────────┘
  49. @author: NewFuture
  50. """
  51. from os import environ
  52. from abc import ABCMeta, abstractmethod
  53. from json import loads as jsondecode, dumps as jsonencode
  54. from logging import Logger, getLogger # noqa:F401 # type: ignore[no-redef]
  55. from ..util.http import send_http_request
  56. try: # python 3
  57. from urllib.parse import quote, urlencode
  58. except ImportError: # python 2
  59. from urllib import urlencode, quote # type: ignore[no-redef,import-untyped]
  60. TYPE_FORM = "application/x-www-form-urlencoded"
  61. TYPE_JSON = "application/json"
  62. class SimpleProvider(object):
  63. """
  64. 简单DNS服务商接口的抽象基类, 必须实现 `set_record` 方法。
  65. Abstract base class for all simple DNS provider APIs.
  66. Subclasses must implement `set_record`.
  67. * set_record(domain, value, record_type="A", ttl=None, line=None, **extra)
  68. """
  69. __metaclass__ = ABCMeta
  70. # API endpoint domain (to be defined in subclass)
  71. API = "" # type: str # https://exampledns.com
  72. # Content-Type for requests (to be defined in subclass)
  73. content_type = TYPE_FORM # type: Literal["application/x-www-form-urlencoded"] | Literal["application/json"]
  74. # 默认 accept 头部, 空则不设置
  75. accept = TYPE_JSON # type: str | None
  76. # Decode Response as JSON by default
  77. decode_response = True
  78. # 是否验证 SSL 证书,默认为 True
  79. verify_ssl = "auto" # type: bool | str
  80. # 版本
  81. version = environ.get("DDNS_VERSION", "0.0.0")
  82. # Description
  83. remark = "Managed by [DDNS v{}](https://ddns.newfuture.cc)".format(version)
  84. def __init__(self, auth_id, auth_token, logger=None, verify_ssl=None, **options):
  85. # type: (str, str, Logger | None, bool|str| None, **object) -> None
  86. """
  87. 初始化服务商对象
  88. Initialize provider instance.
  89. Args:
  90. auth_id (str): 身份认证 ID / Authentication ID
  91. auth_token (str): 密钥 / Authentication Token
  92. options (dict): 其它参数,如代理、调试等 / Additional options
  93. """
  94. self.auth_id = auth_id # type: str
  95. self.auth_token = auth_token # type: str
  96. self.options = options
  97. name = self.__class__.__name__
  98. self.logger = (logger or getLogger()).getChild(name)
  99. self.proxy = None # type: str | None
  100. if verify_ssl is not None:
  101. self.verify_ssl = verify_ssl
  102. self._zone_map = {} # type: dict[str, str]
  103. self.logger.debug("%s initialized with: %s", self.__class__.__name__, auth_id)
  104. self._validate() # 验证身份认证信息
  105. @abstractmethod
  106. def set_record(self, domain, value, record_type="A", ttl=None, line=None, **extra):
  107. # type: (str, str, str, str | int | None, str | None, **object) -> bool
  108. """
  109. 设置 DNS 记录(创建或更新)
  110. Set or update DNS record.
  111. Args:
  112. domain (str): 完整域名
  113. value (str): 新记录值
  114. record_type (str): 记录类型
  115. ttl (int | None): TTL 值,可选
  116. line (str | None): 线路信息
  117. extra (dict): 额外参数
  118. Returns:
  119. Any: 执行结果
  120. """
  121. raise NotImplementedError("This set_record should be implemented by subclasses")
  122. def set_proxy(self, proxy_str):
  123. # type: (str | None) -> SimpleProvider
  124. """
  125. 设置代理服务器
  126. Set HTTPS proxy string.
  127. Args:
  128. proxy_str (str): 代理地址
  129. Returns:
  130. Self: 自身
  131. """
  132. self.proxy = proxy_str
  133. return self
  134. def _validate(self):
  135. # type: () -> None
  136. """
  137. 验证身份认证信息是否填写
  138. Validate authentication credentials.
  139. """
  140. if not self.auth_id:
  141. raise ValueError("id must be configured")
  142. if not self.auth_token:
  143. raise ValueError("token must be configured")
  144. if not self.API:
  145. raise ValueError("API endpoint must be defined in {}".format(self.__class__.__name__))
  146. def _http(self, method, url, params=None, body=None, queries=None, headers=None): # noqa: C901
  147. # type: (str, str, dict[str,Any]|str|None, dict[str,Any]|str|None, dict[str,Any]|None, dict|None) -> Any
  148. """
  149. 发送 HTTP/HTTPS 请求,自动根据 API/url 选择协议。
  150. Args:
  151. method (str): 请求方法,如 GET、POST
  152. url (str): 请求路径
  153. params (dict[str, Any] | None): 请求参数,自动处理 query string 或者body
  154. body (dict[str, Any] | str | None): 请求体内容
  155. queries (dict[str, Any] | None): 查询参数,自动处理为 URL 查询字符串
  156. headers (dict): 头部,可选
  157. Returns:
  158. Any: 解析后的响应内容
  159. Raises:
  160. RuntimeError: 当响应状态码为400/401或5xx(服务器错误)时抛出异常
  161. """
  162. method = method.upper()
  163. # 简化参数处理逻辑
  164. query_params = queries or {}
  165. if params:
  166. if method in ("GET", "DELETE"):
  167. if isinstance(params, dict):
  168. query_params.update(params)
  169. else:
  170. # params是字符串,直接作为查询字符串
  171. url += ("&" if "?" in url else "?") + str(params)
  172. params = None
  173. elif body is None:
  174. body = params
  175. # 构建查询字符串
  176. if len(query_params) > 0:
  177. url += ("&" if "?" in url else "?") + self._encode(query_params)
  178. # 构建完整URL
  179. if not url.startswith("http://") and not url.startswith("https://"):
  180. if not url.startswith("/") and self.API.endswith("/"):
  181. url = "/" + url
  182. url = self.API + url
  183. # 记录请求日志
  184. self.logger.info("%s %s", method, self._mask_sensitive_data(url))
  185. # 处理请求体
  186. body_data, headers = None, headers or {}
  187. if body:
  188. if "content-type" not in headers:
  189. headers["content-type"] = self.content_type
  190. if isinstance(body, (str, bytes)):
  191. body_data = body
  192. elif self.content_type == TYPE_FORM:
  193. body_data = self._encode(body)
  194. else:
  195. body_data = jsonencode(body)
  196. self.logger.debug("body:\n%s", self._mask_sensitive_data(body_data))
  197. # 处理headers
  198. if self.accept and "accept" not in headers and "Accept" not in headers:
  199. headers["accept"] = self.accept
  200. if len(headers) > 2:
  201. self.logger.debug("headers:\n%s", {k: self._mask_sensitive_data(v) for k, v in headers.items()})
  202. response = send_http_request(
  203. url=url,
  204. method=method,
  205. body=body_data,
  206. headers=headers,
  207. proxy=self.proxy,
  208. max_redirects=5,
  209. verify_ssl=self.verify_ssl,
  210. )
  211. # 处理响应
  212. status_code = response.status
  213. if not (200 <= status_code < 300):
  214. self.logger.warning("response status: %s %s", status_code, response.reason)
  215. res = response.body
  216. # 针对客户端错误、认证/授权错误和服务器错误直接抛出异常
  217. if status_code >= 500 or status_code in (400, 401, 403):
  218. self.logger.error("HTTP error:\n%s", res)
  219. if status_code == 400:
  220. raise RuntimeError("请求参数错误 [400]: " + response.reason)
  221. elif status_code == 401:
  222. raise RuntimeError("认证失败 [401]: " + response.reason)
  223. elif status_code == 403:
  224. raise RuntimeError("权限不足 [403]: " + response.reason)
  225. else:
  226. raise RuntimeError("服务器错误 [{}]: {}".format(status_code, response.reason))
  227. self.logger.debug("response:\n%s", res)
  228. if not self.decode_response:
  229. return res
  230. try:
  231. return jsondecode(res)
  232. except Exception as e:
  233. self.logger.error("fail to decode response: %s", e)
  234. return res
  235. @staticmethod
  236. def _encode(params):
  237. # type: (dict|list|str|bytes|None) -> str
  238. """
  239. 编码参数为 URL 查询字符串
  240. Args:
  241. params (dict|list|str|bytes|None): 参数字典、列表或字符串
  242. Returns:
  243. str: 编码后的查询字符串
  244. """
  245. if not params:
  246. return ""
  247. elif isinstance(params, (str, bytes)):
  248. return params # type: ignore[return-value]
  249. return urlencode(params, doseq=True)
  250. @staticmethod
  251. def _quote(data, safe="/"):
  252. # type: (str, str) -> str
  253. """
  254. 对字符串进行 URL 编码
  255. Args:
  256. data (str): 待编码字符串
  257. Returns:
  258. str: 编码后的字符串
  259. """
  260. return quote(data, safe=safe)
  261. def _mask_sensitive_data(self, data):
  262. # type: (str | bytes | None) -> str | bytes | None
  263. """
  264. 对敏感数据进行打码处理,用于日志输出,支持URL编码的敏感信息
  265. Args:
  266. data (str | bytes | None): 需要处理的数据
  267. Returns:
  268. str | bytes | None: 打码后的字符串
  269. """
  270. if not data or not self.auth_token:
  271. return data
  272. # 生成打码后的token
  273. token_masked = self.auth_token[:2] + "***" + self.auth_token[-2:] if len(self.auth_token) > 4 else "***"
  274. token_encoded = quote(self.auth_token, safe="")
  275. if isinstance(data, bytes): # 处理字节数据
  276. return data.replace(self.auth_token.encode(), token_masked.encode()).replace(
  277. token_encoded.encode(), token_masked.encode()
  278. )
  279. if hasattr(data, "replace"): # 处理字符串数据
  280. return data.replace(self.auth_token, token_masked).replace(token_encoded, token_masked)
  281. return data
  282. class BaseProvider(SimpleProvider):
  283. """
  284. 标准DNS服务商接口的抽象基类
  285. Abstract base class for all standard DNS provider APIs.
  286. Subclasses must implement the abstract methods to support various providers.
  287. * _query_zone_id()
  288. * _query_record_id()
  289. * _update_record()
  290. * _create_record()
  291. """
  292. def set_record(self, domain, value, record_type="A", ttl=None, line=None, **extra):
  293. # type: (str, str, str, str | int | None, str | None, **Any) -> bool
  294. """
  295. 设置 DNS 记录(创建或更新)
  296. Set or update DNS record.
  297. Args:
  298. domain (str): 完整域名
  299. value (str): 新记录值
  300. record_type (str): 记录类型
  301. ttl (int | None): TTL 值,可选
  302. line (str | None): 线路信息
  303. extra (dict): 额外参数
  304. Returns:
  305. bool: 执行结果
  306. """
  307. domain = domain.lower()
  308. self.logger.info("%s => %s(%s)", domain, value, record_type)
  309. # 优化域名解析逻辑
  310. sub, main = self._split_custom_domain(domain)
  311. try:
  312. if sub is not None:
  313. # 使用自定义分隔符格式
  314. zone_id = self.get_zone_id(main)
  315. else:
  316. # 自动分析域名
  317. zone_id, sub, main = self._split_zone_and_sub(domain)
  318. self.logger.info("sub: %s, main: %s(id=%s)", sub, main, zone_id)
  319. if not zone_id or sub is None:
  320. self.logger.critical("找不到 zone_id 或 subdomain: %s", domain)
  321. return False
  322. # 查询现有记录
  323. record = self._query_record(zone_id, sub, main, record_type=record_type, line=line, extra=extra)
  324. # 更新或创建记录
  325. if record:
  326. self.logger.info("Found existing record: %s", record)
  327. return self._update_record(zone_id, record, value, record_type, ttl=ttl, line=line, extra=extra)
  328. else:
  329. self.logger.warning("No existing record found, creating new one")
  330. return self._create_record(zone_id, sub, main, value, record_type, ttl=ttl, line=line, extra=extra)
  331. except Exception as e:
  332. self.logger.exception("Error setting record for %s: %s", domain, e)
  333. return False
  334. def get_zone_id(self, domain):
  335. # type: (str) -> str | None
  336. """
  337. 查询指定域名对应的 zone_id
  338. Get zone_id for the domain.
  339. Args:
  340. domain (str): 主域名 / main name
  341. Returns:
  342. str | None: 区域 ID / Zone identifier
  343. """
  344. if domain in self._zone_map:
  345. return self._zone_map[domain]
  346. zone_id = self._query_zone_id(domain)
  347. if zone_id:
  348. self._zone_map[domain] = zone_id
  349. return zone_id
  350. @abstractmethod
  351. def _query_zone_id(self, domain):
  352. # type: (str) -> str | None
  353. """
  354. 查询主域名的 zone ID
  355. Args:
  356. domain (str): 主域名
  357. Returns:
  358. str | None: Zone ID
  359. """
  360. return domain
  361. @abstractmethod
  362. def _query_record(self, zone_id, subdomain, main_domain, record_type, line, extra):
  363. # type: (str, str, str, str, str | None, dict) -> Any
  364. """
  365. 查询 DNS 记录 ID
  366. Args:
  367. zone_id (str): 区域 ID
  368. subdomain (str): 子域名
  369. main_domain (str): 主域名
  370. record_type (str): 记录类型,例如 A、AAAA
  371. line (str | None): 线路选项,可选
  372. extra (dict): 额外参数
  373. Returns:
  374. Any | None: 记录
  375. """
  376. raise NotImplementedError("This _query_record should be implemented by subclasses")
  377. @abstractmethod
  378. def _create_record(self, zone_id, subdomain, main_domain, value, record_type, ttl, line, extra):
  379. # type: (str, str, str, str, str, int | str | None, str | None, dict) -> bool
  380. """
  381. 创建新 DNS 记录
  382. Args:
  383. zone_id (str): 区域 ID
  384. subdomain (str): 子域名
  385. main_domain (str): 主域名
  386. value (str): 记录值
  387. record_type (str): 类型,如 A
  388. ttl (int | None): TTL 可选
  389. line (str | None): 线路选项
  390. extra (dict | None): 额外字段
  391. Returns:
  392. Any: 操作结果
  393. """
  394. raise NotImplementedError("This _create_record should be implemented by subclasses")
  395. @abstractmethod
  396. def _update_record(self, zone_id, old_record, value, record_type, ttl, line, extra):
  397. # type: (str, dict, str, str, int | str | None, str | None, dict) -> bool
  398. """
  399. 更新已有 DNS 记录
  400. Args:
  401. zone_id (str): 区域 ID
  402. old_record (dict): 旧记录信息
  403. value (str): 新的记录值
  404. record_type (str): 类型
  405. ttl (int | None): TTL
  406. line (str | None): 线路
  407. extra (dict | None): 额外参数
  408. Returns:
  409. bool: 操作结果
  410. """
  411. raise NotImplementedError("This _update_record should be implemented by subclasses")
  412. def _split_zone_and_sub(self, domain):
  413. # type: (str) -> tuple[str | None, str | None, str ]
  414. """
  415. 从完整域名拆分主域名和子域名
  416. Args:
  417. domain (str): 完整域名
  418. Returns:
  419. (zone_id, sub): 元组
  420. """
  421. domain_split = domain.split(".")
  422. zone_id = None
  423. index = 2
  424. main = ""
  425. while not zone_id and index <= len(domain_split):
  426. main = ".".join(domain_split[-index:])
  427. zone_id = self.get_zone_id(main)
  428. index += 1
  429. if zone_id:
  430. sub = ".".join(domain_split[: -index + 1]) or "@"
  431. self.logger.debug("zone_id: %s, sub: %s", zone_id, sub)
  432. return zone_id, sub, main
  433. return None, None, main
  434. @staticmethod
  435. def _split_custom_domain(domain):
  436. # type: (str) -> tuple[str | None, str]
  437. """
  438. 拆分支持 ~ 或 + 的自定义格式域名为 (子域, 主域)
  439. 如 sub~example.com => ('sub', 'example.com')
  440. Returns:
  441. (sub, main): 子域 + 主域
  442. """
  443. for sep in ("~", "+"):
  444. if sep in domain:
  445. sub, main = domain.split(sep, 1)
  446. return sub, main
  447. return None, domain
  448. @staticmethod
  449. def _join_domain(sub, main):
  450. # type: (str | None, str) -> str
  451. """
  452. 合并子域名和主域名为完整域名
  453. Args:
  454. sub (str | None): 子域名
  455. main (str): 主域名
  456. Returns:
  457. str: 完整域名
  458. """
  459. sub = sub and sub.strip(".").strip().lower()
  460. main = main and main.strip(".").strip().lower()
  461. if not sub or sub == "@":
  462. if not main:
  463. raise ValueError("Both sub and main cannot be empty")
  464. return main
  465. if not main:
  466. return sub
  467. return "{}.{}".format(sub, main)