hosts.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. package dns
  2. import (
  3. "context"
  4. "strconv"
  5. "github.com/xtls/xray-core/common/errors"
  6. "github.com/xtls/xray-core/common/net"
  7. "github.com/xtls/xray-core/common/strmatcher"
  8. "github.com/xtls/xray-core/features/dns"
  9. )
  10. // StaticHosts represents static domain-ip mapping in DNS server.
  11. type StaticHosts struct {
  12. ips [][]net.Address
  13. matchers *strmatcher.MatcherGroup
  14. }
  15. // NewStaticHosts creates a new StaticHosts instance.
  16. func NewStaticHosts(hosts []*Config_HostMapping) (*StaticHosts, error) {
  17. g := new(strmatcher.MatcherGroup)
  18. sh := &StaticHosts{
  19. ips: make([][]net.Address, len(hosts)+16),
  20. matchers: g,
  21. }
  22. for _, mapping := range hosts {
  23. matcher, err := toStrMatcher(mapping.Type, mapping.Domain)
  24. if err != nil {
  25. return nil, errors.New("failed to create domain matcher").Base(err)
  26. }
  27. id := g.Add(matcher)
  28. ips := make([]net.Address, 0, len(mapping.Ip)+1)
  29. switch {
  30. case len(mapping.ProxiedDomain) > 0:
  31. if mapping.ProxiedDomain[0] == '#' {
  32. rcode, err := strconv.Atoi(mapping.ProxiedDomain[1:])
  33. if err != nil {
  34. return nil, err
  35. }
  36. ips = append(ips, dns.RCodeError(rcode))
  37. } else {
  38. ips = append(ips, net.DomainAddress(mapping.ProxiedDomain))
  39. }
  40. case len(mapping.Ip) > 0:
  41. for _, ip := range mapping.Ip {
  42. addr := net.IPAddress(ip)
  43. if addr == nil {
  44. return nil, errors.New("invalid IP address in static hosts: ", ip).AtWarning()
  45. }
  46. ips = append(ips, addr)
  47. }
  48. }
  49. sh.ips[id] = ips
  50. }
  51. return sh, nil
  52. }
  53. func filterIP(ips []net.Address, option dns.IPOption) []net.Address {
  54. filtered := make([]net.Address, 0, len(ips))
  55. for _, ip := range ips {
  56. if (ip.Family().IsIPv4() && option.IPv4Enable) || (ip.Family().IsIPv6() && option.IPv6Enable) {
  57. filtered = append(filtered, ip)
  58. }
  59. }
  60. return filtered
  61. }
  62. func (h *StaticHosts) lookupInternal(domain string) ([]net.Address, error) {
  63. ips := make([]net.Address, 0)
  64. found := false
  65. for _, id := range h.matchers.Match(domain) {
  66. for _, v := range h.ips[id] {
  67. if err, ok := v.(dns.RCodeError); ok {
  68. if uint16(err) == 0 {
  69. return nil, dns.ErrEmptyResponse
  70. }
  71. return nil, err
  72. }
  73. }
  74. ips = append(ips, h.ips[id]...)
  75. found = true
  76. }
  77. if !found {
  78. return nil, nil
  79. }
  80. return ips, nil
  81. }
  82. func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) ([]net.Address, error) {
  83. switch addrs, err := h.lookupInternal(domain); {
  84. case err != nil:
  85. return nil, err
  86. case len(addrs) == 0: // Not recorded in static hosts, return nil
  87. return addrs, nil
  88. case len(addrs) == 1 && addrs[0].Family().IsDomain(): // Try to unwrap domain
  89. errors.LogDebug(context.Background(), "found replaced domain: ", domain, " -> ", addrs[0].Domain(), ". Try to unwrap it")
  90. if maxDepth > 0 {
  91. unwrapped, err := h.lookup(addrs[0].Domain(), option, maxDepth-1)
  92. if err != nil {
  93. return nil, err
  94. }
  95. if unwrapped != nil {
  96. return unwrapped, nil
  97. }
  98. }
  99. return addrs, nil
  100. default: // IP record found, return a non-nil IP array
  101. return filterIP(addrs, option), nil
  102. }
  103. }
  104. // Lookup returns IP addresses or proxied domain for the given domain, if exists in this StaticHosts.
  105. func (h *StaticHosts) Lookup(domain string, option dns.IPOption) ([]net.Address, error) {
  106. return h.lookup(domain, option, 5)
  107. }