_base.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. # coding=utf-8
  2. """
  3. ## SimpleProvider 简单DNS抽象基类
  4. * set_record()
  5. ## BaseProvider 标准DNS抽象基类
  6. 定义所有 DNS 服务商 API 类应继承的抽象基类,统一接口,便于扩展适配多服务商。
  7. Defines a unified interface to support extension and adaptation across providers.
  8. * _query_zone_id
  9. * _query_record
  10. * _update_record
  11. * _create_record
  12. ┌──────────────────────────────────────────────────┐
  13. │ 调用 set_record(domain, value...) │
  14. └──────────────────────────────────────────────────┘
  15. ┌──────────────────────────────────────┐
  16. │ 快速解析 是否包含 ~ 或 + 分隔符? │
  17. └──────────────────────────────────────┘
  18. │ │
  19. [是,拆解成功] [否,无法拆解]
  20. sub 和 main│ │ domain
  21. ▼ ▼
  22. ┌────────────────────────┐ ┌──────────────────────────┐
  23. │ 查询 zone_id │ │ 自动循环解析 while: │
  24. │ _query_zone_id(main) │ │ _query_zone_id(...) │
  25. └────────────────────────┘ └──────────────────────────┘
  26. │ │
  27. ▼ ▼
  28. zone_id ←──────────────┬─── sub
  29. ┌─────────────────────────────────────┐
  30. │ 查询 record: │
  31. │ _query_record(zone_id, sub, ...) │
  32. └─────────────────────────────────────┘
  33. ┌─────────────┴────────────────┐
  34. │ record_id 是否存在? │
  35. └────────────┬─────────────────┘
  36. ┌──────────────┴─────────────┐
  37. │ │
  38. ▼ ▼
  39. ┌─────────────────────┐ ┌─────────────────────┐
  40. │ 更新记录 │ │ 创建记录 │
  41. │ _update_record(...) │ │ _create_record(...) │
  42. └─────────────────────┘ └─────────────────────┘
  43. │ │
  44. ▼ ▼
  45. ┌───────────────────────────────┐
  46. │ 返回操作结果 │
  47. └───────────────────────────────┘
  48. @author: NewFuture
  49. """
  50. from abc import ABCMeta, abstractmethod
  51. from json import loads as jsondecode, dumps as jsonencode
  52. from logging import Logger, getLogger # noqa:F401 # type: ignore[no-redef]
  53. from ..util.http import request, quote, urlencode
  54. TYPE_FORM = "application/x-www-form-urlencoded"
  55. TYPE_JSON = "application/json"
  56. def encode_params(params):
  57. # type: (dict|list|str|bytes|None) -> str
  58. """
  59. 编码参数为 URL 查询字符串,参数顺序会排序
  60. Args:
  61. params (dict|list|str|bytes|None): 参数字典、列表或字符串
  62. Returns:
  63. str: 编码后的查询字符串
  64. """
  65. if not params:
  66. return ""
  67. elif isinstance(params, (str, bytes)):
  68. return params # type: ignore[return-value]
  69. items = params.items() if isinstance(params, dict) else params
  70. return urlencode(sorted(items), doseq=True)
  71. class SimpleProvider(object):
  72. """
  73. 简单DNS服务商接口的抽象基类, 必须实现 `set_record` 方法。
  74. Abstract base class for all simple DNS provider APIs.
  75. Subclasses must implement `set_record`.
  76. * set_record(domain, value, record_type="A", ttl=None, line=None, **extra)
  77. """
  78. __metaclass__ = ABCMeta
  79. # API endpoint domain (to be defined in subclass)
  80. endpoint = "" # type: str # https://exampledns.com
  81. # Content-Type for requests (to be defined in subclass)
  82. content_type = TYPE_FORM # type: Literal["application/x-www-form-urlencoded"] | Literal["application/json"]
  83. # 默认 accept 头部, 空则不设置
  84. accept = TYPE_JSON # type: str | None
  85. # Decode Response as JSON by default
  86. decode_response = True # type: bool
  87. # Description
  88. remark = "Managed by [DDNS](https://ddns.newfuture.cc)"
  89. def __init__(self, id, token, logger=None, ssl="auto", proxy=None, endpoint=None, **options):
  90. # type: (str, str, Logger | None, bool|str, list[str]|None, str|None, **object) -> None
  91. """
  92. 初始化服务商对象
  93. Initialize provider instance.
  94. Args:
  95. id (str): 身份认证 ID / Authentication ID
  96. token (str): 密钥 / Authentication Token
  97. proxy (list[str | None] | None): 代理配置,支持代理列表
  98. options (dict): 其它参数 / Additional options
  99. """
  100. self.id = id
  101. self.token = token
  102. if endpoint:
  103. self.endpoint = endpoint
  104. # 处理代理配置
  105. self._proxy = proxy # 代理列表或None
  106. self._ssl = ssl
  107. self.options = options
  108. name = self.__class__.__name__
  109. self.logger = (logger or getLogger()).getChild(name)
  110. self._zone_map = {} # type: dict[str, str]
  111. self.logger.debug("%s initialized with: %s", self.__class__.__name__, id)
  112. self._validate() # 验证身份认证信息
  113. @abstractmethod
  114. def set_record(self, domain, value, record_type="A", ttl=None, line=None, **extra):
  115. # type: (str, str, str, str | int | None, str | None, **object) -> bool
  116. """
  117. 设置 DNS 记录(创建或更新)
  118. Set or update DNS record.
  119. Args:
  120. domain (str): 完整域名
  121. value (str): 新记录值
  122. record_type (str): 记录类型
  123. ttl (int | None): TTL 值,可选
  124. line (str | None): 线路信息
  125. extra (dict): 额外参数
  126. Returns:
  127. Any: 执行结果
  128. """
  129. raise NotImplementedError("This set_record should be implemented by subclasses")
  130. def _validate(self):
  131. # type: () -> None
  132. """
  133. 验证身份认证信息是否填写
  134. Validate authentication credentials.
  135. """
  136. if not self.id:
  137. raise ValueError("id must be configured")
  138. if not self.token:
  139. raise ValueError("token must be configured")
  140. if not self.endpoint:
  141. raise ValueError("API endpoint must be defined in {}".format(self.__class__.__name__))
  142. def _http(self, method, url, params=None, body=None, queries=None, headers=None): # noqa: C901
  143. # type: (str, str, dict[str,Any]|str|None, dict[str,Any]|str|None, dict[str,Any]|None, dict|None) -> Any
  144. """
  145. 发送 HTTP/HTTPS 请求,自动根据 API/url 选择协议。
  146. Args:
  147. method (str): 请求方法,如 GET、POST
  148. url (str): 请求路径
  149. params (dict[str, Any] | None): 请求参数,自动处理 query string 或者body
  150. body (dict[str, Any] | str | None): 请求体内容
  151. queries (dict[str, Any] | None): 查询参数,自动处理为 URL 查询字符串
  152. headers (dict): 头部,可选
  153. Returns:
  154. Any: 解析后的响应内容
  155. Raises:
  156. RuntimeError: 当响应状态码为400/401或5xx(服务器错误)时抛出异常
  157. """
  158. method = method.upper()
  159. # 简化参数处理逻辑
  160. query_params = queries or {}
  161. if params:
  162. if method in ("GET", "DELETE"):
  163. if isinstance(params, dict):
  164. query_params.update(params)
  165. else:
  166. # params是字符串,直接作为查询字符串
  167. url += ("&" if "?" in url else "?") + str(params)
  168. params = None
  169. elif body is None:
  170. body = params
  171. # 构建查询字符串
  172. if len(query_params) > 0:
  173. url += ("&" if "?" in url else "?") + encode_params(query_params)
  174. # 构建完整URL
  175. if not url.startswith("http://") and not url.startswith("https://"):
  176. if not url.startswith("/") and self.endpoint.endswith("/"):
  177. url = "/" + url
  178. url = self.endpoint + url
  179. # 记录请求日志
  180. self.logger.info("%s %s", method, self._mask_sensitive_data(url))
  181. # 处理请求体
  182. body_data, headers = None, headers or {}
  183. if body:
  184. if "content-type" not in headers:
  185. headers["content-type"] = self.content_type
  186. body_data = self._encode_body(body)
  187. self.logger.debug("body:\n%s", self._mask_sensitive_data(body_data))
  188. # 处理headers
  189. if self.accept and "accept" not in headers and "Accept" not in headers:
  190. headers["accept"] = self.accept
  191. if len(headers) > 2:
  192. self.logger.debug("headers:\n%s", {k: self._mask_sensitive_data(v) for k, v in headers.items()})
  193. # 直接传递代理列表给request函数
  194. response = request(method, url, body_data, headers=headers, proxies=self._proxy, verify=self._ssl, retries=2)
  195. # 处理响应
  196. status_code = response.status
  197. if not (200 <= status_code < 300):
  198. self.logger.warning("response status: %s %s", status_code, response.reason)
  199. res = response.body
  200. # 针对客户端错误、认证/授权错误和服务器错误直接抛出异常
  201. if status_code >= 500 or status_code in (400, 401, 403):
  202. self.logger.error("HTTP error:\n%s", res)
  203. if status_code == 400:
  204. raise RuntimeError("参数错误 [400]: " + response.reason)
  205. elif status_code == 401:
  206. raise RuntimeError("认证失败 [401]: " + response.reason)
  207. elif status_code == 403:
  208. raise RuntimeError("禁止访问 [403]: " + response.reason)
  209. else:
  210. raise RuntimeError("服务器错误 [{}]: {}".format(status_code, response.reason))
  211. self.logger.debug("response:\n%s", res)
  212. if not self.decode_response:
  213. return res
  214. try:
  215. return jsondecode(res)
  216. except Exception as e:
  217. self.logger.error("fail to decode response: %s", e)
  218. return res
  219. def _encode_body(self, data):
  220. # type: (dict | list | str | bytes | None) -> str
  221. """
  222. 自动编码数据为字符串或字节, 根据 content_type 选择编码方式。
  223. Args:
  224. data (dict | list | str | bytes | None): 待编码数据
  225. Returns:
  226. str | bytes | None: 编码后的数据
  227. """
  228. if isinstance(data, (str, bytes)):
  229. return data # type: ignore[return-value]
  230. if not data:
  231. return ""
  232. if self.content_type == TYPE_FORM:
  233. return encode_params(data)
  234. return jsonencode(data)
  235. def _mask_sensitive_data(self, data):
  236. # type: (str | bytes | None) -> str | bytes | None
  237. """
  238. 对敏感数据进行打码处理,用于日志输出,支持URL编码的敏感信息
  239. Args:
  240. data (str | bytes | None): 需要处理的数据
  241. Returns:
  242. str | bytes | None: 打码后的字符串
  243. """
  244. if not data or not self.token:
  245. return data
  246. # 生成打码后的token
  247. token_masked = self.token[:2] + "***" + self.token[-2:] if len(self.token) > 4 else "***"
  248. token_encoded = quote(self.token, safe="")
  249. if isinstance(data, bytes): # 处理字节数据
  250. return data.replace(self.token.encode(), token_masked.encode()).replace(
  251. token_encoded.encode(), token_masked.encode()
  252. )
  253. if hasattr(data, "replace"): # 处理字符串数据
  254. return data.replace(self.token, token_masked).replace(token_encoded, token_masked)
  255. return data
  256. class BaseProvider(SimpleProvider):
  257. """
  258. 标准DNS服务商接口的抽象基类
  259. Abstract base class for all standard DNS provider APIs.
  260. Subclasses must implement the abstract methods to support various providers.
  261. * _query_zone_id()
  262. * _query_record_id()
  263. * _update_record()
  264. * _create_record()
  265. """
  266. def set_record(self, domain, value, record_type="A", ttl=None, line=None, **extra):
  267. # type: (str, str, str, str | int | None, str | None, **Any) -> bool
  268. """
  269. 设置 DNS 记录(创建或更新)
  270. Set or update DNS record.
  271. Args:
  272. domain (str): 完整域名
  273. value (str): 新记录值
  274. record_type (str): 记录类型
  275. ttl (int | None): TTL 值,可选
  276. line (str | None): 线路信息
  277. extra (dict): 额外参数
  278. Returns:
  279. bool: 执行结果
  280. """
  281. domain = domain.lower()
  282. self.logger.info("%s => %s(%s)", domain, value, record_type)
  283. sub, main = _split_custom_domain(domain)
  284. try:
  285. if sub is not None:
  286. # 使用自定义分隔符格式
  287. zone_id = self.get_zone_id(main)
  288. else:
  289. # 自动分析域名
  290. zone_id, sub, main = self._split_zone_and_sub(domain)
  291. self.logger.info("sub: %s, main: %s(id=%s)", sub, main, zone_id)
  292. if not zone_id or sub is None:
  293. self.logger.critical("找不到 zone_id 或 subdomain: %s", domain)
  294. return False
  295. # 查询现有记录
  296. record = self._query_record(zone_id, sub, main, record_type=record_type, line=line, extra=extra)
  297. # 更新或创建记录
  298. if record:
  299. self.logger.info("Found existing record: %s", record)
  300. return self._update_record(zone_id, record, value, record_type, ttl=ttl, line=line, extra=extra)
  301. else:
  302. self.logger.warning("No existing record found, creating new one")
  303. return self._create_record(zone_id, sub, main, value, record_type, ttl=ttl, line=line, extra=extra)
  304. except Exception as e:
  305. self.logger.exception("Error setting record for %s: %s", domain, e)
  306. return False
  307. def get_zone_id(self, domain):
  308. # type: (str) -> str | None
  309. """
  310. 查询指定域名对应的 zone_id
  311. Get zone_id for the domain.
  312. Args:
  313. domain (str): 主域名 / main name
  314. Returns:
  315. str | None: 区域 ID / Zone identifier
  316. """
  317. if domain in self._zone_map:
  318. return self._zone_map[domain]
  319. zone_id = self._query_zone_id(domain)
  320. if zone_id:
  321. self._zone_map[domain] = zone_id
  322. return zone_id
  323. @abstractmethod
  324. def _query_zone_id(self, domain):
  325. # type: (str) -> str | None
  326. """
  327. 查询主域名的 zone ID
  328. Args:
  329. domain (str): 主域名
  330. Returns:
  331. str | None: Zone ID
  332. """
  333. return domain
  334. @abstractmethod
  335. def _query_record(self, zone_id, subdomain, main_domain, record_type, line, extra):
  336. # type: (str, str, str, str, str | None, dict) -> Any
  337. """
  338. 查询 DNS 记录 ID
  339. Args:
  340. zone_id (str): 区域 ID
  341. subdomain (str): 子域名
  342. main_domain (str): 主域名
  343. record_type (str): 记录类型,例如 A、AAAA
  344. line (str | None): 线路选项,可选
  345. extra (dict): 额外参数
  346. Returns:
  347. Any | None: 记录
  348. """
  349. raise NotImplementedError("This _query_record should be implemented by subclasses")
  350. @abstractmethod
  351. def _create_record(self, zone_id, subdomain, main_domain, value, record_type, ttl, line, extra):
  352. # type: (str, str, str, str, str, int | str | None, str | None, dict) -> bool
  353. """
  354. 创建新 DNS 记录
  355. Args:
  356. zone_id (str): 区域 ID
  357. subdomain (str): 子域名
  358. main_domain (str): 主域名
  359. value (str): 记录值
  360. record_type (str): 类型,如 A
  361. ttl (int | None): TTL 可选
  362. line (str | None): 线路选项
  363. extra (dict | None): 额外字段
  364. Returns:
  365. Any: 操作结果
  366. """
  367. raise NotImplementedError("This _create_record should be implemented by subclasses")
  368. @abstractmethod
  369. def _update_record(self, zone_id, old_record, value, record_type, ttl, line, extra):
  370. # type: (str, dict, str, str, int | str | None, str | None, dict) -> bool
  371. """
  372. 更新已有 DNS 记录
  373. Args:
  374. zone_id (str): 区域 ID
  375. old_record (dict): 旧记录信息
  376. value (str): 新的记录值
  377. record_type (str): 类型
  378. ttl (int | None): TTL
  379. line (str | None): 线路
  380. extra (dict | None): 额外参数
  381. Returns:
  382. bool: 操作结果
  383. """
  384. raise NotImplementedError("This _update_record should be implemented by subclasses")
  385. def _split_zone_and_sub(self, domain):
  386. # type: (str) -> tuple[str | None, str | None, str ]
  387. """
  388. 从完整域名拆分主域名和子域名
  389. Args:
  390. domain (str): 完整域名
  391. Returns:
  392. (zone_id, sub): 元组
  393. """
  394. domain_split = domain.split(".")
  395. zone_id = None
  396. index = 2
  397. main = ""
  398. while not zone_id and index <= len(domain_split):
  399. main = ".".join(domain_split[-index:])
  400. zone_id = self.get_zone_id(main)
  401. index += 1
  402. if zone_id:
  403. sub = ".".join(domain_split[: -index + 1]) or "@"
  404. self.logger.debug("zone_id: %s, sub: %s", zone_id, sub)
  405. return zone_id, sub, main
  406. return None, None, main
  407. def _split_custom_domain(domain):
  408. # type: (str) -> tuple[str | None, str]
  409. """
  410. 拆分支持 ~ 或 + 的自定义格式域名为 (子域, 主域)
  411. 如 sub~example.com => ('sub', 'example.com')
  412. Returns:
  413. (sub, main): 子域 + 主域
  414. """
  415. for sep in ("~", "+"):
  416. if sep in domain:
  417. sub, main = domain.split(sep, 1)
  418. return sub, main
  419. return None, domain
  420. def join_domain(sub, main):
  421. # type: (str | None, str) -> str
  422. """
  423. 合并子域名和主域名为完整域名
  424. Args:
  425. sub (str | None): 子域名
  426. main (str): 主域名
  427. Returns:
  428. str: 完整域名
  429. """
  430. sub = sub and sub.strip(".").strip().lower()
  431. main = main and main.strip(".").strip().lower()
  432. if not sub or sub == "@":
  433. if not main:
  434. raise ValueError("Both sub and main cannot be empty")
  435. return main
  436. if not main:
  437. return sub
  438. return "{}.{}".format(sub, main)