direct_test.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package dns
  4. import (
  5. "context"
  6. "errors"
  7. "fmt"
  8. "io/fs"
  9. "net/netip"
  10. "os"
  11. "path/filepath"
  12. "strings"
  13. "syscall"
  14. "testing"
  15. qt "github.com/frankban/quicktest"
  16. "tailscale.com/util/dnsname"
  17. )
  18. func TestDirectManager(t *testing.T) {
  19. tmp := t.TempDir()
  20. if err := os.MkdirAll(filepath.Join(tmp, "etc"), 0700); err != nil {
  21. t.Fatal(err)
  22. }
  23. testDirect(t, directFS{prefix: tmp})
  24. }
  25. type boundResolvConfFS struct {
  26. directFS
  27. }
  28. func (fs boundResolvConfFS) Rename(old, new string) error {
  29. if old == "/etc/resolv.conf" || new == "/etc/resolv.conf" {
  30. return errors.New("cannot move to/from /etc/resolv.conf")
  31. }
  32. return fs.directFS.Rename(old, new)
  33. }
  34. func (fs boundResolvConfFS) Remove(name string) error {
  35. if name == "/etc/resolv.conf" {
  36. return errors.New("cannot remove /etc/resolv.conf")
  37. }
  38. return fs.directFS.Remove(name)
  39. }
  40. func TestDirectBrokenRename(t *testing.T) {
  41. tmp := t.TempDir()
  42. if err := os.MkdirAll(filepath.Join(tmp, "etc"), 0700); err != nil {
  43. t.Fatal(err)
  44. }
  45. testDirect(t, boundResolvConfFS{directFS{prefix: tmp}})
  46. }
  47. func testDirect(t *testing.T, fs wholeFileFS) {
  48. const orig = "nameserver 9.9.9.9 # orig"
  49. resolvPath := "/etc/resolv.conf"
  50. backupPath := "/etc/resolv.pre-tailscale-backup.conf"
  51. if err := fs.WriteFile(resolvPath, []byte(orig), 0644); err != nil {
  52. t.Fatal(err)
  53. }
  54. readFile := func(t *testing.T, path string) string {
  55. t.Helper()
  56. b, err := fs.ReadFile(path)
  57. if err != nil {
  58. t.Fatal(err)
  59. }
  60. return string(b)
  61. }
  62. assertBaseState := func(t *testing.T) {
  63. if got := readFile(t, resolvPath); got != orig {
  64. t.Fatalf("resolv.conf:\n%s, want:\n%s", got, orig)
  65. }
  66. if _, err := fs.Stat(backupPath); !os.IsNotExist(err) {
  67. t.Fatalf("resolv.conf backup: want it to be gone but: %v", err)
  68. }
  69. }
  70. ctx, cancel := context.WithCancel(context.Background())
  71. defer cancel()
  72. m := directManager{logf: t.Logf, fs: fs, ctx: ctx, ctxClose: cancel}
  73. if err := m.SetDNS(OSConfig{
  74. Nameservers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")},
  75. SearchDomains: []dnsname.FQDN{"ts.net.", "ts-dns.test."},
  76. MatchDomains: []dnsname.FQDN{"ignored."},
  77. }); err != nil {
  78. t.Fatal(err)
  79. }
  80. want := `# resolv.conf(5) file generated by tailscale
  81. # For more info, see https://tailscale.com/s/resolvconf-overwrite
  82. # DO NOT EDIT THIS FILE BY HAND -- CHANGES WILL BE OVERWRITTEN
  83. nameserver 8.8.8.8
  84. nameserver 8.8.4.4
  85. search ts.net ts-dns.test
  86. `
  87. if got := readFile(t, resolvPath); got != want {
  88. t.Fatalf("resolv.conf:\n%s, want:\n%s", got, want)
  89. }
  90. if got := readFile(t, backupPath); got != orig {
  91. t.Fatalf("resolv.conf backup:\n%s, want:\n%s", got, orig)
  92. }
  93. // Test that a nil OSConfig cleans up resolv.conf.
  94. if err := m.SetDNS(OSConfig{}); err != nil {
  95. t.Fatal(err)
  96. }
  97. assertBaseState(t)
  98. // Test that Close cleans up resolv.conf.
  99. if err := m.SetDNS(OSConfig{Nameservers: []netip.Addr{netip.MustParseAddr("8.8.8.8")}}); err != nil {
  100. t.Fatal(err)
  101. }
  102. if err := m.Close(); err != nil {
  103. t.Fatal(err)
  104. }
  105. assertBaseState(t)
  106. }
  107. type brokenRemoveFS struct {
  108. directFS
  109. }
  110. func (b brokenRemoveFS) Rename(old, new string) error {
  111. return errors.New("nyaaah I'm a silly container!")
  112. }
  113. func (b brokenRemoveFS) Remove(name string) error {
  114. if strings.Contains(name, "/etc/resolv.conf") {
  115. return fmt.Errorf("Faking remove failure: %q", &fs.PathError{Err: syscall.EBUSY})
  116. }
  117. return b.directFS.Remove(name)
  118. }
  119. func TestDirectBrokenRemove(t *testing.T) {
  120. tmp := t.TempDir()
  121. if err := os.MkdirAll(filepath.Join(tmp, "etc"), 0700); err != nil {
  122. t.Fatal(err)
  123. }
  124. testDirect(t, brokenRemoveFS{directFS{prefix: tmp}})
  125. }
  126. func TestReadResolve(t *testing.T) {
  127. c := qt.New(t)
  128. tests := []struct {
  129. in string
  130. want OSConfig
  131. wantErr bool
  132. }{
  133. {in: `nameserver 192.168.0.100`,
  134. want: OSConfig{
  135. Nameservers: []netip.Addr{
  136. netip.MustParseAddr("192.168.0.100"),
  137. },
  138. },
  139. },
  140. {in: `nameserver 192.168.0.100 # comment`,
  141. want: OSConfig{
  142. Nameservers: []netip.Addr{
  143. netip.MustParseAddr("192.168.0.100"),
  144. },
  145. },
  146. },
  147. {in: `nameserver 192.168.0.100#`,
  148. want: OSConfig{
  149. Nameservers: []netip.Addr{
  150. netip.MustParseAddr("192.168.0.100"),
  151. },
  152. },
  153. },
  154. {in: `nameserver #192.168.0.100`, wantErr: true},
  155. {in: `nameserver`, wantErr: true},
  156. {in: `# nameserver 192.168.0.100`, want: OSConfig{}},
  157. {in: `nameserver192.168.0.100`, wantErr: true},
  158. {in: `search tailscale.com`,
  159. want: OSConfig{
  160. SearchDomains: []dnsname.FQDN{"tailscale.com."},
  161. },
  162. },
  163. {in: `search tailscale.com # comment`,
  164. want: OSConfig{
  165. SearchDomains: []dnsname.FQDN{"tailscale.com."},
  166. },
  167. },
  168. {in: `searchtailscale.com`, wantErr: true},
  169. {in: `search`, wantErr: true},
  170. }
  171. for _, test := range tests {
  172. cfg, err := readResolv(strings.NewReader(test.in))
  173. if test.wantErr {
  174. c.Assert(err, qt.IsNotNil)
  175. } else {
  176. c.Assert(err, qt.IsNil)
  177. }
  178. c.Assert(cfg, qt.DeepEquals, test.want)
  179. }
  180. }