natc_test.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package main
  4. import (
  5. "context"
  6. "fmt"
  7. "io"
  8. "net"
  9. "net/netip"
  10. "sync"
  11. "testing"
  12. "time"
  13. "github.com/gaissmai/bart"
  14. "golang.org/x/net/dns/dnsmessage"
  15. "tailscale.com/client/tailscale/apitype"
  16. "tailscale.com/cmd/natc/ippool"
  17. "tailscale.com/tailcfg"
  18. "tailscale.com/util/must"
  19. )
  20. func prefixEqual(a, b netip.Prefix) bool {
  21. return a.Bits() == b.Bits() && a.Addr() == b.Addr()
  22. }
  23. func TestULA(t *testing.T) {
  24. tests := []struct {
  25. name string
  26. siteID uint16
  27. expected string
  28. }{
  29. {"zero", 0, "fd7a:115c:a1e0:a99c:0000::/80"},
  30. {"one", 1, "fd7a:115c:a1e0:a99c:0001::/80"},
  31. {"max", 65535, "fd7a:115c:a1e0:a99c:ffff::/80"},
  32. {"random", 12345, "fd7a:115c:a1e0:a99c:3039::/80"},
  33. }
  34. for _, tc := range tests {
  35. t.Run(tc.name, func(t *testing.T) {
  36. got := ula(tc.siteID)
  37. expected := netip.MustParsePrefix(tc.expected)
  38. if !prefixEqual(got, expected) {
  39. t.Errorf("ula(%d) = %s; want %s", tc.siteID, got, expected)
  40. }
  41. })
  42. }
  43. }
  44. type recordingPacketConn struct {
  45. writes [][]byte
  46. }
  47. func (w *recordingPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
  48. w.writes = append(w.writes, b)
  49. return len(b), nil
  50. }
  51. func (w *recordingPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
  52. return 0, nil, io.EOF
  53. }
  54. func (w *recordingPacketConn) Close() error {
  55. return nil
  56. }
  57. func (w *recordingPacketConn) LocalAddr() net.Addr {
  58. return nil
  59. }
  60. func (w *recordingPacketConn) RemoteAddr() net.Addr {
  61. return nil
  62. }
  63. func (w *recordingPacketConn) SetDeadline(t time.Time) error {
  64. return nil
  65. }
  66. func (w *recordingPacketConn) SetReadDeadline(t time.Time) error {
  67. return nil
  68. }
  69. func (w *recordingPacketConn) SetWriteDeadline(t time.Time) error {
  70. return nil
  71. }
  72. type resolver struct {
  73. resolves map[string][]netip.Addr
  74. fails map[string]bool
  75. }
  76. func (r *resolver) LookupNetIP(ctx context.Context, _net, host string) ([]netip.Addr, error) {
  77. if addrs, ok := r.resolves[host]; ok {
  78. return addrs, nil
  79. }
  80. if _, ok := r.fails[host]; ok {
  81. return nil, &net.DNSError{IsTimeout: false, IsNotFound: false, Name: host, IsTemporary: true}
  82. }
  83. return nil, &net.DNSError{IsNotFound: true, Name: host}
  84. }
  85. type whois struct {
  86. peers map[string]*apitype.WhoIsResponse
  87. }
  88. func (w *whois) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) {
  89. addr := netip.MustParseAddrPort(remoteAddr).Addr().String()
  90. if peer, ok := w.peers[addr]; ok {
  91. return peer, nil
  92. }
  93. return nil, fmt.Errorf("peer not found")
  94. }
  95. func TestDNSResponse(t *testing.T) {
  96. tests := []struct {
  97. name string
  98. questions []dnsmessage.Question
  99. wantEmpty bool
  100. wantAnswers []struct {
  101. name string
  102. qType dnsmessage.Type
  103. addr netip.Addr
  104. }
  105. wantNXDOMAIN bool
  106. wantIgnored bool
  107. }{
  108. {
  109. name: "empty_request",
  110. questions: []dnsmessage.Question{},
  111. wantEmpty: false,
  112. wantAnswers: nil,
  113. },
  114. {
  115. name: "a_record",
  116. questions: []dnsmessage.Question{
  117. {
  118. Name: dnsmessage.MustNewName("example.com."),
  119. Type: dnsmessage.TypeA,
  120. Class: dnsmessage.ClassINET,
  121. },
  122. },
  123. wantAnswers: []struct {
  124. name string
  125. qType dnsmessage.Type
  126. addr netip.Addr
  127. }{
  128. {
  129. name: "example.com.",
  130. qType: dnsmessage.TypeA,
  131. addr: netip.MustParseAddr("100.64.0.0"),
  132. },
  133. },
  134. },
  135. {
  136. name: "aaaa_record",
  137. questions: []dnsmessage.Question{
  138. {
  139. Name: dnsmessage.MustNewName("example.com."),
  140. Type: dnsmessage.TypeAAAA,
  141. Class: dnsmessage.ClassINET,
  142. },
  143. },
  144. wantAnswers: []struct {
  145. name string
  146. qType dnsmessage.Type
  147. addr netip.Addr
  148. }{
  149. {
  150. name: "example.com.",
  151. qType: dnsmessage.TypeAAAA,
  152. addr: netip.MustParseAddr("fd7a:115c:a1e0::"),
  153. },
  154. },
  155. },
  156. {
  157. name: "soa_record",
  158. questions: []dnsmessage.Question{
  159. {
  160. Name: dnsmessage.MustNewName("example.com."),
  161. Type: dnsmessage.TypeSOA,
  162. Class: dnsmessage.ClassINET,
  163. },
  164. },
  165. wantAnswers: nil,
  166. },
  167. {
  168. name: "ns_record",
  169. questions: []dnsmessage.Question{
  170. {
  171. Name: dnsmessage.MustNewName("example.com."),
  172. Type: dnsmessage.TypeNS,
  173. Class: dnsmessage.ClassINET,
  174. },
  175. },
  176. wantAnswers: nil,
  177. },
  178. {
  179. name: "nxdomain",
  180. questions: []dnsmessage.Question{
  181. {
  182. Name: dnsmessage.MustNewName("noexist.example.com."),
  183. Type: dnsmessage.TypeA,
  184. Class: dnsmessage.ClassINET,
  185. },
  186. },
  187. wantNXDOMAIN: true,
  188. },
  189. {
  190. name: "servfail",
  191. questions: []dnsmessage.Question{
  192. {
  193. Name: dnsmessage.MustNewName("fail.example.com."),
  194. Type: dnsmessage.TypeA,
  195. Class: dnsmessage.ClassINET,
  196. },
  197. },
  198. wantEmpty: true, // TODO: pass through instead?
  199. },
  200. {
  201. name: "ignored",
  202. questions: []dnsmessage.Question{
  203. {
  204. Name: dnsmessage.MustNewName("ignore.example.com."),
  205. Type: dnsmessage.TypeA,
  206. Class: dnsmessage.ClassINET,
  207. },
  208. },
  209. wantAnswers: []struct {
  210. name string
  211. qType dnsmessage.Type
  212. addr netip.Addr
  213. }{
  214. {
  215. name: "ignore.example.com.",
  216. qType: dnsmessage.TypeA,
  217. addr: netip.MustParseAddr("8.8.4.4"),
  218. },
  219. },
  220. wantIgnored: true,
  221. },
  222. }
  223. var rpc recordingPacketConn
  224. remoteAddr := must.Get(net.ResolveUDPAddr("udp", "100.64.254.1:12345"))
  225. routes, dnsAddr, addrPool := calculateAddresses([]netip.Prefix{netip.MustParsePrefix("10.64.0.0/24")})
  226. v6ULA := ula(1)
  227. c := connector{
  228. resolver: &resolver{
  229. resolves: map[string][]netip.Addr{
  230. "example.com.": {
  231. netip.MustParseAddr("8.8.8.8"),
  232. netip.MustParseAddr("2001:4860:4860::8888"),
  233. },
  234. "ignore.example.com.": {
  235. netip.MustParseAddr("8.8.4.4"),
  236. },
  237. },
  238. fails: map[string]bool{
  239. "fail.example.com.": true,
  240. },
  241. },
  242. whois: &whois{
  243. peers: map[string]*apitype.WhoIsResponse{
  244. "100.64.254.1": {
  245. Node: &tailcfg.Node{ID: 123},
  246. },
  247. },
  248. },
  249. ignoreDsts: &bart.Table[bool]{},
  250. routes: routes,
  251. v6ULA: v6ULA,
  252. ipPool: &ippool.SingleMachineIPPool{IPSet: addrPool},
  253. dnsAddr: dnsAddr,
  254. }
  255. c.ignoreDsts.Insert(netip.MustParsePrefix("8.8.4.4/32"), true)
  256. for _, tc := range tests {
  257. t.Run(tc.name, func(t *testing.T) {
  258. rb := dnsmessage.NewBuilder(nil,
  259. dnsmessage.Header{
  260. ID: 1234,
  261. },
  262. )
  263. must.Do(rb.StartQuestions())
  264. for _, q := range tc.questions {
  265. rb.Question(q)
  266. }
  267. c.handleDNS(&rpc, must.Get(rb.Finish()), remoteAddr)
  268. writes := rpc.writes
  269. rpc.writes = rpc.writes[:0]
  270. if tc.wantEmpty {
  271. if len(writes) != 0 {
  272. t.Errorf("handleDNS() returned non-empty response when expected empty")
  273. }
  274. return
  275. }
  276. if !tc.wantEmpty && len(writes) != 1 {
  277. t.Fatalf("handleDNS() returned an unexpected number of responses: %d, want 1", len(writes))
  278. }
  279. resp := writes[0]
  280. var msg dnsmessage.Message
  281. err := msg.Unpack(resp)
  282. if err != nil {
  283. t.Fatalf("Failed to unpack response: %v", err)
  284. }
  285. if !msg.Header.Response {
  286. t.Errorf("Response header is not set")
  287. }
  288. if msg.Header.ID != 1234 {
  289. t.Errorf("Response ID = %d, want %d", msg.Header.ID, 1234)
  290. }
  291. if len(tc.wantAnswers) > 0 {
  292. if len(msg.Answers) != len(tc.wantAnswers) {
  293. t.Errorf("got %d answers, want %d:\n%s", len(msg.Answers), len(tc.wantAnswers), msg.GoString())
  294. } else {
  295. for i, want := range tc.wantAnswers {
  296. ans := msg.Answers[i]
  297. gotName := ans.Header.Name.String()
  298. if gotName != want.name {
  299. t.Errorf("answer[%d] name = %s, want %s", i, gotName, want.name)
  300. }
  301. if ans.Header.Type != want.qType {
  302. t.Errorf("answer[%d] type = %v, want %v", i, ans.Header.Type, want.qType)
  303. }
  304. switch want.qType {
  305. case dnsmessage.TypeA:
  306. if ans.Body.(*dnsmessage.AResource) == nil {
  307. t.Errorf("answer[%d] not an A record", i)
  308. continue
  309. }
  310. case dnsmessage.TypeAAAA:
  311. if ans.Body.(*dnsmessage.AAAAResource) == nil {
  312. t.Errorf("answer[%d] not an AAAA record", i)
  313. continue
  314. }
  315. }
  316. var gotIP netip.Addr
  317. switch want.qType {
  318. case dnsmessage.TypeA:
  319. resource := ans.Body.(*dnsmessage.AResource)
  320. gotIP = netip.AddrFrom4([4]byte(resource.A))
  321. case dnsmessage.TypeAAAA:
  322. resource := ans.Body.(*dnsmessage.AAAAResource)
  323. gotIP = netip.AddrFrom16([16]byte(resource.AAAA))
  324. }
  325. var wantIP netip.Addr
  326. if tc.wantIgnored {
  327. var net string
  328. var fxSelectIP func(netip.Addr) bool
  329. switch want.qType {
  330. case dnsmessage.TypeA:
  331. net = "ip4"
  332. fxSelectIP = func(a netip.Addr) bool {
  333. return a.Is4()
  334. }
  335. case dnsmessage.TypeAAAA:
  336. //TODO(fran) is this branch exercised?
  337. net = "ip6"
  338. fxSelectIP = func(a netip.Addr) bool {
  339. return a.Is6()
  340. }
  341. }
  342. ips := must.Get(c.resolver.LookupNetIP(t.Context(), net, want.name))
  343. for _, ip := range ips {
  344. if fxSelectIP(ip) {
  345. wantIP = ip
  346. break
  347. }
  348. }
  349. } else {
  350. addr := must.Get(c.ipPool.IPForDomain(tailcfg.NodeID(123), want.name))
  351. switch want.qType {
  352. case dnsmessage.TypeA:
  353. wantIP = addr
  354. case dnsmessage.TypeAAAA:
  355. wantIP = v6ForV4(v6ULA.Addr(), addr)
  356. }
  357. }
  358. if gotIP != wantIP {
  359. t.Errorf("answer[%d] IP = %s, want %s", i, gotIP, wantIP)
  360. }
  361. }
  362. }
  363. }
  364. if tc.wantNXDOMAIN {
  365. if msg.RCode != dnsmessage.RCodeNameError {
  366. t.Errorf("expected NXDOMAIN, got %v", msg.RCode)
  367. }
  368. if len(msg.Answers) != 0 {
  369. t.Errorf("expected no answers, got %d", len(msg.Answers))
  370. }
  371. }
  372. })
  373. }
  374. }
  375. func TestIgnoreDestination(t *testing.T) {
  376. ignoreDstTable := &bart.Table[bool]{}
  377. ignoreDstTable.Insert(netip.MustParsePrefix("192.168.1.0/24"), true)
  378. ignoreDstTable.Insert(netip.MustParsePrefix("10.0.0.0/8"), true)
  379. c := &connector{
  380. ignoreDsts: ignoreDstTable,
  381. }
  382. tests := []struct {
  383. name string
  384. addrs []netip.Addr
  385. expected bool
  386. }{
  387. {
  388. name: "no_match",
  389. addrs: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
  390. expected: false,
  391. },
  392. {
  393. name: "one_match",
  394. addrs: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("192.168.1.5")},
  395. expected: true,
  396. },
  397. {
  398. name: "all_match",
  399. addrs: []netip.Addr{netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("192.168.1.5")},
  400. expected: true,
  401. },
  402. {
  403. name: "empty_addrs",
  404. addrs: []netip.Addr{},
  405. expected: false,
  406. },
  407. }
  408. for _, tc := range tests {
  409. t.Run(tc.name, func(t *testing.T) {
  410. got := c.ignoreDestination(tc.addrs)
  411. if got != tc.expected {
  412. t.Errorf("ignoreDestination(%v) = %v, want %v", tc.addrs, got, tc.expected)
  413. }
  414. })
  415. }
  416. }
  417. func TestV6V4(t *testing.T) {
  418. v6ULA := ula(1)
  419. tests := [][]string{
  420. {"100.64.0.0", "fd7a:115c:a1e0:a99c:1:0:6440:0"},
  421. {"0.0.0.0", "fd7a:115c:a1e0:a99c:1::"},
  422. {"255.255.255.255", "fd7a:115c:a1e0:a99c:1:0:ffff:ffff"},
  423. }
  424. for i, test := range tests {
  425. // to v6
  426. v6 := v6ForV4(v6ULA.Addr(), netip.MustParseAddr(test[0]))
  427. want := netip.MustParseAddr(test[1])
  428. if v6 != want {
  429. t.Fatalf("test %d: want: %v, got: %v", i, want, v6)
  430. }
  431. // to v4
  432. v4 := v4ForV6(netip.MustParseAddr(test[1]))
  433. want = netip.MustParseAddr(test[0])
  434. if v4 != want {
  435. t.Fatalf("test %d: want: %v, got: %v", i, want, v4)
  436. }
  437. }
  438. }
  439. // echoServer is a simple server that just echos back data set to it.
  440. type echoServer struct {
  441. listener net.Listener
  442. addr string
  443. wg sync.WaitGroup
  444. done chan struct{}
  445. }
  446. // newEchoServer creates a new test DNS server on the specified network and address
  447. func newEchoServer(t *testing.T, network, addr string) *echoServer {
  448. listener, err := net.Listen(network, addr)
  449. if err != nil {
  450. t.Fatalf("Failed to create test DNS server: %v", err)
  451. }
  452. server := &echoServer{
  453. listener: listener,
  454. addr: listener.Addr().String(),
  455. done: make(chan struct{}),
  456. }
  457. server.wg.Add(1)
  458. go server.serve()
  459. return server
  460. }
  461. func (s *echoServer) serve() {
  462. defer s.wg.Done()
  463. for {
  464. select {
  465. case <-s.done:
  466. return
  467. default:
  468. conn, err := s.listener.Accept()
  469. if err != nil {
  470. select {
  471. case <-s.done:
  472. return
  473. default:
  474. continue
  475. }
  476. }
  477. go s.handleConnection(conn)
  478. }
  479. }
  480. }
  481. func (s *echoServer) handleConnection(conn net.Conn) {
  482. defer conn.Close()
  483. // Simple response - just echo back some data to confirm connectivity
  484. buf := make([]byte, 1024)
  485. n, err := conn.Read(buf)
  486. if err != nil {
  487. return
  488. }
  489. conn.Write(buf[:n])
  490. }
  491. func (s *echoServer) close() {
  492. close(s.done)
  493. s.listener.Close()
  494. s.wg.Wait()
  495. }
  496. func TestGetResolver(t *testing.T) {
  497. tests := []struct {
  498. name string
  499. network string
  500. addr string
  501. }{
  502. {
  503. name: "ipv4_loopback",
  504. network: "tcp4",
  505. addr: "127.0.0.1:0",
  506. },
  507. {
  508. name: "ipv6_loopback",
  509. network: "tcp6",
  510. addr: "[::1]:0",
  511. },
  512. }
  513. for _, tc := range tests {
  514. t.Run(tc.name, func(t *testing.T) {
  515. server := newEchoServer(t, tc.network, tc.addr)
  516. defer server.close()
  517. serverAddr := server.addr
  518. resolver := getResolver(serverAddr)
  519. if resolver == nil {
  520. t.Fatal("getResolver returned nil")
  521. }
  522. netResolver, ok := resolver.(*net.Resolver)
  523. if !ok {
  524. t.Fatal("getResolver did not return a *net.Resolver")
  525. }
  526. if netResolver.Dial == nil {
  527. t.Fatal("resolver.Dial is nil")
  528. }
  529. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  530. defer cancel()
  531. conn, err := netResolver.Dial(ctx, "tcp", "dummy.address:53")
  532. if err != nil {
  533. t.Fatalf("Failed to dial test DNS server: %v", err)
  534. }
  535. defer conn.Close()
  536. testData := []byte("test")
  537. _, err = conn.Write(testData)
  538. if err != nil {
  539. t.Fatalf("Failed to write to connection: %v", err)
  540. }
  541. response := make([]byte, len(testData))
  542. _, err = conn.Read(response)
  543. if err != nil {
  544. t.Fatalf("Failed to read from connection: %v", err)
  545. }
  546. if string(response) != string(testData) {
  547. t.Fatalf("Expected echo response %q, got %q", testData, response)
  548. }
  549. })
  550. }
  551. }
  552. func TestGetResolverMultipleServers(t *testing.T) {
  553. server1 := newEchoServer(t, "tcp4", "127.0.0.1:0")
  554. defer server1.close()
  555. server2 := newEchoServer(t, "tcp4", "127.0.0.1:0")
  556. defer server2.close()
  557. serverFlag := server1.addr + ", " + server2.addr
  558. resolver := getResolver(serverFlag)
  559. netResolver, ok := resolver.(*net.Resolver)
  560. if !ok {
  561. t.Fatal("getResolver did not return a *net.Resolver")
  562. }
  563. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  564. defer cancel()
  565. servers := map[string]bool{
  566. server1.addr: false,
  567. server2.addr: false,
  568. }
  569. // Try up to 1000 times to hit all servers, this should be very quick, and
  570. // if this fails randomness has regressed beyond reason.
  571. for range 1000 {
  572. conn, err := netResolver.Dial(ctx, "tcp", "dummy.address:53")
  573. if err != nil {
  574. t.Fatalf("Failed to dial test DNS server: %v", err)
  575. }
  576. remoteAddr := conn.RemoteAddr().String()
  577. conn.Close()
  578. servers[remoteAddr] = true
  579. var allDone = true
  580. for _, done := range servers {
  581. if !done {
  582. allDone = false
  583. break
  584. }
  585. }
  586. if allDone {
  587. break
  588. }
  589. }
  590. var allDone = true
  591. for _, done := range servers {
  592. if !done {
  593. allDone = false
  594. break
  595. }
  596. }
  597. if !allDone {
  598. t.Errorf("after 1000 queries, not all servers were hit, significant lack of randomness: %#v", servers)
  599. }
  600. }
  601. func TestGetResolverEmpty(t *testing.T) {
  602. resolver := getResolver("")
  603. if resolver != net.DefaultResolver {
  604. t.Fatal(`getResolver("") should return net.DefaultResolver`)
  605. }
  606. }