firewall_test.go 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842
  1. package nebula
  2. import (
  3. "bytes"
  4. "errors"
  5. "math"
  6. "net/netip"
  7. "testing"
  8. "time"
  9. "github.com/slackhq/nebula/cert"
  10. "github.com/slackhq/nebula/config"
  11. "github.com/slackhq/nebula/firewall"
  12. "github.com/slackhq/nebula/test"
  13. "github.com/stretchr/testify/assert"
  14. )
  15. func TestNewFirewall(t *testing.T) {
  16. l := test.NewLogger()
  17. c := &dummyCert{}
  18. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  19. conntrack := fw.Conntrack
  20. assert.NotNil(t, conntrack)
  21. assert.NotNil(t, conntrack.Conns)
  22. assert.NotNil(t, conntrack.TimerWheel)
  23. assert.NotNil(t, fw.InRules)
  24. assert.NotNil(t, fw.OutRules)
  25. assert.Equal(t, time.Second, fw.TCPTimeout)
  26. assert.Equal(t, time.Minute, fw.UDPTimeout)
  27. assert.Equal(t, time.Hour, fw.DefaultTimeout)
  28. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  29. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  30. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  31. fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
  32. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  33. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  34. fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
  35. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  36. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  37. fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
  38. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  39. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  40. fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
  41. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  42. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  43. fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
  44. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  45. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  46. }
  47. func TestFirewall_AddRule(t *testing.T) {
  48. l := test.NewLogger()
  49. ob := &bytes.Buffer{}
  50. l.SetOutput(ob)
  51. c := &dummyCert{}
  52. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  53. assert.NotNil(t, fw.InRules)
  54. assert.NotNil(t, fw.OutRules)
  55. ti, err := netip.ParsePrefix("1.2.3.4/32")
  56. assert.NoError(t, err)
  57. assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  58. // An empty rule is any
  59. assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
  60. assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
  61. assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
  62. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  63. assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  64. assert.Nil(t, fw.InRules.UDP[1].Any.Any)
  65. assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
  66. assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
  67. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  68. assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
  69. assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
  70. assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
  71. assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
  72. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  73. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
  74. assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
  75. _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
  76. assert.True(t, ok)
  77. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  78. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
  79. assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
  80. _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
  81. assert.True(t, ok)
  82. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  83. assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
  84. assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
  85. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  86. assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
  87. assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
  88. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  89. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
  90. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  91. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  92. anyIp, err := netip.ParsePrefix("0.0.0.0/0")
  93. assert.NoError(t, err)
  94. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
  95. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  96. // Test error conditions
  97. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  98. assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  99. assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  100. }
  101. func TestFirewall_Drop(t *testing.T) {
  102. l := test.NewLogger()
  103. ob := &bytes.Buffer{}
  104. l.SetOutput(ob)
  105. p := firewall.Packet{
  106. LocalIP: netip.MustParseAddr("1.2.3.4"),
  107. RemoteIP: netip.MustParseAddr("1.2.3.4"),
  108. LocalPort: 10,
  109. RemotePort: 90,
  110. Protocol: firewall.ProtoUDP,
  111. Fragment: false,
  112. }
  113. c := dummyCert{
  114. name: "host1",
  115. networks: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/24")},
  116. groups: []string{"default-group"},
  117. issuer: "signer-shasum",
  118. }
  119. h := HostInfo{
  120. ConnectionState: &ConnectionState{
  121. peerCert: &cert.CachedCertificate{
  122. Certificate: &c,
  123. InvertedGroups: map[string]struct{}{"default-group": {}},
  124. },
  125. },
  126. vpnIp: netip.MustParseAddr("1.2.3.4"),
  127. }
  128. h.CreateRemoteCIDR(&c)
  129. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  130. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  131. cp := cert.NewCAPool()
  132. // Drop outbound
  133. assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
  134. // Allow inbound
  135. resetConntrack(fw)
  136. assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
  137. // Allow outbound because conntrack
  138. assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
  139. // test remote mismatch
  140. oldRemote := p.RemoteIP
  141. p.RemoteIP = netip.MustParseAddr("1.2.3.10")
  142. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
  143. p.RemoteIP = oldRemote
  144. // ensure signer doesn't get in the way of group checks
  145. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  146. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
  147. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
  148. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  149. // test caSha doesn't drop on match
  150. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  151. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
  152. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
  153. assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
  154. // ensure ca name doesn't get in the way of group checks
  155. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  156. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  157. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
  158. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
  159. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  160. // test caName doesn't drop on match
  161. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  162. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  163. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
  164. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
  165. assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
  166. }
  167. func BenchmarkFirewallTable_match(b *testing.B) {
  168. f := &Firewall{}
  169. ft := FirewallTable{
  170. TCP: firewallPort{},
  171. }
  172. pfix := netip.MustParsePrefix("172.1.1.1/32")
  173. _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
  174. _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
  175. cp := cert.NewCAPool()
  176. b.Run("fail on proto", func(b *testing.B) {
  177. // This benchmark is showing us the cost of failing to match the protocol
  178. c := &cert.CachedCertificate{
  179. Certificate: &dummyCert{},
  180. }
  181. for n := 0; n < b.N; n++ {
  182. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
  183. }
  184. })
  185. b.Run("pass proto, fail on port", func(b *testing.B) {
  186. // This benchmark is showing us the cost of matching a specific protocol but failing to match the port
  187. c := &cert.CachedCertificate{
  188. Certificate: &dummyCert{},
  189. }
  190. for n := 0; n < b.N; n++ {
  191. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
  192. }
  193. })
  194. b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
  195. c := &cert.CachedCertificate{
  196. Certificate: &dummyCert{},
  197. }
  198. ip := netip.MustParsePrefix("9.254.254.254/32")
  199. for n := 0; n < b.N; n++ {
  200. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp))
  201. }
  202. })
  203. b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
  204. c := &cert.CachedCertificate{
  205. Certificate: &dummyCert{
  206. name: "nope",
  207. networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
  208. },
  209. InvertedGroups: map[string]struct{}{"nope": {}},
  210. }
  211. for n := 0; n < b.N; n++ {
  212. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  213. }
  214. })
  215. b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
  216. c := &cert.CachedCertificate{
  217. Certificate: &dummyCert{
  218. name: "nope",
  219. networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
  220. },
  221. InvertedGroups: map[string]struct{}{"nope": {}},
  222. }
  223. for n := 0; n < b.N; n++ {
  224. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
  225. }
  226. })
  227. b.Run("pass on group on any local cidr", func(b *testing.B) {
  228. c := &cert.CachedCertificate{
  229. Certificate: &dummyCert{
  230. name: "nope",
  231. },
  232. InvertedGroups: map[string]struct{}{"good-group": {}},
  233. }
  234. for n := 0; n < b.N; n++ {
  235. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  236. }
  237. })
  238. b.Run("pass on group on specific local cidr", func(b *testing.B) {
  239. c := &cert.CachedCertificate{
  240. Certificate: &dummyCert{
  241. name: "nope",
  242. },
  243. InvertedGroups: map[string]struct{}{"good-group": {}},
  244. }
  245. for n := 0; n < b.N; n++ {
  246. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
  247. }
  248. })
  249. b.Run("pass on name", func(b *testing.B) {
  250. c := &cert.CachedCertificate{
  251. Certificate: &dummyCert{
  252. name: "good-host",
  253. },
  254. InvertedGroups: map[string]struct{}{"nope": {}},
  255. }
  256. for n := 0; n < b.N; n++ {
  257. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
  258. }
  259. })
  260. }
  261. func TestFirewall_Drop2(t *testing.T) {
  262. l := test.NewLogger()
  263. ob := &bytes.Buffer{}
  264. l.SetOutput(ob)
  265. p := firewall.Packet{
  266. LocalIP: netip.MustParseAddr("1.2.3.4"),
  267. RemoteIP: netip.MustParseAddr("1.2.3.4"),
  268. LocalPort: 10,
  269. RemotePort: 90,
  270. Protocol: firewall.ProtoUDP,
  271. Fragment: false,
  272. }
  273. network := netip.MustParsePrefix("1.2.3.4/24")
  274. c := cert.CachedCertificate{
  275. Certificate: &dummyCert{
  276. name: "host1",
  277. networks: []netip.Prefix{network},
  278. },
  279. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
  280. }
  281. h := HostInfo{
  282. ConnectionState: &ConnectionState{
  283. peerCert: &c,
  284. },
  285. vpnIp: network.Addr(),
  286. }
  287. h.CreateRemoteCIDR(c.Certificate)
  288. c1 := cert.CachedCertificate{
  289. Certificate: &dummyCert{
  290. name: "host1",
  291. networks: []netip.Prefix{network},
  292. },
  293. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
  294. }
  295. h1 := HostInfo{
  296. ConnectionState: &ConnectionState{
  297. peerCert: &c1,
  298. },
  299. }
  300. h1.CreateRemoteCIDR(c1.Certificate)
  301. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  302. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  303. cp := cert.NewCAPool()
  304. // h1/c1 lacks the proper groups
  305. assert.Error(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
  306. // c has the proper groups
  307. resetConntrack(fw)
  308. assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
  309. }
  310. func TestFirewall_Drop3(t *testing.T) {
  311. l := test.NewLogger()
  312. ob := &bytes.Buffer{}
  313. l.SetOutput(ob)
  314. p := firewall.Packet{
  315. LocalIP: netip.MustParseAddr("1.2.3.4"),
  316. RemoteIP: netip.MustParseAddr("1.2.3.4"),
  317. LocalPort: 1,
  318. RemotePort: 1,
  319. Protocol: firewall.ProtoUDP,
  320. Fragment: false,
  321. }
  322. network := netip.MustParsePrefix("1.2.3.4/24")
  323. c := cert.CachedCertificate{
  324. Certificate: &dummyCert{
  325. name: "host-owner",
  326. networks: []netip.Prefix{network},
  327. },
  328. }
  329. c1 := cert.CachedCertificate{
  330. Certificate: &dummyCert{
  331. name: "host1",
  332. networks: []netip.Prefix{network},
  333. issuer: "signer-sha-bad",
  334. },
  335. }
  336. h1 := HostInfo{
  337. ConnectionState: &ConnectionState{
  338. peerCert: &c1,
  339. },
  340. vpnIp: network.Addr(),
  341. }
  342. h1.CreateRemoteCIDR(c1.Certificate)
  343. c2 := cert.CachedCertificate{
  344. Certificate: &dummyCert{
  345. name: "host2",
  346. networks: []netip.Prefix{network},
  347. issuer: "signer-sha",
  348. },
  349. }
  350. h2 := HostInfo{
  351. ConnectionState: &ConnectionState{
  352. peerCert: &c2,
  353. },
  354. vpnIp: network.Addr(),
  355. }
  356. h2.CreateRemoteCIDR(c2.Certificate)
  357. c3 := cert.CachedCertificate{
  358. Certificate: &dummyCert{
  359. name: "host3",
  360. networks: []netip.Prefix{network},
  361. issuer: "signer-sha-bad",
  362. },
  363. }
  364. h3 := HostInfo{
  365. ConnectionState: &ConnectionState{
  366. peerCert: &c3,
  367. },
  368. vpnIp: network.Addr(),
  369. }
  370. h3.CreateRemoteCIDR(c3.Certificate)
  371. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  372. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
  373. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
  374. cp := cert.NewCAPool()
  375. // c1 should pass because host match
  376. assert.NoError(t, fw.Drop(p, true, &h1, cp, nil))
  377. // c2 should pass because ca sha match
  378. resetConntrack(fw)
  379. assert.NoError(t, fw.Drop(p, true, &h2, cp, nil))
  380. // c3 should fail because no match
  381. resetConntrack(fw)
  382. assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
  383. }
  384. func TestFirewall_DropConntrackReload(t *testing.T) {
  385. l := test.NewLogger()
  386. ob := &bytes.Buffer{}
  387. l.SetOutput(ob)
  388. p := firewall.Packet{
  389. LocalIP: netip.MustParseAddr("1.2.3.4"),
  390. RemoteIP: netip.MustParseAddr("1.2.3.4"),
  391. LocalPort: 10,
  392. RemotePort: 90,
  393. Protocol: firewall.ProtoUDP,
  394. Fragment: false,
  395. }
  396. network := netip.MustParsePrefix("1.2.3.4/24")
  397. c := cert.CachedCertificate{
  398. Certificate: &dummyCert{
  399. name: "host1",
  400. networks: []netip.Prefix{network},
  401. groups: []string{"default-group"},
  402. issuer: "signer-shasum",
  403. },
  404. InvertedGroups: map[string]struct{}{"default-group": {}},
  405. }
  406. h := HostInfo{
  407. ConnectionState: &ConnectionState{
  408. peerCert: &c,
  409. },
  410. vpnIp: network.Addr(),
  411. }
  412. h.CreateRemoteCIDR(c.Certificate)
  413. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  414. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  415. cp := cert.NewCAPool()
  416. // Drop outbound
  417. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
  418. // Allow inbound
  419. resetConntrack(fw)
  420. assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
  421. // Allow outbound because conntrack
  422. assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
  423. oldFw := fw
  424. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  425. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  426. fw.Conntrack = oldFw.Conntrack
  427. fw.rulesVersion = oldFw.rulesVersion + 1
  428. // Allow outbound because conntrack and new rules allow port 10
  429. assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
  430. oldFw = fw
  431. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  432. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  433. fw.Conntrack = oldFw.Conntrack
  434. fw.rulesVersion = oldFw.rulesVersion + 1
  435. // Drop outbound because conntrack doesn't match new ruleset
  436. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
  437. }
  438. func BenchmarkLookup(b *testing.B) {
  439. ml := func(m map[string]struct{}, a [][]string) {
  440. for n := 0; n < b.N; n++ {
  441. for _, sg := range a {
  442. found := false
  443. for _, g := range sg {
  444. if _, ok := m[g]; !ok {
  445. found = false
  446. break
  447. }
  448. found = true
  449. }
  450. if found {
  451. return
  452. }
  453. }
  454. }
  455. }
  456. b.Run("array to map best", func(b *testing.B) {
  457. m := map[string]struct{}{
  458. "1ne": {},
  459. "2wo": {},
  460. "3hr": {},
  461. "4ou": {},
  462. "5iv": {},
  463. "6ix": {},
  464. }
  465. a := [][]string{
  466. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  467. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  468. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  469. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  470. {"one", "two", "thr", "fou", "5iv", "6ix"},
  471. {"one", "two", "thr", "fou", "fiv", "6ix"},
  472. {"one", "two", "thr", "fou", "fiv", "six"},
  473. }
  474. for n := 0; n < b.N; n++ {
  475. ml(m, a)
  476. }
  477. })
  478. b.Run("array to map worst", func(b *testing.B) {
  479. m := map[string]struct{}{
  480. "one": {},
  481. "two": {},
  482. "thr": {},
  483. "fou": {},
  484. "fiv": {},
  485. "six": {},
  486. }
  487. a := [][]string{
  488. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  489. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  490. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  491. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  492. {"one", "two", "thr", "fou", "5iv", "6ix"},
  493. {"one", "two", "thr", "fou", "fiv", "6ix"},
  494. {"one", "two", "thr", "fou", "fiv", "six"},
  495. }
  496. for n := 0; n < b.N; n++ {
  497. ml(m, a)
  498. }
  499. })
  500. //TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster
  501. }
  502. func Test_parsePort(t *testing.T) {
  503. _, _, err := parsePort("")
  504. assert.EqualError(t, err, "was not a number; ``")
  505. _, _, err = parsePort(" ")
  506. assert.EqualError(t, err, "was not a number; ` `")
  507. _, _, err = parsePort("-")
  508. assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
  509. _, _, err = parsePort(" - ")
  510. assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
  511. _, _, err = parsePort("a-b")
  512. assert.EqualError(t, err, "beginning range was not a number; `a`")
  513. _, _, err = parsePort("1-b")
  514. assert.EqualError(t, err, "ending range was not a number; `b`")
  515. s, e, err := parsePort(" 1 - 2 ")
  516. assert.Equal(t, int32(1), s)
  517. assert.Equal(t, int32(2), e)
  518. assert.Nil(t, err)
  519. s, e, err = parsePort("0-1")
  520. assert.Equal(t, int32(0), s)
  521. assert.Equal(t, int32(0), e)
  522. assert.Nil(t, err)
  523. s, e, err = parsePort("9919")
  524. assert.Equal(t, int32(9919), s)
  525. assert.Equal(t, int32(9919), e)
  526. assert.Nil(t, err)
  527. s, e, err = parsePort("any")
  528. assert.Equal(t, int32(0), s)
  529. assert.Equal(t, int32(0), e)
  530. assert.Nil(t, err)
  531. }
  532. func TestNewFirewallFromConfig(t *testing.T) {
  533. l := test.NewLogger()
  534. // Test a bad rule definition
  535. c := &dummyCert{}
  536. conf := config.NewC(l)
  537. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
  538. _, err := NewFirewallFromConfig(l, c, conf)
  539. assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
  540. // Test both port and code
  541. conf = config.NewC(l)
  542. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
  543. _, err = NewFirewallFromConfig(l, c, conf)
  544. assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
  545. // Test missing host, group, cidr, ca_name and ca_sha
  546. conf = config.NewC(l)
  547. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
  548. _, err = NewFirewallFromConfig(l, c, conf)
  549. assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
  550. // Test code/port error
  551. conf = config.NewC(l)
  552. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
  553. _, err = NewFirewallFromConfig(l, c, conf)
  554. assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
  555. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
  556. _, err = NewFirewallFromConfig(l, c, conf)
  557. assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
  558. // Test proto error
  559. conf = config.NewC(l)
  560. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
  561. _, err = NewFirewallFromConfig(l, c, conf)
  562. assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
  563. // Test cidr parse error
  564. conf = config.NewC(l)
  565. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
  566. _, err = NewFirewallFromConfig(l, c, conf)
  567. assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
  568. // Test local_cidr parse error
  569. conf = config.NewC(l)
  570. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
  571. _, err = NewFirewallFromConfig(l, c, conf)
  572. assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
  573. // Test both group and groups
  574. conf = config.NewC(l)
  575. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
  576. _, err = NewFirewallFromConfig(l, c, conf)
  577. assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
  578. }
  579. func TestAddFirewallRulesFromConfig(t *testing.T) {
  580. l := test.NewLogger()
  581. // Test adding tcp rule
  582. conf := config.NewC(l)
  583. mf := &mockFirewall{}
  584. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
  585. assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  586. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  587. // Test adding udp rule
  588. conf = config.NewC(l)
  589. mf = &mockFirewall{}
  590. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
  591. assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  592. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  593. // Test adding icmp rule
  594. conf = config.NewC(l)
  595. mf = &mockFirewall{}
  596. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
  597. assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  598. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  599. // Test adding any rule
  600. conf = config.NewC(l)
  601. mf = &mockFirewall{}
  602. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
  603. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  604. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  605. // Test adding rule with cidr
  606. cidr := netip.MustParsePrefix("10.0.0.0/8")
  607. conf = config.NewC(l)
  608. mf = &mockFirewall{}
  609. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
  610. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  611. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
  612. // Test adding rule with local_cidr
  613. conf = config.NewC(l)
  614. mf = &mockFirewall{}
  615. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
  616. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  617. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
  618. // Test adding rule with ca_sha
  619. conf = config.NewC(l)
  620. mf = &mockFirewall{}
  621. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
  622. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  623. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
  624. // Test adding rule with ca_name
  625. conf = config.NewC(l)
  626. mf = &mockFirewall{}
  627. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
  628. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  629. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
  630. // Test single group
  631. conf = config.NewC(l)
  632. mf = &mockFirewall{}
  633. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
  634. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  635. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  636. // Test single groups
  637. conf = config.NewC(l)
  638. mf = &mockFirewall{}
  639. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
  640. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  641. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  642. // Test multiple AND groups
  643. conf = config.NewC(l)
  644. mf = &mockFirewall{}
  645. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
  646. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  647. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  648. // Test Add error
  649. conf = config.NewC(l)
  650. mf = &mockFirewall{}
  651. mf.nextCallReturn = errors.New("test error")
  652. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
  653. assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
  654. }
  655. func TestFirewall_convertRule(t *testing.T) {
  656. l := test.NewLogger()
  657. ob := &bytes.Buffer{}
  658. l.SetOutput(ob)
  659. // Ensure group array of 1 is converted and a warning is printed
  660. c := map[interface{}]interface{}{
  661. "group": []interface{}{"group1"},
  662. }
  663. r, err := convertRule(l, c, "test", 1)
  664. assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
  665. assert.Nil(t, err)
  666. assert.Equal(t, "group1", r.Group)
  667. // Ensure group array of > 1 is errord
  668. ob.Reset()
  669. c = map[interface{}]interface{}{
  670. "group": []interface{}{"group1", "group2"},
  671. }
  672. r, err = convertRule(l, c, "test", 1)
  673. assert.Equal(t, "", ob.String())
  674. assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
  675. // Make sure a well formed group is alright
  676. ob.Reset()
  677. c = map[interface{}]interface{}{
  678. "group": "group1",
  679. }
  680. r, err = convertRule(l, c, "test", 1)
  681. assert.Nil(t, err)
  682. assert.Equal(t, "group1", r.Group)
  683. }
  684. type addRuleCall struct {
  685. incoming bool
  686. proto uint8
  687. startPort int32
  688. endPort int32
  689. groups []string
  690. host string
  691. ip netip.Prefix
  692. localIp netip.Prefix
  693. caName string
  694. caSha string
  695. }
  696. type mockFirewall struct {
  697. lastCall addRuleCall
  698. nextCallReturn error
  699. }
  700. func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
  701. mf.lastCall = addRuleCall{
  702. incoming: incoming,
  703. proto: proto,
  704. startPort: startPort,
  705. endPort: endPort,
  706. groups: groups,
  707. host: host,
  708. ip: ip,
  709. localIp: localIp,
  710. caName: caName,
  711. caSha: caSha,
  712. }
  713. err := mf.nextCallReturn
  714. mf.nextCallReturn = nil
  715. return err
  716. }
  717. func resetConntrack(fw *Firewall) {
  718. fw.Conntrack.Lock()
  719. fw.Conntrack.Conns = map[firewall.Packet]*conn{}
  720. fw.Conntrack.Unlock()
  721. }