firewall_test.go 41 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106
  1. package nebula
  2. import (
  3. "bytes"
  4. "errors"
  5. "math"
  6. "net/netip"
  7. "testing"
  8. "time"
  9. "github.com/gaissmai/bart"
  10. "github.com/slackhq/nebula/cert"
  11. "github.com/slackhq/nebula/config"
  12. "github.com/slackhq/nebula/firewall"
  13. "github.com/slackhq/nebula/test"
  14. "github.com/stretchr/testify/assert"
  15. "github.com/stretchr/testify/require"
  16. )
  17. func TestNewFirewall(t *testing.T) {
  18. l := test.NewLogger()
  19. c := &dummyCert{}
  20. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  21. conntrack := fw.Conntrack
  22. assert.NotNil(t, conntrack)
  23. assert.NotNil(t, conntrack.Conns)
  24. assert.NotNil(t, conntrack.TimerWheel)
  25. assert.NotNil(t, fw.InRules)
  26. assert.NotNil(t, fw.OutRules)
  27. assert.Equal(t, time.Second, fw.TCPTimeout)
  28. assert.Equal(t, time.Minute, fw.UDPTimeout)
  29. assert.Equal(t, time.Hour, fw.DefaultTimeout)
  30. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  31. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  32. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  33. fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
  34. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  35. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  36. fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
  37. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  38. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  39. fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
  40. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  41. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  42. fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
  43. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  44. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  45. fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
  46. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  47. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  48. }
  49. func TestFirewall_AddRule(t *testing.T) {
  50. l := test.NewLogger()
  51. ob := &bytes.Buffer{}
  52. l.SetOutput(ob)
  53. c := &dummyCert{}
  54. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  55. assert.NotNil(t, fw.InRules)
  56. assert.NotNil(t, fw.OutRules)
  57. ti, err := netip.ParsePrefix("1.2.3.4/32")
  58. require.NoError(t, err)
  59. ti6, err := netip.ParsePrefix("fd12::34/128")
  60. require.NoError(t, err)
  61. require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  62. // An empty rule is any
  63. assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
  64. assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
  65. assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
  66. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  67. require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  68. assert.Nil(t, fw.InRules.UDP[1].Any.Any)
  69. assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
  70. assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
  71. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  72. require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
  73. assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
  74. assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
  75. assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
  76. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  77. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
  78. assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
  79. _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
  80. assert.True(t, ok)
  81. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  82. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
  83. assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
  84. _, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
  85. assert.True(t, ok)
  86. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  87. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
  88. assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
  89. _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
  90. assert.True(t, ok)
  91. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  92. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
  93. assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
  94. _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
  95. assert.True(t, ok)
  96. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  97. require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
  98. assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
  99. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  100. require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
  101. assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
  102. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  103. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
  104. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  105. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  106. anyIp, err := netip.ParsePrefix("0.0.0.0/0")
  107. require.NoError(t, err)
  108. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
  109. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  110. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  111. anyIp6, err := netip.ParsePrefix("::/0")
  112. require.NoError(t, err)
  113. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
  114. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  115. // Test error conditions
  116. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  117. require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  118. require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  119. }
  120. func TestFirewall_Drop(t *testing.T) {
  121. l := test.NewLogger()
  122. ob := &bytes.Buffer{}
  123. l.SetOutput(ob)
  124. myVpnNetworksTable := new(bart.Lite)
  125. myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
  126. p := firewall.Packet{
  127. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  128. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  129. LocalPort: 10,
  130. RemotePort: 90,
  131. Protocol: firewall.ProtoUDP,
  132. Fragment: false,
  133. }
  134. c := dummyCert{
  135. name: "host1",
  136. networks: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/24")},
  137. groups: []string{"default-group"},
  138. issuer: "signer-shasum",
  139. }
  140. h := HostInfo{
  141. ConnectionState: &ConnectionState{
  142. peerCert: &cert.CachedCertificate{
  143. Certificate: &c,
  144. InvertedGroups: map[string]struct{}{"default-group": {}},
  145. },
  146. },
  147. vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
  148. }
  149. h.buildNetworks(myVpnNetworksTable, c.networks, c.unsafeNetworks)
  150. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  151. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  152. cp := cert.NewCAPool()
  153. // Drop outbound
  154. assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
  155. // Allow inbound
  156. resetConntrack(fw)
  157. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  158. // Allow outbound because conntrack
  159. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  160. // test remote mismatch
  161. oldRemote := p.RemoteAddr
  162. p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
  163. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
  164. p.RemoteAddr = oldRemote
  165. // ensure signer doesn't get in the way of group checks
  166. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  167. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
  168. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
  169. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  170. // test caSha doesn't drop on match
  171. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  172. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
  173. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
  174. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  175. // ensure ca name doesn't get in the way of group checks
  176. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  177. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  178. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
  179. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
  180. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  181. // test caName doesn't drop on match
  182. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  183. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  184. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
  185. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
  186. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  187. }
  188. func TestFirewall_DropV6(t *testing.T) {
  189. l := test.NewLogger()
  190. ob := &bytes.Buffer{}
  191. l.SetOutput(ob)
  192. myVpnNetworksTable := new(bart.Lite)
  193. myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
  194. p := firewall.Packet{
  195. LocalAddr: netip.MustParseAddr("fd12::34"),
  196. RemoteAddr: netip.MustParseAddr("fd12::34"),
  197. LocalPort: 10,
  198. RemotePort: 90,
  199. Protocol: firewall.ProtoUDP,
  200. Fragment: false,
  201. }
  202. c := dummyCert{
  203. name: "host1",
  204. networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")},
  205. groups: []string{"default-group"},
  206. issuer: "signer-shasum",
  207. }
  208. h := HostInfo{
  209. ConnectionState: &ConnectionState{
  210. peerCert: &cert.CachedCertificate{
  211. Certificate: &c,
  212. InvertedGroups: map[string]struct{}{"default-group": {}},
  213. },
  214. },
  215. vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
  216. }
  217. h.buildNetworks(myVpnNetworksTable, c.networks, c.unsafeNetworks)
  218. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  219. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  220. cp := cert.NewCAPool()
  221. // Drop outbound
  222. assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
  223. // Allow inbound
  224. resetConntrack(fw)
  225. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  226. // Allow outbound because conntrack
  227. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  228. // test remote mismatch
  229. oldRemote := p.RemoteAddr
  230. p.RemoteAddr = netip.MustParseAddr("fd12::56")
  231. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
  232. p.RemoteAddr = oldRemote
  233. // ensure signer doesn't get in the way of group checks
  234. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  235. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
  236. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
  237. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  238. // test caSha doesn't drop on match
  239. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  240. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
  241. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
  242. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  243. // ensure ca name doesn't get in the way of group checks
  244. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  245. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  246. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
  247. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
  248. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  249. // test caName doesn't drop on match
  250. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  251. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  252. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
  253. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
  254. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  255. }
  256. func BenchmarkFirewallTable_match(b *testing.B) {
  257. f := &Firewall{}
  258. ft := FirewallTable{
  259. TCP: firewallPort{},
  260. }
  261. pfix := netip.MustParsePrefix("172.1.1.1/32")
  262. _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
  263. _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
  264. pfix6 := netip.MustParsePrefix("fd11::11/128")
  265. _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
  266. _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
  267. cp := cert.NewCAPool()
  268. b.Run("fail on proto", func(b *testing.B) {
  269. // This benchmark is showing us the cost of failing to match the protocol
  270. c := &cert.CachedCertificate{
  271. Certificate: &dummyCert{},
  272. }
  273. for n := 0; n < b.N; n++ {
  274. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
  275. }
  276. })
  277. b.Run("pass proto, fail on port", func(b *testing.B) {
  278. // This benchmark is showing us the cost of matching a specific protocol but failing to match the port
  279. c := &cert.CachedCertificate{
  280. Certificate: &dummyCert{},
  281. }
  282. for n := 0; n < b.N; n++ {
  283. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
  284. }
  285. })
  286. b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
  287. c := &cert.CachedCertificate{
  288. Certificate: &dummyCert{},
  289. }
  290. ip := netip.MustParsePrefix("9.254.254.254/32")
  291. for n := 0; n < b.N; n++ {
  292. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
  293. }
  294. })
  295. b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) {
  296. c := &cert.CachedCertificate{
  297. Certificate: &dummyCert{},
  298. }
  299. ip := netip.MustParsePrefix("fd99::99/128")
  300. for n := 0; n < b.N; n++ {
  301. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
  302. }
  303. })
  304. b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
  305. c := &cert.CachedCertificate{
  306. Certificate: &dummyCert{
  307. name: "nope",
  308. networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
  309. },
  310. InvertedGroups: map[string]struct{}{"nope": {}},
  311. }
  312. for n := 0; n < b.N; n++ {
  313. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  314. }
  315. })
  316. b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
  317. c := &cert.CachedCertificate{
  318. Certificate: &dummyCert{
  319. name: "nope",
  320. networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
  321. },
  322. InvertedGroups: map[string]struct{}{"nope": {}},
  323. }
  324. for n := 0; n < b.N; n++ {
  325. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  326. }
  327. })
  328. b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
  329. c := &cert.CachedCertificate{
  330. Certificate: &dummyCert{
  331. name: "nope",
  332. networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
  333. },
  334. InvertedGroups: map[string]struct{}{"nope": {}},
  335. }
  336. for n := 0; n < b.N; n++ {
  337. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
  338. }
  339. })
  340. b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
  341. c := &cert.CachedCertificate{
  342. Certificate: &dummyCert{
  343. name: "nope",
  344. networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
  345. },
  346. InvertedGroups: map[string]struct{}{"nope": {}},
  347. }
  348. for n := 0; n < b.N; n++ {
  349. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
  350. }
  351. })
  352. b.Run("pass on group on any local cidr", func(b *testing.B) {
  353. c := &cert.CachedCertificate{
  354. Certificate: &dummyCert{
  355. name: "nope",
  356. },
  357. InvertedGroups: map[string]struct{}{"good-group": {}},
  358. }
  359. for n := 0; n < b.N; n++ {
  360. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  361. }
  362. })
  363. b.Run("pass on group on specific local cidr", func(b *testing.B) {
  364. c := &cert.CachedCertificate{
  365. Certificate: &dummyCert{
  366. name: "nope",
  367. },
  368. InvertedGroups: map[string]struct{}{"good-group": {}},
  369. }
  370. for n := 0; n < b.N; n++ {
  371. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
  372. }
  373. })
  374. b.Run("pass on group on specific local cidr6", func(b *testing.B) {
  375. c := &cert.CachedCertificate{
  376. Certificate: &dummyCert{
  377. name: "nope",
  378. },
  379. InvertedGroups: map[string]struct{}{"good-group": {}},
  380. }
  381. for n := 0; n < b.N; n++ {
  382. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
  383. }
  384. })
  385. b.Run("pass on name", func(b *testing.B) {
  386. c := &cert.CachedCertificate{
  387. Certificate: &dummyCert{
  388. name: "good-host",
  389. },
  390. InvertedGroups: map[string]struct{}{"nope": {}},
  391. }
  392. for n := 0; n < b.N; n++ {
  393. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
  394. }
  395. })
  396. }
  397. func TestFirewall_Drop2(t *testing.T) {
  398. l := test.NewLogger()
  399. ob := &bytes.Buffer{}
  400. l.SetOutput(ob)
  401. myVpnNetworksTable := new(bart.Lite)
  402. myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
  403. p := firewall.Packet{
  404. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  405. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  406. LocalPort: 10,
  407. RemotePort: 90,
  408. Protocol: firewall.ProtoUDP,
  409. Fragment: false,
  410. }
  411. network := netip.MustParsePrefix("1.2.3.4/24")
  412. c := cert.CachedCertificate{
  413. Certificate: &dummyCert{
  414. name: "host1",
  415. networks: []netip.Prefix{network},
  416. },
  417. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
  418. }
  419. h := HostInfo{
  420. ConnectionState: &ConnectionState{
  421. peerCert: &c,
  422. },
  423. vpnAddrs: []netip.Addr{network.Addr()},
  424. }
  425. h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
  426. c1 := cert.CachedCertificate{
  427. Certificate: &dummyCert{
  428. name: "host1",
  429. networks: []netip.Prefix{network},
  430. },
  431. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
  432. }
  433. h1 := HostInfo{
  434. vpnAddrs: []netip.Addr{network.Addr()},
  435. ConnectionState: &ConnectionState{
  436. peerCert: &c1,
  437. },
  438. }
  439. h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
  440. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  441. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  442. cp := cert.NewCAPool()
  443. // h1/c1 lacks the proper groups
  444. require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
  445. // c has the proper groups
  446. resetConntrack(fw)
  447. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  448. }
  449. func TestFirewall_Drop3(t *testing.T) {
  450. l := test.NewLogger()
  451. ob := &bytes.Buffer{}
  452. l.SetOutput(ob)
  453. myVpnNetworksTable := new(bart.Lite)
  454. myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
  455. p := firewall.Packet{
  456. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  457. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  458. LocalPort: 1,
  459. RemotePort: 1,
  460. Protocol: firewall.ProtoUDP,
  461. Fragment: false,
  462. }
  463. network := netip.MustParsePrefix("1.2.3.4/24")
  464. c := cert.CachedCertificate{
  465. Certificate: &dummyCert{
  466. name: "host-owner",
  467. networks: []netip.Prefix{network},
  468. },
  469. }
  470. c1 := cert.CachedCertificate{
  471. Certificate: &dummyCert{
  472. name: "host1",
  473. networks: []netip.Prefix{network},
  474. issuer: "signer-sha-bad",
  475. },
  476. }
  477. h1 := HostInfo{
  478. ConnectionState: &ConnectionState{
  479. peerCert: &c1,
  480. },
  481. vpnAddrs: []netip.Addr{network.Addr()},
  482. }
  483. h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
  484. c2 := cert.CachedCertificate{
  485. Certificate: &dummyCert{
  486. name: "host2",
  487. networks: []netip.Prefix{network},
  488. issuer: "signer-sha",
  489. },
  490. }
  491. h2 := HostInfo{
  492. ConnectionState: &ConnectionState{
  493. peerCert: &c2,
  494. },
  495. vpnAddrs: []netip.Addr{network.Addr()},
  496. }
  497. h2.buildNetworks(myVpnNetworksTable, c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
  498. c3 := cert.CachedCertificate{
  499. Certificate: &dummyCert{
  500. name: "host3",
  501. networks: []netip.Prefix{network},
  502. issuer: "signer-sha-bad",
  503. },
  504. }
  505. h3 := HostInfo{
  506. ConnectionState: &ConnectionState{
  507. peerCert: &c3,
  508. },
  509. vpnAddrs: []netip.Addr{network.Addr()},
  510. }
  511. h3.buildNetworks(myVpnNetworksTable, c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
  512. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  513. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
  514. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
  515. cp := cert.NewCAPool()
  516. // c1 should pass because host match
  517. require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
  518. // c2 should pass because ca sha match
  519. resetConntrack(fw)
  520. require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
  521. // c3 should fail because no match
  522. resetConntrack(fw)
  523. assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
  524. // Test a remote address match
  525. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  526. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
  527. require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
  528. }
  529. func TestFirewall_Drop3V6(t *testing.T) {
  530. l := test.NewLogger()
  531. ob := &bytes.Buffer{}
  532. l.SetOutput(ob)
  533. myVpnNetworksTable := new(bart.Lite)
  534. myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
  535. p := firewall.Packet{
  536. LocalAddr: netip.MustParseAddr("fd12::34"),
  537. RemoteAddr: netip.MustParseAddr("fd12::34"),
  538. LocalPort: 1,
  539. RemotePort: 1,
  540. Protocol: firewall.ProtoUDP,
  541. Fragment: false,
  542. }
  543. network := netip.MustParsePrefix("fd12::34/120")
  544. c := cert.CachedCertificate{
  545. Certificate: &dummyCert{
  546. name: "host-owner",
  547. networks: []netip.Prefix{network},
  548. },
  549. }
  550. h := HostInfo{
  551. ConnectionState: &ConnectionState{
  552. peerCert: &c,
  553. },
  554. vpnAddrs: []netip.Addr{network.Addr()},
  555. }
  556. h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
  557. // Test a remote address match
  558. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  559. cp := cert.NewCAPool()
  560. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
  561. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  562. }
  563. func TestFirewall_DropConntrackReload(t *testing.T) {
  564. l := test.NewLogger()
  565. ob := &bytes.Buffer{}
  566. l.SetOutput(ob)
  567. myVpnNetworksTable := new(bart.Lite)
  568. myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
  569. p := firewall.Packet{
  570. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  571. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  572. LocalPort: 10,
  573. RemotePort: 90,
  574. Protocol: firewall.ProtoUDP,
  575. Fragment: false,
  576. }
  577. network := netip.MustParsePrefix("1.2.3.4/24")
  578. c := cert.CachedCertificate{
  579. Certificate: &dummyCert{
  580. name: "host1",
  581. networks: []netip.Prefix{network},
  582. groups: []string{"default-group"},
  583. issuer: "signer-shasum",
  584. },
  585. InvertedGroups: map[string]struct{}{"default-group": {}},
  586. }
  587. h := HostInfo{
  588. ConnectionState: &ConnectionState{
  589. peerCert: &c,
  590. },
  591. vpnAddrs: []netip.Addr{network.Addr()},
  592. }
  593. h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
  594. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  595. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  596. cp := cert.NewCAPool()
  597. // Drop outbound
  598. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
  599. // Allow inbound
  600. resetConntrack(fw)
  601. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  602. // Allow outbound because conntrack
  603. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  604. oldFw := fw
  605. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  606. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  607. fw.Conntrack = oldFw.Conntrack
  608. fw.rulesVersion = oldFw.rulesVersion + 1
  609. // Allow outbound because conntrack and new rules allow port 10
  610. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  611. oldFw = fw
  612. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  613. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  614. fw.Conntrack = oldFw.Conntrack
  615. fw.rulesVersion = oldFw.rulesVersion + 1
  616. // Drop outbound because conntrack doesn't match new ruleset
  617. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
  618. }
  619. func TestFirewall_DropIPSpoofing(t *testing.T) {
  620. l := test.NewLogger()
  621. ob := &bytes.Buffer{}
  622. l.SetOutput(ob)
  623. myVpnNetworksTable := new(bart.Lite)
  624. myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
  625. c := cert.CachedCertificate{
  626. Certificate: &dummyCert{
  627. name: "host-owner",
  628. networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.1/24")},
  629. },
  630. }
  631. c1 := cert.CachedCertificate{
  632. Certificate: &dummyCert{
  633. name: "host",
  634. networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.2/24")},
  635. unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")},
  636. },
  637. }
  638. h1 := HostInfo{
  639. ConnectionState: &ConnectionState{
  640. peerCert: &c1,
  641. },
  642. vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
  643. }
  644. h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
  645. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  646. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  647. cp := cert.NewCAPool()
  648. // Packet spoofed by `c1`. Note that the remote addr is not a valid one.
  649. p := firewall.Packet{
  650. LocalAddr: netip.MustParseAddr("192.0.2.1"),
  651. RemoteAddr: netip.MustParseAddr("192.0.2.3"),
  652. LocalPort: 1,
  653. RemotePort: 1,
  654. Protocol: firewall.ProtoUDP,
  655. Fragment: false,
  656. }
  657. assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
  658. }
  659. func BenchmarkLookup(b *testing.B) {
  660. ml := func(m map[string]struct{}, a [][]string) {
  661. for n := 0; n < b.N; n++ {
  662. for _, sg := range a {
  663. found := false
  664. for _, g := range sg {
  665. if _, ok := m[g]; !ok {
  666. found = false
  667. break
  668. }
  669. found = true
  670. }
  671. if found {
  672. return
  673. }
  674. }
  675. }
  676. }
  677. b.Run("array to map best", func(b *testing.B) {
  678. m := map[string]struct{}{
  679. "1ne": {},
  680. "2wo": {},
  681. "3hr": {},
  682. "4ou": {},
  683. "5iv": {},
  684. "6ix": {},
  685. }
  686. a := [][]string{
  687. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  688. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  689. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  690. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  691. {"one", "two", "thr", "fou", "5iv", "6ix"},
  692. {"one", "two", "thr", "fou", "fiv", "6ix"},
  693. {"one", "two", "thr", "fou", "fiv", "six"},
  694. }
  695. for n := 0; n < b.N; n++ {
  696. ml(m, a)
  697. }
  698. })
  699. b.Run("array to map worst", func(b *testing.B) {
  700. m := map[string]struct{}{
  701. "one": {},
  702. "two": {},
  703. "thr": {},
  704. "fou": {},
  705. "fiv": {},
  706. "six": {},
  707. }
  708. a := [][]string{
  709. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  710. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  711. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  712. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  713. {"one", "two", "thr", "fou", "5iv", "6ix"},
  714. {"one", "two", "thr", "fou", "fiv", "6ix"},
  715. {"one", "two", "thr", "fou", "fiv", "six"},
  716. }
  717. for n := 0; n < b.N; n++ {
  718. ml(m, a)
  719. }
  720. })
  721. }
  722. func Test_parsePort(t *testing.T) {
  723. _, _, err := parsePort("")
  724. require.EqualError(t, err, "was not a number; ``")
  725. _, _, err = parsePort(" ")
  726. require.EqualError(t, err, "was not a number; ` `")
  727. _, _, err = parsePort("-")
  728. require.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
  729. _, _, err = parsePort(" - ")
  730. require.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
  731. _, _, err = parsePort("a-b")
  732. require.EqualError(t, err, "beginning range was not a number; `a`")
  733. _, _, err = parsePort("1-b")
  734. require.EqualError(t, err, "ending range was not a number; `b`")
  735. s, e, err := parsePort(" 1 - 2 ")
  736. assert.Equal(t, int32(1), s)
  737. assert.Equal(t, int32(2), e)
  738. require.NoError(t, err)
  739. s, e, err = parsePort("0-1")
  740. assert.Equal(t, int32(0), s)
  741. assert.Equal(t, int32(0), e)
  742. require.NoError(t, err)
  743. s, e, err = parsePort("9919")
  744. assert.Equal(t, int32(9919), s)
  745. assert.Equal(t, int32(9919), e)
  746. require.NoError(t, err)
  747. s, e, err = parsePort("any")
  748. assert.Equal(t, int32(0), s)
  749. assert.Equal(t, int32(0), e)
  750. require.NoError(t, err)
  751. }
  752. func TestNewFirewallFromConfig(t *testing.T) {
  753. l := test.NewLogger()
  754. // Test a bad rule definition
  755. c := &dummyCert{}
  756. cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
  757. require.NoError(t, err)
  758. conf := config.NewC(l)
  759. conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
  760. _, err = NewFirewallFromConfig(l, cs, conf)
  761. require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
  762. // Test both port and code
  763. conf = config.NewC(l)
  764. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
  765. _, err = NewFirewallFromConfig(l, cs, conf)
  766. require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
  767. // Test missing host, group, cidr, ca_name and ca_sha
  768. conf = config.NewC(l)
  769. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
  770. _, err = NewFirewallFromConfig(l, cs, conf)
  771. require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
  772. // Test code/port error
  773. conf = config.NewC(l)
  774. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}}
  775. _, err = NewFirewallFromConfig(l, cs, conf)
  776. require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
  777. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}}
  778. _, err = NewFirewallFromConfig(l, cs, conf)
  779. require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
  780. // Test proto error
  781. conf = config.NewC(l)
  782. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
  783. _, err = NewFirewallFromConfig(l, cs, conf)
  784. require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
  785. // Test cidr parse error
  786. conf = config.NewC(l)
  787. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
  788. _, err = NewFirewallFromConfig(l, cs, conf)
  789. require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
  790. // Test local_cidr parse error
  791. conf = config.NewC(l)
  792. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
  793. _, err = NewFirewallFromConfig(l, cs, conf)
  794. require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
  795. // Test both group and groups
  796. conf = config.NewC(l)
  797. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
  798. _, err = NewFirewallFromConfig(l, cs, conf)
  799. require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
  800. }
  801. func TestAddFirewallRulesFromConfig(t *testing.T) {
  802. l := test.NewLogger()
  803. // Test adding tcp rule
  804. conf := config.NewC(l)
  805. mf := &mockFirewall{}
  806. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
  807. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  808. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  809. // Test adding udp rule
  810. conf = config.NewC(l)
  811. mf = &mockFirewall{}
  812. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
  813. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  814. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  815. // Test adding icmp rule
  816. conf = config.NewC(l)
  817. mf = &mockFirewall{}
  818. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
  819. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  820. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  821. // Test adding any rule
  822. conf = config.NewC(l)
  823. mf = &mockFirewall{}
  824. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
  825. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  826. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  827. // Test adding rule with cidr
  828. cidr := netip.MustParsePrefix("10.0.0.0/8")
  829. conf = config.NewC(l)
  830. mf = &mockFirewall{}
  831. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
  832. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  833. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
  834. // Test adding rule with local_cidr
  835. conf = config.NewC(l)
  836. mf = &mockFirewall{}
  837. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
  838. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  839. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
  840. // Test adding rule with cidr ipv6
  841. cidr6 := netip.MustParsePrefix("fd00::/8")
  842. conf = config.NewC(l)
  843. mf = &mockFirewall{}
  844. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
  845. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  846. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
  847. // Test adding rule with local_cidr ipv6
  848. conf = config.NewC(l)
  849. mf = &mockFirewall{}
  850. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
  851. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  852. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
  853. // Test adding rule with ca_sha
  854. conf = config.NewC(l)
  855. mf = &mockFirewall{}
  856. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
  857. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  858. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
  859. // Test adding rule with ca_name
  860. conf = config.NewC(l)
  861. mf = &mockFirewall{}
  862. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
  863. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  864. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
  865. // Test single group
  866. conf = config.NewC(l)
  867. mf = &mockFirewall{}
  868. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
  869. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  870. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  871. // Test single groups
  872. conf = config.NewC(l)
  873. mf = &mockFirewall{}
  874. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
  875. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  876. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  877. // Test multiple AND groups
  878. conf = config.NewC(l)
  879. mf = &mockFirewall{}
  880. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
  881. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  882. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  883. // Test Add error
  884. conf = config.NewC(l)
  885. mf = &mockFirewall{}
  886. mf.nextCallReturn = errors.New("test error")
  887. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
  888. require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
  889. }
  890. func TestFirewall_convertRule(t *testing.T) {
  891. l := test.NewLogger()
  892. ob := &bytes.Buffer{}
  893. l.SetOutput(ob)
  894. // Ensure group array of 1 is converted and a warning is printed
  895. c := map[string]any{
  896. "group": []any{"group1"},
  897. }
  898. r, err := convertRule(l, c, "test", 1)
  899. assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
  900. require.NoError(t, err)
  901. assert.Equal(t, "group1", r.Group)
  902. // Ensure group array of > 1 is errord
  903. ob.Reset()
  904. c = map[string]any{
  905. "group": []any{"group1", "group2"},
  906. }
  907. r, err = convertRule(l, c, "test", 1)
  908. assert.Empty(t, ob.String())
  909. require.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
  910. // Make sure a well formed group is alright
  911. ob.Reset()
  912. c = map[string]any{
  913. "group": "group1",
  914. }
  915. r, err = convertRule(l, c, "test", 1)
  916. require.NoError(t, err)
  917. assert.Equal(t, "group1", r.Group)
  918. }
  919. type addRuleCall struct {
  920. incoming bool
  921. proto uint8
  922. startPort int32
  923. endPort int32
  924. groups []string
  925. host string
  926. ip netip.Prefix
  927. localIp netip.Prefix
  928. caName string
  929. caSha string
  930. }
  931. type mockFirewall struct {
  932. lastCall addRuleCall
  933. nextCallReturn error
  934. }
  935. func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
  936. mf.lastCall = addRuleCall{
  937. incoming: incoming,
  938. proto: proto,
  939. startPort: startPort,
  940. endPort: endPort,
  941. groups: groups,
  942. host: host,
  943. ip: ip,
  944. localIp: localIp,
  945. caName: caName,
  946. caSha: caSha,
  947. }
  948. err := mf.nextCallReturn
  949. mf.nextCallReturn = nil
  950. return err
  951. }
  952. func resetConntrack(fw *Firewall) {
  953. fw.Conntrack.Lock()
  954. fw.Conntrack.Conns = map[firewall.Packet]*conn{}
  955. fw.Conntrack.Unlock()
  956. }