hosts.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. package dns
  2. import (
  3. "context"
  4. "strconv"
  5. "github.com/xtls/xray-core/common/errors"
  6. "github.com/xtls/xray-core/common/net"
  7. "github.com/xtls/xray-core/common/strmatcher"
  8. "github.com/xtls/xray-core/features/dns"
  9. )
  10. // StaticHosts represents static domain-ip mapping in DNS server.
  11. type StaticHosts struct {
  12. ips [][]net.Address
  13. matchers *strmatcher.MatcherGroup
  14. }
  15. // NewStaticHosts creates a new StaticHosts instance.
  16. func NewStaticHosts(hosts []*Config_HostMapping) (*StaticHosts, error) {
  17. g := new(strmatcher.MatcherGroup)
  18. sh := &StaticHosts{
  19. ips: make([][]net.Address, len(hosts)+16),
  20. matchers: g,
  21. }
  22. for _, mapping := range hosts {
  23. matcher, err := toStrMatcher(mapping.Type, mapping.Domain)
  24. if err != nil {
  25. errors.LogErrorInner(context.Background(), err, "failed to create domain matcher, ignore domain rule [type: ", mapping.Type, ", domain: ", mapping.Domain, "]")
  26. continue
  27. }
  28. id := g.Add(matcher)
  29. ips := make([]net.Address, 0, len(mapping.Ip)+1)
  30. switch {
  31. case len(mapping.ProxiedDomain) > 0:
  32. if mapping.ProxiedDomain[0] == '#' {
  33. rcode, err := strconv.Atoi(mapping.ProxiedDomain[1:])
  34. if err != nil {
  35. return nil, err
  36. }
  37. ips = append(ips, dns.RCodeError(rcode))
  38. } else {
  39. ips = append(ips, net.DomainAddress(mapping.ProxiedDomain))
  40. }
  41. case len(mapping.Ip) > 0:
  42. for _, ip := range mapping.Ip {
  43. addr := net.IPAddress(ip)
  44. if addr == nil {
  45. errors.LogError(context.Background(), "invalid IP address in static hosts: ", ip, ", ignore this ip for rule [type: ", mapping.Type, ", domain: ", mapping.Domain, "]")
  46. continue
  47. }
  48. ips = append(ips, addr)
  49. }
  50. if len(ips) == 0 {
  51. continue
  52. }
  53. }
  54. sh.ips[id] = ips
  55. }
  56. return sh, nil
  57. }
  58. func filterIP(ips []net.Address, option dns.IPOption) []net.Address {
  59. filtered := make([]net.Address, 0, len(ips))
  60. for _, ip := range ips {
  61. if (ip.Family().IsIPv4() && option.IPv4Enable) || (ip.Family().IsIPv6() && option.IPv6Enable) {
  62. filtered = append(filtered, ip)
  63. }
  64. }
  65. return filtered
  66. }
  67. func (h *StaticHosts) lookupInternal(domain string) ([]net.Address, error) {
  68. ips := make([]net.Address, 0)
  69. found := false
  70. for _, id := range h.matchers.Match(domain) {
  71. for _, v := range h.ips[id] {
  72. if err, ok := v.(dns.RCodeError); ok {
  73. if uint16(err) == 0 {
  74. return nil, dns.ErrEmptyResponse
  75. }
  76. return nil, err
  77. }
  78. }
  79. ips = append(ips, h.ips[id]...)
  80. found = true
  81. }
  82. if !found {
  83. return nil, nil
  84. }
  85. return ips, nil
  86. }
  87. func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) ([]net.Address, error) {
  88. switch addrs, err := h.lookupInternal(domain); {
  89. case err != nil:
  90. return nil, err
  91. case len(addrs) == 0: // Not recorded in static hosts, return nil
  92. return addrs, nil
  93. case len(addrs) == 1 && addrs[0].Family().IsDomain(): // Try to unwrap domain
  94. errors.LogDebug(context.Background(), "found replaced domain: ", domain, " -> ", addrs[0].Domain(), ". Try to unwrap it")
  95. if maxDepth > 0 {
  96. unwrapped, err := h.lookup(addrs[0].Domain(), option, maxDepth-1)
  97. if err != nil {
  98. return nil, err
  99. }
  100. if unwrapped != nil {
  101. return unwrapped, nil
  102. }
  103. }
  104. return addrs, nil
  105. default: // IP record found, return a non-nil IP array
  106. return filterIP(addrs, option), nil
  107. }
  108. }
  109. // Lookup returns IP addresses or proxied domain for the given domain, if exists in this StaticHosts.
  110. func (h *StaticHosts) Lookup(domain string, option dns.IPOption) ([]net.Address, error) {
  111. return h.lookup(domain, option, 5)
  112. }