packet_filter_test.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package tsnet
  4. import (
  5. "context"
  6. "fmt"
  7. "net/netip"
  8. "testing"
  9. "time"
  10. "tailscale.com/ipn"
  11. "tailscale.com/tailcfg"
  12. "tailscale.com/tstest"
  13. "tailscale.com/types/ipproto"
  14. "tailscale.com/types/key"
  15. "tailscale.com/types/netmap"
  16. "tailscale.com/util/must"
  17. "tailscale.com/wgengine/filter"
  18. )
  19. // waitFor blocks until a NetMap is seen on the IPN bus that satisfies the given
  20. // function f. Note: has no timeout, should be called with a ctx that has an
  21. // appropriate timeout set.
  22. func waitFor(t testing.TB, ctx context.Context, s *Server, f func(*netmap.NetworkMap) bool) error {
  23. t.Helper()
  24. watcher, err := s.localClient.WatchIPNBus(ctx, ipn.NotifyInitialNetMap)
  25. if err != nil {
  26. t.Fatalf("error watching IPN bus: %s", err)
  27. }
  28. defer watcher.Close()
  29. for {
  30. n, err := watcher.Next()
  31. if err != nil {
  32. return fmt.Errorf("getting next ipn.Notify from IPN bus: %w", err)
  33. }
  34. if n.NetMap != nil {
  35. if f(n.NetMap) {
  36. return nil
  37. }
  38. }
  39. }
  40. }
  41. // TestPacketFilterFromNetmap tests all of the client code for processing
  42. // netmaps and turning them into packet filters together. Only the control-plane
  43. // side is mocked out.
  44. func TestPacketFilterFromNetmap(t *testing.T) {
  45. tstest.Shard(t)
  46. t.Parallel()
  47. var key key.NodePublic
  48. must.Do(key.UnmarshalText([]byte("nodekey:5c8f86d5fc70d924e55f02446165a5dae8f822994ad26bcf4b08fd841f9bf261")))
  49. type check struct {
  50. src string
  51. dst string
  52. port uint16
  53. want filter.Response
  54. }
  55. tests := []struct {
  56. name string
  57. mapResponse *tailcfg.MapResponse
  58. waitTest func(*netmap.NetworkMap) bool
  59. incrementalMapResponse *tailcfg.MapResponse // optional
  60. incrementalWaitTest func(*netmap.NetworkMap) bool // optional
  61. checks []check
  62. }{
  63. {
  64. name: "IP_based_peers",
  65. mapResponse: &tailcfg.MapResponse{
  66. Node: &tailcfg.Node{
  67. Addresses: []netip.Prefix{netip.MustParsePrefix("1.1.1.1/32")},
  68. },
  69. Peers: []*tailcfg.Node{{
  70. ID: 2,
  71. Name: "foo",
  72. Key: key,
  73. Addresses: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")},
  74. CapMap: nil,
  75. }},
  76. PacketFilter: []tailcfg.FilterRule{{
  77. SrcIPs: []string{"2.2.2.2/32"},
  78. DstPorts: []tailcfg.NetPortRange{{
  79. IP: "1.1.1.1/32",
  80. Ports: tailcfg.PortRange{
  81. First: 22,
  82. Last: 22,
  83. },
  84. }},
  85. IPProto: []int{int(ipproto.TCP)},
  86. }},
  87. },
  88. waitTest: func(nm *netmap.NetworkMap) bool {
  89. return len(nm.Peers) > 0
  90. },
  91. checks: []check{
  92. {src: "2.2.2.2", dst: "1.1.1.1", port: 22, want: filter.Accept},
  93. {src: "2.2.2.2", dst: "1.1.1.1", port: 23, want: filter.Drop}, // different port
  94. {src: "3.3.3.3", dst: "1.1.1.1", port: 22, want: filter.Drop}, // different src
  95. {src: "2.2.2.2", dst: "1.1.1.2", port: 22, want: filter.Drop}, // different dst
  96. },
  97. },
  98. {
  99. name: "capmap_based_peers",
  100. mapResponse: &tailcfg.MapResponse{
  101. Node: &tailcfg.Node{
  102. Addresses: []netip.Prefix{netip.MustParsePrefix("1.1.1.1/32")},
  103. },
  104. Peers: []*tailcfg.Node{{
  105. ID: 2,
  106. Name: "foo",
  107. Key: key,
  108. Addresses: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")},
  109. CapMap: tailcfg.NodeCapMap{"X": nil},
  110. }},
  111. PacketFilter: []tailcfg.FilterRule{{
  112. SrcIPs: []string{"cap:X"},
  113. DstPorts: []tailcfg.NetPortRange{{
  114. IP: "1.1.1.1/32",
  115. Ports: tailcfg.PortRange{
  116. First: 22,
  117. Last: 22,
  118. },
  119. }},
  120. IPProto: []int{int(ipproto.TCP)},
  121. }},
  122. },
  123. waitTest: func(nm *netmap.NetworkMap) bool {
  124. return len(nm.Peers) > 0
  125. },
  126. checks: []check{
  127. {src: "2.2.2.2", dst: "1.1.1.1", port: 22, want: filter.Accept},
  128. {src: "2.2.2.2", dst: "1.1.1.1", port: 23, want: filter.Drop}, // different port
  129. {src: "3.3.3.3", dst: "1.1.1.1", port: 22, want: filter.Drop}, // different src
  130. {src: "2.2.2.2", dst: "1.1.1.2", port: 22, want: filter.Drop}, // different dst
  131. },
  132. },
  133. {
  134. name: "capmap_based_peers_changed",
  135. mapResponse: &tailcfg.MapResponse{
  136. Node: &tailcfg.Node{
  137. Addresses: []netip.Prefix{netip.MustParsePrefix("1.1.1.1/32")},
  138. CapMap: tailcfg.NodeCapMap{"X-sigil": nil},
  139. },
  140. PacketFilter: []tailcfg.FilterRule{{
  141. SrcIPs: []string{"cap:label-1"},
  142. DstPorts: []tailcfg.NetPortRange{{
  143. IP: "1.1.1.1/32",
  144. Ports: tailcfg.PortRange{
  145. First: 22,
  146. Last: 22,
  147. },
  148. }},
  149. IPProto: []int{int(ipproto.TCP)},
  150. }},
  151. },
  152. waitTest: func(nm *netmap.NetworkMap) bool {
  153. return nm.SelfNode.HasCap("X-sigil")
  154. },
  155. incrementalMapResponse: &tailcfg.MapResponse{
  156. PeersChanged: []*tailcfg.Node{{
  157. ID: 2,
  158. Name: "foo",
  159. Key: key,
  160. Addresses: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")},
  161. CapMap: tailcfg.NodeCapMap{"label-1": nil},
  162. }},
  163. },
  164. incrementalWaitTest: func(nm *netmap.NetworkMap) bool {
  165. return len(nm.Peers) > 0
  166. },
  167. checks: []check{
  168. {src: "2.2.2.2", dst: "1.1.1.1", port: 22, want: filter.Accept},
  169. {src: "2.2.2.2", dst: "1.1.1.1", port: 23, want: filter.Drop}, // different port
  170. {src: "3.3.3.3", dst: "1.1.1.1", port: 22, want: filter.Drop}, // different src
  171. {src: "2.2.2.2", dst: "1.1.1.2", port: 22, want: filter.Drop}, // different dst
  172. },
  173. },
  174. }
  175. for _, test := range tests {
  176. t.Run(test.name, func(t *testing.T) {
  177. t.Parallel()
  178. ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second)
  179. defer cancel()
  180. controlURL, c := startControl(t)
  181. s, _, pubKey := startServer(t, ctx, controlURL, "node")
  182. if test.waitTest(s.lb.NetMap()) {
  183. t.Fatal("waitTest already passes before sending initial netmap: this will be flaky")
  184. }
  185. if !c.AddRawMapResponse(pubKey, test.mapResponse) {
  186. t.Fatalf("could not send map response to %s", pubKey)
  187. }
  188. if err := waitFor(t, ctx, s, test.waitTest); err != nil {
  189. t.Fatalf("waitFor: %s", err)
  190. }
  191. pf := s.lb.GetFilterForTest()
  192. for _, check := range test.checks {
  193. got := pf.Check(netip.MustParseAddr(check.src), netip.MustParseAddr(check.dst), check.port, ipproto.TCP)
  194. want := check.want
  195. if test.incrementalMapResponse != nil {
  196. want = filter.Drop
  197. }
  198. if got != want {
  199. t.Errorf("check %s -> %s:%d, got: %s, want: %s", check.src, check.dst, check.port, got, want)
  200. }
  201. }
  202. if test.incrementalMapResponse != nil {
  203. if test.incrementalWaitTest == nil {
  204. t.Fatal("incrementalWaitTest must be set if incrementalMapResponse is set")
  205. }
  206. if test.incrementalWaitTest(s.lb.NetMap()) {
  207. t.Fatal("incrementalWaitTest already passes before sending incremental netmap: this will be flaky")
  208. }
  209. if !c.AddRawMapResponse(pubKey, test.incrementalMapResponse) {
  210. t.Fatalf("could not send map response to %s", pubKey)
  211. }
  212. if err := waitFor(t, ctx, s, test.incrementalWaitTest); err != nil {
  213. t.Fatalf("waitFor: %s", err)
  214. }
  215. pf := s.lb.GetFilterForTest()
  216. for _, check := range test.checks {
  217. got := pf.Check(netip.MustParseAddr(check.src), netip.MustParseAddr(check.dst), check.port, ipproto.TCP)
  218. if got != check.want {
  219. t.Errorf("check %s -> %s:%d, got: %s, want: %s", check.src, check.dst, check.port, got, check.want)
  220. }
  221. }
  222. }
  223. })
  224. }
  225. }