firewall_test.go 54 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534
  1. package nebula
  2. import (
  3. "bytes"
  4. "errors"
  5. "math"
  6. "net/netip"
  7. "testing"
  8. "time"
  9. "github.com/gaissmai/bart"
  10. "github.com/sirupsen/logrus"
  11. "github.com/slackhq/nebula/cert"
  12. "github.com/slackhq/nebula/config"
  13. "github.com/slackhq/nebula/firewall"
  14. "github.com/slackhq/nebula/test"
  15. "github.com/stretchr/testify/assert"
  16. "github.com/stretchr/testify/require"
  17. )
  18. func TestNewFirewall(t *testing.T) {
  19. l := test.NewLogger()
  20. c := &dummyCert{}
  21. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  22. conntrack := fw.Conntrack
  23. assert.NotNil(t, conntrack)
  24. assert.NotNil(t, conntrack.Conns)
  25. assert.NotNil(t, conntrack.TimerWheel)
  26. assert.NotNil(t, fw.InRules)
  27. assert.NotNil(t, fw.OutRules)
  28. assert.Equal(t, time.Second, fw.TCPTimeout)
  29. assert.Equal(t, time.Minute, fw.UDPTimeout)
  30. assert.Equal(t, time.Hour, fw.DefaultTimeout)
  31. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  32. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  33. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  34. fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
  35. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  36. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  37. fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
  38. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  39. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  40. fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
  41. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  42. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  43. fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
  44. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  45. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  46. fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
  47. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  48. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  49. }
  50. func TestFirewall_AddRule(t *testing.T) {
  51. l := test.NewLogger()
  52. ob := &bytes.Buffer{}
  53. l.SetOutput(ob)
  54. c := &dummyCert{}
  55. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  56. assert.NotNil(t, fw.InRules)
  57. assert.NotNil(t, fw.OutRules)
  58. ti, err := netip.ParsePrefix("1.2.3.4/32")
  59. require.NoError(t, err)
  60. ti6, err := netip.ParsePrefix("fd12::34/128")
  61. require.NoError(t, err)
  62. require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", "", "", "", ""))
  63. // An empty rule is any
  64. assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
  65. assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
  66. assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
  67. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  68. require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", ""))
  69. assert.Nil(t, fw.InRules.UDP[1].Any.Any)
  70. assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
  71. assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
  72. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  73. require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", ""))
  74. //no matter what port is given for icmp, it should end up as "any"
  75. assert.Nil(t, fw.InRules.ICMP[firewall.PortAny].Any.Any)
  76. assert.Empty(t, fw.InRules.ICMP[firewall.PortAny].Any.Groups)
  77. assert.Contains(t, fw.InRules.ICMP[firewall.PortAny].Any.Hosts, "h1")
  78. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  79. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", ""))
  80. assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
  81. _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
  82. assert.True(t, ok)
  83. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  84. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", ""))
  85. assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
  86. _, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
  87. assert.True(t, ok)
  88. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  89. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", ""))
  90. assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
  91. ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
  92. assert.True(t, ok)
  93. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  94. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", ""))
  95. assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
  96. ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
  97. assert.True(t, ok)
  98. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  99. require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", ""))
  100. assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
  101. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  102. require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha"))
  103. assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
  104. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  105. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", ""))
  106. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  107. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  108. anyIp, err := netip.ParsePrefix("0.0.0.0/0")
  109. require.NoError(t, err)
  110. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp.String(), "", "", ""))
  111. assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
  112. table, ok := fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
  113. assert.True(t, table.Any)
  114. table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
  115. assert.False(t, ok)
  116. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  117. anyIp6, err := netip.ParsePrefix("::/0")
  118. require.NoError(t, err)
  119. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6.String(), "", "", ""))
  120. assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
  121. table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
  122. assert.True(t, table.Any)
  123. table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
  124. assert.False(t, ok)
  125. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  126. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", ""))
  127. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  128. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  129. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", ""))
  130. assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  131. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
  132. assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
  133. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  134. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", ""))
  135. assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  136. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
  137. assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
  138. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  139. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", ""))
  140. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  141. // Test error conditions
  142. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  143. require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", ""))
  144. require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", ""))
  145. }
  146. func TestFirewall_Drop(t *testing.T) {
  147. l := test.NewLogger()
  148. ob := &bytes.Buffer{}
  149. l.SetOutput(ob)
  150. myVpnNetworksTable := new(bart.Lite)
  151. myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
  152. p := firewall.Packet{
  153. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  154. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  155. LocalPort: 10,
  156. RemotePort: 90,
  157. Protocol: firewall.ProtoUDP,
  158. Fragment: false,
  159. }
  160. c := dummyCert{
  161. name: "host1",
  162. networks: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/24")},
  163. groups: []string{"default-group"},
  164. issuer: "signer-shasum",
  165. }
  166. h := HostInfo{
  167. ConnectionState: &ConnectionState{
  168. peerCert: &cert.CachedCertificate{
  169. Certificate: &c,
  170. InvertedGroups: map[string]struct{}{"default-group": {}},
  171. },
  172. },
  173. vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
  174. }
  175. h.buildNetworks(myVpnNetworksTable, &c)
  176. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  177. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
  178. cp := cert.NewCAPool()
  179. // Drop outbound
  180. assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
  181. // Allow inbound
  182. resetConntrack(fw)
  183. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  184. // Allow outbound because conntrack
  185. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  186. // test remote mismatch
  187. oldRemote := p.RemoteAddr
  188. p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
  189. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
  190. p.RemoteAddr = oldRemote
  191. // ensure signer doesn't get in the way of group checks
  192. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  193. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
  194. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
  195. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  196. // test caSha doesn't drop on match
  197. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  198. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
  199. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
  200. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  201. // ensure ca name doesn't get in the way of group checks
  202. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  203. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  204. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
  205. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
  206. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  207. // test caName doesn't drop on match
  208. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  209. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  210. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
  211. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
  212. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  213. }
  214. func TestFirewall_DropV6(t *testing.T) {
  215. l := test.NewLogger()
  216. ob := &bytes.Buffer{}
  217. l.SetOutput(ob)
  218. myVpnNetworksTable := new(bart.Lite)
  219. myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
  220. p := firewall.Packet{
  221. LocalAddr: netip.MustParseAddr("fd12::34"),
  222. RemoteAddr: netip.MustParseAddr("fd12::34"),
  223. LocalPort: 10,
  224. RemotePort: 90,
  225. Protocol: firewall.ProtoUDP,
  226. Fragment: false,
  227. }
  228. c := dummyCert{
  229. name: "host1",
  230. networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")},
  231. groups: []string{"default-group"},
  232. issuer: "signer-shasum",
  233. }
  234. h := HostInfo{
  235. ConnectionState: &ConnectionState{
  236. peerCert: &cert.CachedCertificate{
  237. Certificate: &c,
  238. InvertedGroups: map[string]struct{}{"default-group": {}},
  239. },
  240. },
  241. vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
  242. }
  243. h.buildNetworks(myVpnNetworksTable, &c)
  244. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  245. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
  246. cp := cert.NewCAPool()
  247. // Drop outbound
  248. assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
  249. // Allow inbound
  250. resetConntrack(fw)
  251. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  252. // Allow outbound because conntrack
  253. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  254. // test remote mismatch
  255. oldRemote := p.RemoteAddr
  256. p.RemoteAddr = netip.MustParseAddr("fd12::56")
  257. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
  258. p.RemoteAddr = oldRemote
  259. // ensure signer doesn't get in the way of group checks
  260. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  261. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
  262. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
  263. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  264. // test caSha doesn't drop on match
  265. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  266. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
  267. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
  268. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  269. // ensure ca name doesn't get in the way of group checks
  270. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  271. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  272. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
  273. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
  274. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  275. // test caName doesn't drop on match
  276. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  277. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  278. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
  279. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
  280. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  281. }
  282. func BenchmarkFirewallTable_match(b *testing.B) {
  283. f := &Firewall{}
  284. ft := FirewallTable{
  285. TCP: firewallPort{},
  286. }
  287. pfix := netip.MustParsePrefix("172.1.1.1/32")
  288. _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix.String(), "", "", "")
  289. _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix.String(), "", "")
  290. pfix6 := netip.MustParsePrefix("fd11::11/128")
  291. _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6.String(), "", "", "")
  292. _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix6.String(), "", "")
  293. cp := cert.NewCAPool()
  294. b.Run("fail on proto", func(b *testing.B) {
  295. // This benchmark is showing us the cost of failing to match the protocol
  296. c := &cert.CachedCertificate{
  297. Certificate: &dummyCert{},
  298. }
  299. for n := 0; n < b.N; n++ {
  300. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
  301. }
  302. })
  303. b.Run("pass proto, fail on port", func(b *testing.B) {
  304. // This benchmark is showing us the cost of matching a specific protocol but failing to match the port
  305. c := &cert.CachedCertificate{
  306. Certificate: &dummyCert{},
  307. }
  308. for n := 0; n < b.N; n++ {
  309. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
  310. }
  311. })
  312. b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
  313. c := &cert.CachedCertificate{
  314. Certificate: &dummyCert{},
  315. }
  316. ip := netip.MustParsePrefix("9.254.254.254/32")
  317. for n := 0; n < b.N; n++ {
  318. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
  319. }
  320. })
  321. b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) {
  322. c := &cert.CachedCertificate{
  323. Certificate: &dummyCert{},
  324. }
  325. ip := netip.MustParsePrefix("fd99::99/128")
  326. for n := 0; n < b.N; n++ {
  327. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
  328. }
  329. })
  330. b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
  331. c := &cert.CachedCertificate{
  332. Certificate: &dummyCert{
  333. name: "nope",
  334. networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
  335. },
  336. InvertedGroups: map[string]struct{}{"nope": {}},
  337. }
  338. for n := 0; n < b.N; n++ {
  339. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  340. }
  341. })
  342. b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
  343. c := &cert.CachedCertificate{
  344. Certificate: &dummyCert{
  345. name: "nope",
  346. networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
  347. },
  348. InvertedGroups: map[string]struct{}{"nope": {}},
  349. }
  350. for n := 0; n < b.N; n++ {
  351. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  352. }
  353. })
  354. b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
  355. c := &cert.CachedCertificate{
  356. Certificate: &dummyCert{
  357. name: "nope",
  358. networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
  359. },
  360. InvertedGroups: map[string]struct{}{"nope": {}},
  361. }
  362. for n := 0; n < b.N; n++ {
  363. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
  364. }
  365. })
  366. b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
  367. c := &cert.CachedCertificate{
  368. Certificate: &dummyCert{
  369. name: "nope",
  370. networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
  371. },
  372. InvertedGroups: map[string]struct{}{"nope": {}},
  373. }
  374. for n := 0; n < b.N; n++ {
  375. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
  376. }
  377. })
  378. b.Run("pass on group on any local cidr", func(b *testing.B) {
  379. c := &cert.CachedCertificate{
  380. Certificate: &dummyCert{
  381. name: "nope",
  382. },
  383. InvertedGroups: map[string]struct{}{"good-group": {}},
  384. }
  385. for n := 0; n < b.N; n++ {
  386. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  387. }
  388. })
  389. b.Run("pass on group on specific local cidr", func(b *testing.B) {
  390. c := &cert.CachedCertificate{
  391. Certificate: &dummyCert{
  392. name: "nope",
  393. },
  394. InvertedGroups: map[string]struct{}{"good-group": {}},
  395. }
  396. for n := 0; n < b.N; n++ {
  397. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
  398. }
  399. })
  400. b.Run("pass on group on specific local cidr6", func(b *testing.B) {
  401. c := &cert.CachedCertificate{
  402. Certificate: &dummyCert{
  403. name: "nope",
  404. },
  405. InvertedGroups: map[string]struct{}{"good-group": {}},
  406. }
  407. for n := 0; n < b.N; n++ {
  408. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
  409. }
  410. })
  411. b.Run("pass on name", func(b *testing.B) {
  412. c := &cert.CachedCertificate{
  413. Certificate: &dummyCert{
  414. name: "good-host",
  415. },
  416. InvertedGroups: map[string]struct{}{"nope": {}},
  417. }
  418. for n := 0; n < b.N; n++ {
  419. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
  420. }
  421. })
  422. }
  423. func TestFirewall_Drop2(t *testing.T) {
  424. l := test.NewLogger()
  425. ob := &bytes.Buffer{}
  426. l.SetOutput(ob)
  427. myVpnNetworksTable := new(bart.Lite)
  428. myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
  429. p := firewall.Packet{
  430. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  431. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  432. LocalPort: 10,
  433. RemotePort: 90,
  434. Protocol: firewall.ProtoUDP,
  435. Fragment: false,
  436. }
  437. network := netip.MustParsePrefix("1.2.3.4/24")
  438. c := cert.CachedCertificate{
  439. Certificate: &dummyCert{
  440. name: "host1",
  441. networks: []netip.Prefix{network},
  442. },
  443. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
  444. }
  445. h := HostInfo{
  446. ConnectionState: &ConnectionState{
  447. peerCert: &c,
  448. },
  449. vpnAddrs: []netip.Addr{network.Addr()},
  450. }
  451. h.buildNetworks(myVpnNetworksTable, c.Certificate)
  452. c1 := cert.CachedCertificate{
  453. Certificate: &dummyCert{
  454. name: "host1",
  455. networks: []netip.Prefix{network},
  456. },
  457. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
  458. }
  459. h1 := HostInfo{
  460. vpnAddrs: []netip.Addr{network.Addr()},
  461. ConnectionState: &ConnectionState{
  462. peerCert: &c1,
  463. },
  464. }
  465. h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
  466. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  467. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", ""))
  468. cp := cert.NewCAPool()
  469. // h1/c1 lacks the proper groups
  470. require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
  471. // c has the proper groups
  472. resetConntrack(fw)
  473. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  474. }
  475. func TestFirewall_Drop3(t *testing.T) {
  476. l := test.NewLogger()
  477. ob := &bytes.Buffer{}
  478. l.SetOutput(ob)
  479. myVpnNetworksTable := new(bart.Lite)
  480. myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
  481. p := firewall.Packet{
  482. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  483. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  484. LocalPort: 1,
  485. RemotePort: 1,
  486. Protocol: firewall.ProtoUDP,
  487. Fragment: false,
  488. }
  489. network := netip.MustParsePrefix("1.2.3.4/24")
  490. c := cert.CachedCertificate{
  491. Certificate: &dummyCert{
  492. name: "host-owner",
  493. networks: []netip.Prefix{network},
  494. },
  495. }
  496. c1 := cert.CachedCertificate{
  497. Certificate: &dummyCert{
  498. name: "host1",
  499. networks: []netip.Prefix{network},
  500. issuer: "signer-sha-bad",
  501. },
  502. }
  503. h1 := HostInfo{
  504. ConnectionState: &ConnectionState{
  505. peerCert: &c1,
  506. },
  507. vpnAddrs: []netip.Addr{network.Addr()},
  508. }
  509. h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
  510. c2 := cert.CachedCertificate{
  511. Certificate: &dummyCert{
  512. name: "host2",
  513. networks: []netip.Prefix{network},
  514. issuer: "signer-sha",
  515. },
  516. }
  517. h2 := HostInfo{
  518. ConnectionState: &ConnectionState{
  519. peerCert: &c2,
  520. },
  521. vpnAddrs: []netip.Addr{network.Addr()},
  522. }
  523. h2.buildNetworks(myVpnNetworksTable, c2.Certificate)
  524. c3 := cert.CachedCertificate{
  525. Certificate: &dummyCert{
  526. name: "host3",
  527. networks: []netip.Prefix{network},
  528. issuer: "signer-sha-bad",
  529. },
  530. }
  531. h3 := HostInfo{
  532. ConnectionState: &ConnectionState{
  533. peerCert: &c3,
  534. },
  535. vpnAddrs: []netip.Addr{network.Addr()},
  536. }
  537. h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
  538. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  539. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", ""))
  540. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha"))
  541. cp := cert.NewCAPool()
  542. // c1 should pass because host match
  543. require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
  544. // c2 should pass because ca sha match
  545. resetConntrack(fw)
  546. require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
  547. // c3 should fail because no match
  548. resetConntrack(fw)
  549. assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
  550. // Test a remote address match
  551. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  552. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", ""))
  553. require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
  554. }
  555. func TestFirewall_Drop3V6(t *testing.T) {
  556. l := test.NewLogger()
  557. ob := &bytes.Buffer{}
  558. l.SetOutput(ob)
  559. myVpnNetworksTable := new(bart.Lite)
  560. myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
  561. p := firewall.Packet{
  562. LocalAddr: netip.MustParseAddr("fd12::34"),
  563. RemoteAddr: netip.MustParseAddr("fd12::34"),
  564. LocalPort: 1,
  565. RemotePort: 1,
  566. Protocol: firewall.ProtoUDP,
  567. Fragment: false,
  568. }
  569. network := netip.MustParsePrefix("fd12::34/120")
  570. c := cert.CachedCertificate{
  571. Certificate: &dummyCert{
  572. name: "host-owner",
  573. networks: []netip.Prefix{network},
  574. },
  575. }
  576. h := HostInfo{
  577. ConnectionState: &ConnectionState{
  578. peerCert: &c,
  579. },
  580. vpnAddrs: []netip.Addr{network.Addr()},
  581. }
  582. h.buildNetworks(myVpnNetworksTable, c.Certificate)
  583. // Test a remote address match
  584. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  585. cp := cert.NewCAPool()
  586. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", ""))
  587. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  588. }
  589. func TestFirewall_DropConntrackReload(t *testing.T) {
  590. l := test.NewLogger()
  591. ob := &bytes.Buffer{}
  592. l.SetOutput(ob)
  593. myVpnNetworksTable := new(bart.Lite)
  594. myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
  595. p := firewall.Packet{
  596. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  597. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  598. LocalPort: 10,
  599. RemotePort: 90,
  600. Protocol: firewall.ProtoUDP,
  601. Fragment: false,
  602. }
  603. network := netip.MustParsePrefix("1.2.3.4/24")
  604. c := cert.CachedCertificate{
  605. Certificate: &dummyCert{
  606. name: "host1",
  607. networks: []netip.Prefix{network},
  608. groups: []string{"default-group"},
  609. issuer: "signer-shasum",
  610. },
  611. InvertedGroups: map[string]struct{}{"default-group": {}},
  612. }
  613. h := HostInfo{
  614. ConnectionState: &ConnectionState{
  615. peerCert: &c,
  616. },
  617. vpnAddrs: []netip.Addr{network.Addr()},
  618. }
  619. h.buildNetworks(myVpnNetworksTable, c.Certificate)
  620. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  621. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
  622. cp := cert.NewCAPool()
  623. // Drop outbound
  624. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
  625. // Allow inbound
  626. resetConntrack(fw)
  627. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  628. // Allow outbound because conntrack
  629. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  630. oldFw := fw
  631. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  632. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", ""))
  633. fw.Conntrack = oldFw.Conntrack
  634. fw.rulesVersion = oldFw.rulesVersion + 1
  635. // Allow outbound because conntrack and new rules allow port 10
  636. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  637. oldFw = fw
  638. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  639. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", ""))
  640. fw.Conntrack = oldFw.Conntrack
  641. fw.rulesVersion = oldFw.rulesVersion + 1
  642. // Drop outbound because conntrack doesn't match new ruleset
  643. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
  644. }
  645. func TestFirewall_ICMPPortBehavior(t *testing.T) {
  646. l := test.NewLogger()
  647. ob := &bytes.Buffer{}
  648. l.SetOutput(ob)
  649. myVpnNetworksTable := new(bart.Lite)
  650. myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
  651. network := netip.MustParsePrefix("1.2.3.4/24")
  652. c := cert.CachedCertificate{
  653. Certificate: &dummyCert{
  654. name: "host1",
  655. networks: []netip.Prefix{network},
  656. groups: []string{"default-group"},
  657. issuer: "signer-shasum",
  658. },
  659. InvertedGroups: map[string]struct{}{"default-group": {}},
  660. }
  661. h := HostInfo{
  662. ConnectionState: &ConnectionState{
  663. peerCert: &c,
  664. },
  665. vpnAddrs: []netip.Addr{network.Addr()},
  666. }
  667. h.buildNetworks(myVpnNetworksTable, c.Certificate)
  668. cp := cert.NewCAPool()
  669. templ := firewall.Packet{
  670. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  671. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  672. Protocol: firewall.ProtoICMP,
  673. Fragment: false,
  674. }
  675. t.Run("ICMP allowed", func(t *testing.T) {
  676. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  677. require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", ""))
  678. t.Run("zero ports", func(t *testing.T) {
  679. p := templ.Copy()
  680. p.LocalPort = 0
  681. p.RemotePort = 0
  682. // Drop outbound
  683. assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
  684. // Allow inbound
  685. resetConntrack(fw)
  686. require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
  687. //now also allow outbound
  688. require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
  689. })
  690. t.Run("nonzero ports", func(t *testing.T) {
  691. p := templ.Copy()
  692. p.LocalPort = 0xabcd
  693. p.RemotePort = 0x1234
  694. // Drop outbound
  695. assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
  696. // Allow inbound
  697. resetConntrack(fw)
  698. require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
  699. //now also allow outbound
  700. require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
  701. })
  702. })
  703. t.Run("Any proto, some ports allowed", func(t *testing.T) {
  704. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  705. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", ""))
  706. t.Run("zero ports, still blocked", func(t *testing.T) {
  707. p := templ.Copy()
  708. p.LocalPort = 0
  709. p.RemotePort = 0
  710. // Drop outbound
  711. assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
  712. // Allow inbound
  713. resetConntrack(fw)
  714. assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
  715. //now also allow outbound
  716. assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
  717. })
  718. t.Run("nonzero ports, still blocked", func(t *testing.T) {
  719. p := templ.Copy()
  720. p.LocalPort = 0xabcd
  721. p.RemotePort = 0x1234
  722. // Drop outbound
  723. assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
  724. // Allow inbound
  725. resetConntrack(fw)
  726. assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
  727. //now also allow outbound
  728. assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
  729. })
  730. t.Run("nonzero, matching ports, still blocked", func(t *testing.T) {
  731. p := templ.Copy()
  732. p.LocalPort = 80
  733. p.RemotePort = 80
  734. // Drop outbound
  735. assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
  736. // Allow inbound
  737. resetConntrack(fw)
  738. assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
  739. //now also allow outbound
  740. assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
  741. })
  742. })
  743. t.Run("Any proto, any port", func(t *testing.T) {
  744. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  745. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
  746. t.Run("zero ports, allowed", func(t *testing.T) {
  747. resetConntrack(fw)
  748. p := templ.Copy()
  749. p.LocalPort = 0
  750. p.RemotePort = 0
  751. // Drop outbound
  752. assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
  753. // Allow inbound
  754. resetConntrack(fw)
  755. require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
  756. //now also allow outbound
  757. require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
  758. })
  759. t.Run("nonzero ports, allowed", func(t *testing.T) {
  760. resetConntrack(fw)
  761. p := templ.Copy()
  762. p.LocalPort = 0xabcd
  763. p.RemotePort = 0x1234
  764. // Drop outbound
  765. assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
  766. // Allow inbound
  767. resetConntrack(fw)
  768. require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
  769. //now also allow outbound
  770. require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
  771. //different ID is blocked
  772. p.RemotePort++
  773. require.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
  774. })
  775. })
  776. }
  777. func TestFirewall_DropIPSpoofing(t *testing.T) {
  778. l := test.NewLogger()
  779. ob := &bytes.Buffer{}
  780. l.SetOutput(ob)
  781. myVpnNetworksTable := new(bart.Lite)
  782. myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
  783. c := cert.CachedCertificate{
  784. Certificate: &dummyCert{
  785. name: "host-owner",
  786. networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.1/24")},
  787. },
  788. }
  789. c1 := cert.CachedCertificate{
  790. Certificate: &dummyCert{
  791. name: "host",
  792. networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.2/24")},
  793. unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")},
  794. },
  795. }
  796. h1 := HostInfo{
  797. ConnectionState: &ConnectionState{
  798. peerCert: &c1,
  799. },
  800. vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
  801. }
  802. h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
  803. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  804. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", ""))
  805. cp := cert.NewCAPool()
  806. // Packet spoofed by `c1`. Note that the remote addr is not a valid one.
  807. p := firewall.Packet{
  808. LocalAddr: netip.MustParseAddr("192.0.2.1"),
  809. RemoteAddr: netip.MustParseAddr("192.0.2.3"),
  810. LocalPort: 1,
  811. RemotePort: 1,
  812. Protocol: firewall.ProtoUDP,
  813. Fragment: false,
  814. }
  815. assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
  816. }
  817. func BenchmarkLookup(b *testing.B) {
  818. ml := func(m map[string]struct{}, a [][]string) {
  819. for n := 0; n < b.N; n++ {
  820. for _, sg := range a {
  821. found := false
  822. for _, g := range sg {
  823. if _, ok := m[g]; !ok {
  824. found = false
  825. break
  826. }
  827. found = true
  828. }
  829. if found {
  830. return
  831. }
  832. }
  833. }
  834. }
  835. b.Run("array to map best", func(b *testing.B) {
  836. m := map[string]struct{}{
  837. "1ne": {},
  838. "2wo": {},
  839. "3hr": {},
  840. "4ou": {},
  841. "5iv": {},
  842. "6ix": {},
  843. }
  844. a := [][]string{
  845. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  846. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  847. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  848. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  849. {"one", "two", "thr", "fou", "5iv", "6ix"},
  850. {"one", "two", "thr", "fou", "fiv", "6ix"},
  851. {"one", "two", "thr", "fou", "fiv", "six"},
  852. }
  853. for n := 0; n < b.N; n++ {
  854. ml(m, a)
  855. }
  856. })
  857. b.Run("array to map worst", func(b *testing.B) {
  858. m := map[string]struct{}{
  859. "one": {},
  860. "two": {},
  861. "thr": {},
  862. "fou": {},
  863. "fiv": {},
  864. "six": {},
  865. }
  866. a := [][]string{
  867. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  868. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  869. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  870. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  871. {"one", "two", "thr", "fou", "5iv", "6ix"},
  872. {"one", "two", "thr", "fou", "fiv", "6ix"},
  873. {"one", "two", "thr", "fou", "fiv", "six"},
  874. }
  875. for n := 0; n < b.N; n++ {
  876. ml(m, a)
  877. }
  878. })
  879. }
  880. func Test_parsePort(t *testing.T) {
  881. _, _, err := parsePort("")
  882. require.EqualError(t, err, "was not a number; ``")
  883. _, _, err = parsePort(" ")
  884. require.EqualError(t, err, "was not a number; ` `")
  885. _, _, err = parsePort("-")
  886. require.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
  887. _, _, err = parsePort(" - ")
  888. require.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
  889. _, _, err = parsePort("a-b")
  890. require.EqualError(t, err, "beginning range was not a number; `a`")
  891. _, _, err = parsePort("1-b")
  892. require.EqualError(t, err, "ending range was not a number; `b`")
  893. s, e, err := parsePort(" 1 - 2 ")
  894. assert.Equal(t, int32(1), s)
  895. assert.Equal(t, int32(2), e)
  896. require.NoError(t, err)
  897. s, e, err = parsePort("0-1")
  898. assert.Equal(t, int32(0), s)
  899. assert.Equal(t, int32(0), e)
  900. require.NoError(t, err)
  901. s, e, err = parsePort("9919")
  902. assert.Equal(t, int32(9919), s)
  903. assert.Equal(t, int32(9919), e)
  904. require.NoError(t, err)
  905. s, e, err = parsePort("any")
  906. assert.Equal(t, int32(0), s)
  907. assert.Equal(t, int32(0), e)
  908. require.NoError(t, err)
  909. }
  910. func TestNewFirewallFromConfig(t *testing.T) {
  911. l := test.NewLogger()
  912. // Test a bad rule definition
  913. c := &dummyCert{}
  914. cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
  915. require.NoError(t, err)
  916. conf := config.NewC(l)
  917. conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
  918. _, err = NewFirewallFromConfig(l, cs, conf)
  919. require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
  920. // Test both port and code
  921. conf = config.NewC(l)
  922. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
  923. _, err = NewFirewallFromConfig(l, cs, conf)
  924. require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
  925. // Test missing host, group, cidr, ca_name and ca_sha
  926. conf = config.NewC(l)
  927. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
  928. _, err = NewFirewallFromConfig(l, cs, conf)
  929. require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
  930. // Test code/port error
  931. conf = config.NewC(l)
  932. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}}
  933. _, err = NewFirewallFromConfig(l, cs, conf)
  934. require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
  935. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh", "proto": "any"}}}
  936. _, err = NewFirewallFromConfig(l, cs, conf)
  937. require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
  938. // Test proto error
  939. conf = config.NewC(l)
  940. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
  941. _, err = NewFirewallFromConfig(l, cs, conf)
  942. require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
  943. // Test cidr parse error
  944. conf = config.NewC(l)
  945. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
  946. _, err = NewFirewallFromConfig(l, cs, conf)
  947. require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
  948. // Test local_cidr parse error
  949. conf = config.NewC(l)
  950. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
  951. _, err = NewFirewallFromConfig(l, cs, conf)
  952. require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
  953. // Test both group and groups
  954. conf = config.NewC(l)
  955. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
  956. _, err = NewFirewallFromConfig(l, cs, conf)
  957. require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
  958. }
  959. func TestAddFirewallRulesFromConfig(t *testing.T) {
  960. l := test.NewLogger()
  961. // Test adding tcp rule
  962. conf := config.NewC(l)
  963. mf := &mockFirewall{}
  964. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
  965. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  966. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
  967. // Test adding udp rule
  968. conf = config.NewC(l)
  969. mf = &mockFirewall{}
  970. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
  971. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  972. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
  973. // Test adding icmp rule
  974. conf = config.NewC(l)
  975. mf = &mockFirewall{}
  976. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
  977. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  978. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
  979. // Test adding icmp rule no port
  980. conf = config.NewC(l)
  981. mf = &mockFirewall{}
  982. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"proto": "icmp", "host": "a"}}}
  983. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  984. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
  985. // Test adding any rule
  986. conf = config.NewC(l)
  987. mf = &mockFirewall{}
  988. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
  989. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  990. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
  991. // Test adding rule with cidr
  992. cidr := netip.MustParsePrefix("10.0.0.0/8")
  993. conf = config.NewC(l)
  994. mf = &mockFirewall{}
  995. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
  996. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  997. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall)
  998. // Test adding rule with local_cidr
  999. conf = config.NewC(l)
  1000. mf = &mockFirewall{}
  1001. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
  1002. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  1003. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr.String()}, mf.lastCall)
  1004. // Test adding rule with cidr ipv6
  1005. cidr6 := netip.MustParsePrefix("fd00::/8")
  1006. conf = config.NewC(l)
  1007. mf = &mockFirewall{}
  1008. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
  1009. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  1010. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall)
  1011. // Test adding rule with any cidr
  1012. conf = config.NewC(l)
  1013. mf = &mockFirewall{}
  1014. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
  1015. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  1016. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
  1017. // Test adding rule with junk cidr
  1018. conf = config.NewC(l)
  1019. mf = &mockFirewall{}
  1020. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
  1021. require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
  1022. // Test adding rule with local_cidr ipv6
  1023. conf = config.NewC(l)
  1024. mf = &mockFirewall{}
  1025. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
  1026. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  1027. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall)
  1028. // Test adding rule with any local_cidr
  1029. conf = config.NewC(l)
  1030. mf = &mockFirewall{}
  1031. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
  1032. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  1033. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
  1034. // Test adding rule with junk local_cidr
  1035. conf = config.NewC(l)
  1036. mf = &mockFirewall{}
  1037. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
  1038. require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
  1039. // Test adding rule with ca_sha
  1040. conf = config.NewC(l)
  1041. mf = &mockFirewall{}
  1042. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
  1043. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  1044. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall)
  1045. // Test adding rule with ca_name
  1046. conf = config.NewC(l)
  1047. mf = &mockFirewall{}
  1048. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
  1049. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  1050. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall)
  1051. // Test single group
  1052. conf = config.NewC(l)
  1053. mf = &mockFirewall{}
  1054. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
  1055. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  1056. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
  1057. // Test single groups
  1058. conf = config.NewC(l)
  1059. mf = &mockFirewall{}
  1060. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
  1061. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  1062. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
  1063. // Test multiple AND groups
  1064. conf = config.NewC(l)
  1065. mf = &mockFirewall{}
  1066. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
  1067. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  1068. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall)
  1069. // Test Add error
  1070. conf = config.NewC(l)
  1071. mf = &mockFirewall{}
  1072. mf.nextCallReturn = errors.New("test error")
  1073. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
  1074. require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
  1075. }
  1076. func TestFirewall_convertRule(t *testing.T) {
  1077. l := test.NewLogger()
  1078. ob := &bytes.Buffer{}
  1079. l.SetOutput(ob)
  1080. // Ensure group array of 1 is converted and a warning is printed
  1081. c := map[string]any{
  1082. "group": []any{"group1"},
  1083. }
  1084. r, err := convertRule(l, c, "test", 1)
  1085. assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
  1086. require.NoError(t, err)
  1087. assert.Equal(t, []string{"group1"}, r.Groups)
  1088. // Ensure group array of > 1 is errord
  1089. ob.Reset()
  1090. c = map[string]any{
  1091. "group": []any{"group1", "group2"},
  1092. }
  1093. r, err = convertRule(l, c, "test", 1)
  1094. assert.Empty(t, ob.String())
  1095. require.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
  1096. // Make sure a well formed group is alright
  1097. ob.Reset()
  1098. c = map[string]any{
  1099. "group": "group1",
  1100. }
  1101. r, err = convertRule(l, c, "test", 1)
  1102. require.NoError(t, err)
  1103. assert.Equal(t, []string{"group1"}, r.Groups)
  1104. }
  1105. func TestFirewall_convertRuleSanity(t *testing.T) {
  1106. l := test.NewLogger()
  1107. ob := &bytes.Buffer{}
  1108. l.SetOutput(ob)
  1109. noWarningPlease := []map[string]any{
  1110. {"group": "group1"},
  1111. {"groups": []any{"group2"}},
  1112. {"host": "bob"},
  1113. {"cidr": "1.1.1.1/1"},
  1114. {"groups": []any{"group2"}, "host": "bob"},
  1115. {"cidr": "1.1.1.1/1", "host": "bob"},
  1116. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
  1117. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
  1118. }
  1119. for _, c := range noWarningPlease {
  1120. r, err := convertRule(l, c, "test", 1)
  1121. require.NoError(t, err)
  1122. require.NoError(t, r.sanity(), "should not generate a sanity warning, %+v", c)
  1123. }
  1124. yesWarningPlease := []map[string]any{
  1125. {"group": "group1"},
  1126. {"groups": []any{"group2"}},
  1127. {"cidr": "1.1.1.1/1"},
  1128. {"groups": []any{"group2"}, "host": "bob"},
  1129. {"cidr": "1.1.1.1/1", "host": "bob"},
  1130. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
  1131. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
  1132. }
  1133. for _, c := range yesWarningPlease {
  1134. c["host"] = "any"
  1135. r, err := convertRule(l, c, "test", 1)
  1136. require.NoError(t, err)
  1137. err = r.sanity()
  1138. require.Error(t, err, "I wanted a warning: %+v", c)
  1139. }
  1140. //reset the list
  1141. yesWarningPlease = []map[string]any{
  1142. {"group": "group1"},
  1143. {"groups": []any{"group2"}},
  1144. {"cidr": "1.1.1.1/1"},
  1145. {"groups": []any{"group2"}, "host": "bob"},
  1146. {"cidr": "1.1.1.1/1", "host": "bob"},
  1147. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
  1148. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
  1149. }
  1150. for _, c := range yesWarningPlease {
  1151. r, err := convertRule(l, c, "test", 1)
  1152. require.NoError(t, err)
  1153. r.Groups = append(r.Groups, "any")
  1154. err = r.sanity()
  1155. require.Error(t, err, "I wanted a warning: %+v", c)
  1156. }
  1157. }
  1158. type testcase struct {
  1159. h *HostInfo
  1160. p firewall.Packet
  1161. c cert.Certificate
  1162. err error
  1163. }
  1164. func (c *testcase) Test(t *testing.T, fw *Firewall) {
  1165. t.Helper()
  1166. cp := cert.NewCAPool()
  1167. resetConntrack(fw)
  1168. err := fw.Drop(c.p, true, c.h, cp, nil)
  1169. if c.err == nil {
  1170. require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
  1171. } else {
  1172. require.ErrorIs(t, c.err, err, "failed to drop remote address %s", c.p.RemoteAddr)
  1173. }
  1174. }
  1175. func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
  1176. c1 := dummyCert{
  1177. name: "host1",
  1178. networks: theirPrefixes,
  1179. groups: []string{"default-group"},
  1180. issuer: "signer-shasum",
  1181. }
  1182. h := HostInfo{
  1183. ConnectionState: &ConnectionState{
  1184. peerCert: &cert.CachedCertificate{
  1185. Certificate: &c1,
  1186. InvertedGroups: map[string]struct{}{"default-group": {}},
  1187. },
  1188. },
  1189. vpnAddrs: make([]netip.Addr, len(theirPrefixes)),
  1190. }
  1191. for i := range theirPrefixes {
  1192. h.vpnAddrs[i] = theirPrefixes[i].Addr()
  1193. }
  1194. h.buildNetworks(setup.myVpnNetworksTable, &c1)
  1195. p := firewall.Packet{
  1196. LocalAddr: setup.c.Networks()[0].Addr(), //todo?
  1197. RemoteAddr: theirPrefixes[0].Addr(),
  1198. LocalPort: 10,
  1199. RemotePort: 90,
  1200. Protocol: firewall.ProtoUDP,
  1201. Fragment: false,
  1202. }
  1203. return testcase{
  1204. h: &h,
  1205. p: p,
  1206. c: &c1,
  1207. err: err,
  1208. }
  1209. }
  1210. type testsetup struct {
  1211. c dummyCert
  1212. myVpnNetworksTable *bart.Lite
  1213. fw *Firewall
  1214. }
  1215. func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
  1216. c := dummyCert{
  1217. name: "me",
  1218. networks: myPrefixes,
  1219. groups: []string{"default-group"},
  1220. issuer: "signer-shasum",
  1221. }
  1222. return newSetupFromCert(t, l, c)
  1223. }
  1224. func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
  1225. myVpnNetworksTable := new(bart.Lite)
  1226. for _, prefix := range c.Networks() {
  1227. myVpnNetworksTable.Insert(prefix)
  1228. }
  1229. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  1230. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
  1231. return testsetup{
  1232. c: c,
  1233. fw: fw,
  1234. myVpnNetworksTable: myVpnNetworksTable,
  1235. }
  1236. }
  1237. func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
  1238. t.Parallel()
  1239. l := test.NewLogger()
  1240. ob := &bytes.Buffer{}
  1241. l.SetOutput(ob)
  1242. myPrefix := netip.MustParsePrefix("1.1.1.1/8")
  1243. // for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
  1244. t.Run("allow inbound all matching", func(t *testing.T) {
  1245. t.Parallel()
  1246. setup := newSetup(t, l, myPrefix)
  1247. tc := buildTestCase(setup, nil, netip.MustParsePrefix("1.2.3.4/24"))
  1248. tc.Test(t, setup.fw)
  1249. })
  1250. t.Run("allow inbound local matching", func(t *testing.T) {
  1251. t.Parallel()
  1252. setup := newSetup(t, l, myPrefix)
  1253. tc := buildTestCase(setup, ErrInvalidLocalIP, netip.MustParsePrefix("1.2.3.4/24"))
  1254. tc.p.LocalAddr = netip.MustParseAddr("1.2.3.8")
  1255. tc.Test(t, setup.fw)
  1256. })
  1257. t.Run("block inbound remote mismatched", func(t *testing.T) {
  1258. t.Parallel()
  1259. setup := newSetup(t, l, myPrefix)
  1260. tc := buildTestCase(setup, ErrInvalidRemoteIP, netip.MustParsePrefix("1.2.3.4/24"))
  1261. tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
  1262. tc.Test(t, setup.fw)
  1263. })
  1264. t.Run("Block a vpn peer packet", func(t *testing.T) {
  1265. t.Parallel()
  1266. setup := newSetup(t, l, myPrefix)
  1267. tc := buildTestCase(setup, ErrPeerRejected, netip.MustParsePrefix("2.2.2.2/24"))
  1268. tc.Test(t, setup.fw)
  1269. })
  1270. twoPrefixes := []netip.Prefix{
  1271. netip.MustParsePrefix("1.2.3.4/24"), netip.MustParsePrefix("2.2.2.2/24"),
  1272. }
  1273. t.Run("allow inbound one matching", func(t *testing.T) {
  1274. t.Parallel()
  1275. setup := newSetup(t, l, myPrefix)
  1276. tc := buildTestCase(setup, nil, twoPrefixes...)
  1277. tc.Test(t, setup.fw)
  1278. })
  1279. t.Run("block inbound multimismatch", func(t *testing.T) {
  1280. t.Parallel()
  1281. setup := newSetup(t, l, myPrefix)
  1282. tc := buildTestCase(setup, ErrInvalidRemoteIP, twoPrefixes...)
  1283. tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
  1284. tc.Test(t, setup.fw)
  1285. })
  1286. t.Run("allow inbound 2nd one matching", func(t *testing.T) {
  1287. t.Parallel()
  1288. setup2 := newSetup(t, l, netip.MustParsePrefix("2.2.2.1/24"))
  1289. tc := buildTestCase(setup2, nil, twoPrefixes...)
  1290. tc.p.RemoteAddr = twoPrefixes[1].Addr()
  1291. tc.Test(t, setup2.fw)
  1292. })
  1293. t.Run("allow inbound unsafe route", func(t *testing.T) {
  1294. t.Parallel()
  1295. unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
  1296. c := dummyCert{
  1297. name: "me",
  1298. networks: []netip.Prefix{myPrefix},
  1299. unsafeNetworks: []netip.Prefix{unsafePrefix},
  1300. groups: []string{"default-group"},
  1301. issuer: "signer-shasum",
  1302. }
  1303. unsafeSetup := newSetupFromCert(t, l, c)
  1304. tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
  1305. tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
  1306. tc.err = ErrNoMatchingRule
  1307. tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
  1308. require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", ""))
  1309. tc.err = nil
  1310. tc.Test(t, unsafeSetup.fw) //should pass
  1311. })
  1312. }
  1313. type addRuleCall struct {
  1314. incoming bool
  1315. proto uint8
  1316. startPort int32
  1317. endPort int32
  1318. groups []string
  1319. host string
  1320. ip string
  1321. localIp string
  1322. caName string
  1323. caSha string
  1324. }
  1325. type mockFirewall struct {
  1326. lastCall addRuleCall
  1327. nextCallReturn error
  1328. }
  1329. func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp, caName string, caSha string) error {
  1330. mf.lastCall = addRuleCall{
  1331. incoming: incoming,
  1332. proto: proto,
  1333. startPort: startPort,
  1334. endPort: endPort,
  1335. groups: groups,
  1336. host: host,
  1337. ip: ip,
  1338. localIp: localIp,
  1339. caName: caName,
  1340. caSha: caSha,
  1341. }
  1342. err := mf.nextCallReturn
  1343. mf.nextCallReturn = nil
  1344. return err
  1345. }
  1346. func resetConntrack(fw *Firewall) {
  1347. fw.Conntrack.Lock()
  1348. fw.Conntrack.Conns = map[firewall.Packet]*conn{}
  1349. fw.Conntrack.Unlock()
  1350. }