hosts_file.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. package hosts
  2. import (
  3. "bufio"
  4. "errors"
  5. "io"
  6. "net/netip"
  7. "os"
  8. "strings"
  9. "sync"
  10. "time"
  11. "github.com/miekg/dns"
  12. )
  13. const cacheMaxAge = 5 * time.Second
  14. type File struct {
  15. path string
  16. access sync.Mutex
  17. byName map[string][]netip.Addr
  18. expire time.Time
  19. modTime time.Time
  20. size int64
  21. }
  22. func NewFile(path string) *File {
  23. return &File{
  24. path: path,
  25. }
  26. }
  27. func (f *File) Lookup(name string) []netip.Addr {
  28. f.access.Lock()
  29. defer f.access.Unlock()
  30. f.update()
  31. return f.byName[name]
  32. }
  33. func (f *File) update() {
  34. now := time.Now()
  35. if now.Before(f.expire) && len(f.byName) > 0 {
  36. return
  37. }
  38. stat, err := os.Stat(f.path)
  39. if err != nil {
  40. return
  41. }
  42. if f.modTime.Equal(stat.ModTime()) && f.size == stat.Size() {
  43. f.expire = now.Add(cacheMaxAge)
  44. return
  45. }
  46. byName := make(map[string][]netip.Addr)
  47. file, err := os.Open(f.path)
  48. if err != nil {
  49. return
  50. }
  51. defer file.Close()
  52. reader := bufio.NewReader(file)
  53. var (
  54. prefix []byte
  55. line []byte
  56. isPrefix bool
  57. )
  58. for {
  59. line, isPrefix, err = reader.ReadLine()
  60. if err != nil {
  61. if errors.Is(err, io.EOF) {
  62. break
  63. }
  64. return
  65. }
  66. if isPrefix {
  67. prefix = append(prefix, line...)
  68. continue
  69. } else if len(prefix) > 0 {
  70. line = append(prefix, line...)
  71. prefix = nil
  72. }
  73. commentIndex := strings.IndexRune(string(line), '#')
  74. if commentIndex != -1 {
  75. line = line[:commentIndex]
  76. }
  77. fields := strings.Fields(string(line))
  78. if len(fields) < 2 {
  79. continue
  80. }
  81. var addr netip.Addr
  82. addr, err = netip.ParseAddr(fields[0])
  83. if err != nil {
  84. continue
  85. }
  86. for index := 1; index < len(fields); index++ {
  87. canonicalName := dns.CanonicalName(fields[index])
  88. byName[canonicalName] = append(byName[canonicalName], addr)
  89. }
  90. }
  91. f.expire = now.Add(cacheMaxAge)
  92. f.modTime = stat.ModTime()
  93. f.size = stat.Size()
  94. f.byName = byName
  95. }