router.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. package dns
  2. import (
  3. "context"
  4. "errors"
  5. "net/netip"
  6. "strings"
  7. "time"
  8. "github.com/sagernet/sing-box/adapter"
  9. "github.com/sagernet/sing-box/common/taskmonitor"
  10. C "github.com/sagernet/sing-box/constant"
  11. "github.com/sagernet/sing-box/log"
  12. "github.com/sagernet/sing-box/option"
  13. R "github.com/sagernet/sing-box/route/rule"
  14. "github.com/sagernet/sing-tun"
  15. "github.com/sagernet/sing/common"
  16. E "github.com/sagernet/sing/common/exceptions"
  17. F "github.com/sagernet/sing/common/format"
  18. "github.com/sagernet/sing/common/logger"
  19. M "github.com/sagernet/sing/common/metadata"
  20. "github.com/sagernet/sing/contrab/freelru"
  21. "github.com/sagernet/sing/contrab/maphash"
  22. "github.com/sagernet/sing/service"
  23. mDNS "github.com/miekg/dns"
  24. )
  25. var _ adapter.DNSRouter = (*Router)(nil)
  26. type Router struct {
  27. ctx context.Context
  28. logger logger.ContextLogger
  29. transport adapter.DNSTransportManager
  30. outbound adapter.OutboundManager
  31. client adapter.DNSClient
  32. rules []adapter.DNSRule
  33. defaultDomainStrategy C.DomainStrategy
  34. dnsReverseMapping freelru.Cache[netip.Addr, string]
  35. platformInterface adapter.PlatformInterface
  36. }
  37. func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOptions) *Router {
  38. router := &Router{
  39. ctx: ctx,
  40. logger: logFactory.NewLogger("dns"),
  41. transport: service.FromContext[adapter.DNSTransportManager](ctx),
  42. outbound: service.FromContext[adapter.OutboundManager](ctx),
  43. rules: make([]adapter.DNSRule, 0, len(options.Rules)),
  44. defaultDomainStrategy: C.DomainStrategy(options.Strategy),
  45. }
  46. router.client = NewClient(ClientOptions{
  47. DisableCache: options.DNSClientOptions.DisableCache,
  48. DisableExpire: options.DNSClientOptions.DisableExpire,
  49. IndependentCache: options.DNSClientOptions.IndependentCache,
  50. CacheCapacity: options.DNSClientOptions.CacheCapacity,
  51. ClientSubnet: options.DNSClientOptions.ClientSubnet.Build(netip.Prefix{}),
  52. RDRC: func() adapter.RDRCStore {
  53. cacheFile := service.FromContext[adapter.CacheFile](ctx)
  54. if cacheFile == nil {
  55. return nil
  56. }
  57. if !cacheFile.StoreRDRC() {
  58. return nil
  59. }
  60. return cacheFile
  61. },
  62. Logger: router.logger,
  63. })
  64. if options.ReverseMapping {
  65. router.dnsReverseMapping = common.Must1(freelru.NewSharded[netip.Addr, string](1024, maphash.NewHasher[netip.Addr]().Hash32))
  66. }
  67. return router
  68. }
  69. func (r *Router) Initialize(rules []option.DNSRule) error {
  70. for i, ruleOptions := range rules {
  71. dnsRule, err := R.NewDNSRule(r.ctx, r.logger, ruleOptions, true)
  72. if err != nil {
  73. return E.Cause(err, "parse dns rule[", i, "]")
  74. }
  75. r.rules = append(r.rules, dnsRule)
  76. }
  77. return nil
  78. }
  79. func (r *Router) Start(stage adapter.StartStage) error {
  80. monitor := taskmonitor.New(r.logger, C.StartTimeout)
  81. switch stage {
  82. case adapter.StartStateStart:
  83. monitor.Start("initialize DNS client")
  84. r.client.Start()
  85. monitor.Finish()
  86. for i, rule := range r.rules {
  87. monitor.Start("initialize DNS rule[", i, "]")
  88. err := rule.Start()
  89. monitor.Finish()
  90. if err != nil {
  91. return E.Cause(err, "initialize DNS rule[", i, "]")
  92. }
  93. }
  94. }
  95. return nil
  96. }
  97. func (r *Router) Close() error {
  98. monitor := taskmonitor.New(r.logger, C.StopTimeout)
  99. var err error
  100. for i, rule := range r.rules {
  101. monitor.Start("close dns rule[", i, "]")
  102. err = E.Append(err, rule.Close(), func(err error) error {
  103. return E.Cause(err, "close dns rule[", i, "]")
  104. })
  105. monitor.Finish()
  106. }
  107. return err
  108. }
  109. func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) {
  110. metadata := adapter.ContextFrom(ctx)
  111. if metadata == nil {
  112. panic("no context")
  113. }
  114. var currentRuleIndex int
  115. if ruleIndex != -1 {
  116. currentRuleIndex = ruleIndex + 1
  117. }
  118. for ; currentRuleIndex < len(r.rules); currentRuleIndex++ {
  119. currentRule := r.rules[currentRuleIndex]
  120. if currentRule.WithAddressLimit() && !isAddressQuery {
  121. continue
  122. }
  123. metadata.ResetRuleCache()
  124. if currentRule.Match(metadata) {
  125. displayRuleIndex := currentRuleIndex
  126. if displayRuleIndex != -1 {
  127. displayRuleIndex += displayRuleIndex + 1
  128. }
  129. ruleDescription := currentRule.String()
  130. if ruleDescription != "" {
  131. r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] ", currentRule, " => ", currentRule.Action())
  132. } else {
  133. r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action())
  134. }
  135. switch action := currentRule.Action().(type) {
  136. case *R.RuleActionDNSRoute:
  137. transport, loaded := r.transport.Transport(action.Server)
  138. if !loaded {
  139. r.logger.ErrorContext(ctx, "transport not found: ", action.Server)
  140. continue
  141. }
  142. isFakeIP := transport.Type() == C.DNSTypeFakeIP
  143. if isFakeIP && !allowFakeIP {
  144. continue
  145. }
  146. if action.Strategy != C.DomainStrategyAsIS {
  147. options.Strategy = action.Strategy
  148. }
  149. if isFakeIP || action.DisableCache {
  150. options.DisableCache = true
  151. }
  152. if action.RewriteTTL != nil {
  153. options.RewriteTTL = action.RewriteTTL
  154. }
  155. if action.ClientSubnet.IsValid() {
  156. options.ClientSubnet = action.ClientSubnet
  157. }
  158. if legacyTransport, isLegacy := transport.(adapter.LegacyDNSTransport); isLegacy {
  159. if options.Strategy == C.DomainStrategyAsIS {
  160. options.Strategy = legacyTransport.LegacyStrategy()
  161. }
  162. if !options.ClientSubnet.IsValid() {
  163. options.ClientSubnet = legacyTransport.LegacyClientSubnet()
  164. }
  165. }
  166. return transport, currentRule, currentRuleIndex
  167. case *R.RuleActionDNSRouteOptions:
  168. if action.Strategy != C.DomainStrategyAsIS {
  169. options.Strategy = action.Strategy
  170. }
  171. if action.DisableCache {
  172. options.DisableCache = true
  173. }
  174. if action.RewriteTTL != nil {
  175. options.RewriteTTL = action.RewriteTTL
  176. }
  177. if action.ClientSubnet.IsValid() {
  178. options.ClientSubnet = action.ClientSubnet
  179. }
  180. case *R.RuleActionReject:
  181. return nil, currentRule, currentRuleIndex
  182. case *R.RuleActionPredefined:
  183. return nil, currentRule, currentRuleIndex
  184. }
  185. }
  186. }
  187. return r.transport.Default(), nil, -1
  188. }
  189. func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions) (*mDNS.Msg, error) {
  190. if len(message.Question) != 1 {
  191. r.logger.WarnContext(ctx, "bad question size: ", len(message.Question))
  192. responseMessage := mDNS.Msg{
  193. MsgHdr: mDNS.MsgHdr{
  194. Id: message.Id,
  195. Response: true,
  196. Rcode: mDNS.RcodeFormatError,
  197. },
  198. Question: message.Question,
  199. }
  200. return &responseMessage, nil
  201. }
  202. r.logger.DebugContext(ctx, "exchange ", FormatQuestion(message.Question[0].String()))
  203. var (
  204. transport adapter.DNSTransport
  205. err error
  206. )
  207. response, cached := r.client.ExchangeCache(ctx, message)
  208. if !cached {
  209. var metadata *adapter.InboundContext
  210. ctx, metadata = adapter.ExtendContext(ctx)
  211. metadata.Destination = M.Socksaddr{}
  212. metadata.QueryType = message.Question[0].Qtype
  213. switch metadata.QueryType {
  214. case mDNS.TypeA:
  215. metadata.IPVersion = 4
  216. case mDNS.TypeAAAA:
  217. metadata.IPVersion = 6
  218. }
  219. metadata.Domain = FqdnToDomain(message.Question[0].Name)
  220. if options.Transport != nil {
  221. transport = options.Transport
  222. if legacyTransport, isLegacy := transport.(adapter.LegacyDNSTransport); isLegacy {
  223. if options.Strategy == C.DomainStrategyAsIS {
  224. options.Strategy = legacyTransport.LegacyStrategy()
  225. }
  226. if !options.ClientSubnet.IsValid() {
  227. options.ClientSubnet = legacyTransport.LegacyClientSubnet()
  228. }
  229. }
  230. if options.Strategy == C.DomainStrategyAsIS {
  231. options.Strategy = r.defaultDomainStrategy
  232. }
  233. response, err = r.client.Exchange(ctx, transport, message, options, nil)
  234. } else {
  235. var (
  236. rule adapter.DNSRule
  237. ruleIndex int
  238. )
  239. ruleIndex = -1
  240. for {
  241. dnsCtx := adapter.OverrideContext(ctx)
  242. dnsOptions := options
  243. transport, rule, ruleIndex = r.matchDNS(ctx, true, ruleIndex, isAddressQuery(message), &dnsOptions)
  244. if rule != nil {
  245. switch action := rule.Action().(type) {
  246. case *R.RuleActionReject:
  247. switch action.Method {
  248. case C.RuleActionRejectMethodDefault:
  249. return &mDNS.Msg{
  250. MsgHdr: mDNS.MsgHdr{
  251. Id: message.Id,
  252. Rcode: mDNS.RcodeRefused,
  253. Response: true,
  254. },
  255. Question: []mDNS.Question{message.Question[0]},
  256. }, nil
  257. case C.RuleActionRejectMethodDrop:
  258. return nil, tun.ErrDrop
  259. }
  260. case *R.RuleActionPredefined:
  261. return action.Response(message), nil
  262. }
  263. }
  264. var responseCheck func(responseAddrs []netip.Addr) bool
  265. if rule != nil && rule.WithAddressLimit() {
  266. responseCheck = func(responseAddrs []netip.Addr) bool {
  267. metadata.DestinationAddresses = responseAddrs
  268. return rule.MatchAddressLimit(metadata)
  269. }
  270. }
  271. if dnsOptions.Strategy == C.DomainStrategyAsIS {
  272. dnsOptions.Strategy = r.defaultDomainStrategy
  273. }
  274. response, err = r.client.Exchange(dnsCtx, transport, message, dnsOptions, responseCheck)
  275. var rejected bool
  276. if err != nil {
  277. if errors.Is(err, ErrResponseRejectedCached) {
  278. rejected = true
  279. r.logger.DebugContext(ctx, E.Cause(err, "response rejected for ", FormatQuestion(message.Question[0].String())), " (cached)")
  280. } else if errors.Is(err, ErrResponseRejected) {
  281. rejected = true
  282. r.logger.DebugContext(ctx, E.Cause(err, "response rejected for ", FormatQuestion(message.Question[0].String())))
  283. } else if len(message.Question) > 0 {
  284. r.logger.ErrorContext(ctx, E.Cause(err, "exchange failed for ", FormatQuestion(message.Question[0].String())))
  285. } else {
  286. r.logger.ErrorContext(ctx, E.Cause(err, "exchange failed for <empty query>"))
  287. }
  288. }
  289. if responseCheck != nil && rejected {
  290. continue
  291. }
  292. break
  293. }
  294. }
  295. }
  296. if err != nil {
  297. return nil, err
  298. }
  299. if r.dnsReverseMapping != nil && len(message.Question) > 0 && response != nil && len(response.Answer) > 0 {
  300. if transport == nil || transport.Type() != C.DNSTypeFakeIP {
  301. for _, answer := range response.Answer {
  302. switch record := answer.(type) {
  303. case *mDNS.A:
  304. r.dnsReverseMapping.AddWithLifetime(M.AddrFromIP(record.A), FqdnToDomain(record.Hdr.Name), time.Duration(record.Hdr.Ttl)*time.Second)
  305. case *mDNS.AAAA:
  306. r.dnsReverseMapping.AddWithLifetime(M.AddrFromIP(record.AAAA), FqdnToDomain(record.Hdr.Name), time.Duration(record.Hdr.Ttl)*time.Second)
  307. }
  308. }
  309. }
  310. }
  311. return response, nil
  312. }
  313. func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) {
  314. var (
  315. responseAddrs []netip.Addr
  316. cached bool
  317. err error
  318. )
  319. printResult := func() {
  320. if err == nil && len(responseAddrs) == 0 {
  321. err = E.New("empty result")
  322. }
  323. if err != nil {
  324. if errors.Is(err, ErrResponseRejectedCached) {
  325. r.logger.DebugContext(ctx, "response rejected for ", domain, " (cached)")
  326. } else if errors.Is(err, ErrResponseRejected) {
  327. r.logger.DebugContext(ctx, "response rejected for ", domain)
  328. } else {
  329. r.logger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain))
  330. }
  331. }
  332. if err != nil {
  333. err = E.Cause(err, "lookup ", domain)
  334. }
  335. }
  336. responseAddrs, cached = r.client.LookupCache(domain, options.Strategy)
  337. if cached {
  338. if len(responseAddrs) == 0 {
  339. return nil, E.New("lookup ", domain, ": empty result (cached)")
  340. }
  341. return responseAddrs, nil
  342. }
  343. r.logger.DebugContext(ctx, "lookup domain ", domain)
  344. ctx, metadata := adapter.ExtendContext(ctx)
  345. metadata.Destination = M.Socksaddr{}
  346. metadata.Domain = FqdnToDomain(domain)
  347. if options.Transport != nil {
  348. transport := options.Transport
  349. if legacyTransport, isLegacy := transport.(adapter.LegacyDNSTransport); isLegacy {
  350. if options.Strategy == C.DomainStrategyAsIS {
  351. options.Strategy = r.defaultDomainStrategy
  352. }
  353. if !options.ClientSubnet.IsValid() {
  354. options.ClientSubnet = legacyTransport.LegacyClientSubnet()
  355. }
  356. }
  357. if options.Strategy == C.DomainStrategyAsIS {
  358. options.Strategy = r.defaultDomainStrategy
  359. }
  360. responseAddrs, err = r.client.Lookup(ctx, transport, domain, options, nil)
  361. } else {
  362. var (
  363. transport adapter.DNSTransport
  364. rule adapter.DNSRule
  365. ruleIndex int
  366. )
  367. ruleIndex = -1
  368. for {
  369. dnsCtx := adapter.OverrideContext(ctx)
  370. dnsOptions := options
  371. transport, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex, true, &dnsOptions)
  372. if rule != nil {
  373. switch action := rule.Action().(type) {
  374. case *R.RuleActionReject:
  375. switch action.Method {
  376. case C.RuleActionRejectMethodDefault:
  377. return nil, nil
  378. case C.RuleActionRejectMethodDrop:
  379. return nil, tun.ErrDrop
  380. }
  381. case *R.RuleActionPredefined:
  382. if action.Rcode != mDNS.RcodeSuccess {
  383. err = RcodeError(action.Rcode)
  384. } else {
  385. for _, answer := range action.Answer {
  386. switch record := answer.(type) {
  387. case *mDNS.A:
  388. responseAddrs = append(responseAddrs, M.AddrFromIP(record.A))
  389. case *mDNS.AAAA:
  390. responseAddrs = append(responseAddrs, M.AddrFromIP(record.AAAA))
  391. }
  392. }
  393. }
  394. goto response
  395. }
  396. }
  397. var responseCheck func(responseAddrs []netip.Addr) bool
  398. if rule != nil && rule.WithAddressLimit() {
  399. responseCheck = func(responseAddrs []netip.Addr) bool {
  400. metadata.DestinationAddresses = responseAddrs
  401. return rule.MatchAddressLimit(metadata)
  402. }
  403. }
  404. if dnsOptions.Strategy == C.DomainStrategyAsIS {
  405. dnsOptions.Strategy = r.defaultDomainStrategy
  406. }
  407. responseAddrs, err = r.client.Lookup(dnsCtx, transport, domain, dnsOptions, responseCheck)
  408. if responseCheck == nil || err == nil {
  409. break
  410. }
  411. printResult()
  412. }
  413. }
  414. response:
  415. printResult()
  416. if len(responseAddrs) > 0 {
  417. r.logger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(responseAddrs), " "))
  418. }
  419. return responseAddrs, err
  420. }
  421. func isAddressQuery(message *mDNS.Msg) bool {
  422. for _, question := range message.Question {
  423. if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA || question.Qtype == mDNS.TypeHTTPS {
  424. return true
  425. }
  426. }
  427. return false
  428. }
  429. func (r *Router) ClearCache() {
  430. r.client.ClearCache()
  431. if r.platformInterface != nil {
  432. r.platformInterface.ClearDNSCache()
  433. }
  434. }
  435. func (r *Router) LookupReverseMapping(ip netip.Addr) (string, bool) {
  436. if r.dnsReverseMapping == nil {
  437. return "", false
  438. }
  439. domain, loaded := r.dnsReverseMapping.Get(ip)
  440. return domain, loaded
  441. }
  442. func (r *Router) ResetNetwork() {
  443. r.ClearCache()
  444. for _, transport := range r.transport.Transports() {
  445. transport.Close()
  446. }
  447. }