12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091 |
- package nebula
- import (
- "bytes"
- "errors"
- "math"
- "net/netip"
- "testing"
- "time"
- "github.com/slackhq/nebula/cert"
- "github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/firewall"
- "github.com/slackhq/nebula/test"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- )
- func TestNewFirewall(t *testing.T) {
- l := test.NewLogger()
- c := &dummyCert{}
- fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- conntrack := fw.Conntrack
- assert.NotNil(t, conntrack)
- assert.NotNil(t, conntrack.Conns)
- assert.NotNil(t, conntrack.TimerWheel)
- assert.NotNil(t, fw.InRules)
- assert.NotNil(t, fw.OutRules)
- assert.Equal(t, time.Second, fw.TCPTimeout)
- assert.Equal(t, time.Minute, fw.UDPTimeout)
- assert.Equal(t, time.Hour, fw.DefaultTimeout)
- assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
- assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
- assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
- fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
- assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
- assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
- fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
- assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
- assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
- fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
- assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
- assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
- fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
- assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
- assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
- fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
- assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
- assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
- }
- func TestFirewall_AddRule(t *testing.T) {
- l := test.NewLogger()
- ob := &bytes.Buffer{}
- l.SetOutput(ob)
- c := &dummyCert{}
- fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- assert.NotNil(t, fw.InRules)
- assert.NotNil(t, fw.OutRules)
- ti, err := netip.ParsePrefix("1.2.3.4/32")
- require.NoError(t, err)
- ti6, err := netip.ParsePrefix("fd12::34/128")
- require.NoError(t, err)
- require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
- // An empty rule is any
- assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
- assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
- assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
- assert.Nil(t, fw.InRules.UDP[1].Any.Any)
- assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
- assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
- assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
- assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
- assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
- assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
- _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
- assert.True(t, ok)
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
- assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
- _, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
- assert.True(t, ok)
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
- assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
- _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
- assert.True(t, ok)
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
- assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
- _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
- assert.True(t, ok)
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
- assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
- assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
- assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- anyIp, err := netip.ParsePrefix("0.0.0.0/0")
- require.NoError(t, err)
- require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
- assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- anyIp6, err := netip.ParsePrefix("::/0")
- require.NoError(t, err)
- require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
- assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
- // Test error conditions
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
- require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
- }
- func TestFirewall_Drop(t *testing.T) {
- l := test.NewLogger()
- ob := &bytes.Buffer{}
- l.SetOutput(ob)
- p := firewall.Packet{
- LocalAddr: netip.MustParseAddr("1.2.3.4"),
- RemoteAddr: netip.MustParseAddr("1.2.3.4"),
- LocalPort: 10,
- RemotePort: 90,
- Protocol: firewall.ProtoUDP,
- Fragment: false,
- }
- c := dummyCert{
- name: "host1",
- networks: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/24")},
- groups: []string{"default-group"},
- issuer: "signer-shasum",
- }
- h := HostInfo{
- ConnectionState: &ConnectionState{
- peerCert: &cert.CachedCertificate{
- Certificate: &c,
- InvertedGroups: map[string]struct{}{"default-group": {}},
- },
- },
- vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
- }
- h.buildNetworks(c.networks, c.unsafeNetworks)
- fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
- cp := cert.NewCAPool()
- // Drop outbound
- assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
- // Allow inbound
- resetConntrack(fw)
- require.NoError(t, fw.Drop(p, true, &h, cp, nil))
- // Allow outbound because conntrack
- require.NoError(t, fw.Drop(p, false, &h, cp, nil))
- // test remote mismatch
- oldRemote := p.RemoteAddr
- p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
- assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
- p.RemoteAddr = oldRemote
- // ensure signer doesn't get in the way of group checks
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
- assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
- // test caSha doesn't drop on match
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
- require.NoError(t, fw.Drop(p, true, &h, cp, nil))
- // ensure ca name doesn't get in the way of group checks
- cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
- assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
- // test caName doesn't drop on match
- cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
- require.NoError(t, fw.Drop(p, true, &h, cp, nil))
- }
- func TestFirewall_DropV6(t *testing.T) {
- l := test.NewLogger()
- ob := &bytes.Buffer{}
- l.SetOutput(ob)
- p := firewall.Packet{
- LocalAddr: netip.MustParseAddr("fd12::34"),
- RemoteAddr: netip.MustParseAddr("fd12::34"),
- LocalPort: 10,
- RemotePort: 90,
- Protocol: firewall.ProtoUDP,
- Fragment: false,
- }
- c := dummyCert{
- name: "host1",
- networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")},
- groups: []string{"default-group"},
- issuer: "signer-shasum",
- }
- h := HostInfo{
- ConnectionState: &ConnectionState{
- peerCert: &cert.CachedCertificate{
- Certificate: &c,
- InvertedGroups: map[string]struct{}{"default-group": {}},
- },
- },
- vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
- }
- h.buildNetworks(c.networks, c.unsafeNetworks)
- fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
- cp := cert.NewCAPool()
- // Drop outbound
- assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
- // Allow inbound
- resetConntrack(fw)
- require.NoError(t, fw.Drop(p, true, &h, cp, nil))
- // Allow outbound because conntrack
- require.NoError(t, fw.Drop(p, false, &h, cp, nil))
- // test remote mismatch
- oldRemote := p.RemoteAddr
- p.RemoteAddr = netip.MustParseAddr("fd12::56")
- assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
- p.RemoteAddr = oldRemote
- // ensure signer doesn't get in the way of group checks
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
- assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
- // test caSha doesn't drop on match
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
- require.NoError(t, fw.Drop(p, true, &h, cp, nil))
- // ensure ca name doesn't get in the way of group checks
- cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
- assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
- // test caName doesn't drop on match
- cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
- require.NoError(t, fw.Drop(p, true, &h, cp, nil))
- }
- func BenchmarkFirewallTable_match(b *testing.B) {
- f := &Firewall{}
- ft := FirewallTable{
- TCP: firewallPort{},
- }
- pfix := netip.MustParsePrefix("172.1.1.1/32")
- _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
- _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
- pfix6 := netip.MustParsePrefix("fd11::11/128")
- _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
- _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
- cp := cert.NewCAPool()
- b.Run("fail on proto", func(b *testing.B) {
- // This benchmark is showing us the cost of failing to match the protocol
- c := &cert.CachedCertificate{
- Certificate: &dummyCert{},
- }
- for n := 0; n < b.N; n++ {
- assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
- }
- })
- b.Run("pass proto, fail on port", func(b *testing.B) {
- // This benchmark is showing us the cost of matching a specific protocol but failing to match the port
- c := &cert.CachedCertificate{
- Certificate: &dummyCert{},
- }
- for n := 0; n < b.N; n++ {
- assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
- }
- })
- b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
- c := &cert.CachedCertificate{
- Certificate: &dummyCert{},
- }
- ip := netip.MustParsePrefix("9.254.254.254/32")
- for n := 0; n < b.N; n++ {
- assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
- }
- })
- b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) {
- c := &cert.CachedCertificate{
- Certificate: &dummyCert{},
- }
- ip := netip.MustParsePrefix("fd99::99/128")
- for n := 0; n < b.N; n++ {
- assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
- }
- })
- b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
- c := &cert.CachedCertificate{
- Certificate: &dummyCert{
- name: "nope",
- networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
- },
- InvertedGroups: map[string]struct{}{"nope": {}},
- }
- for n := 0; n < b.N; n++ {
- assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
- }
- })
- b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
- c := &cert.CachedCertificate{
- Certificate: &dummyCert{
- name: "nope",
- networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
- },
- InvertedGroups: map[string]struct{}{"nope": {}},
- }
- for n := 0; n < b.N; n++ {
- assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
- }
- })
- b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
- c := &cert.CachedCertificate{
- Certificate: &dummyCert{
- name: "nope",
- networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
- },
- InvertedGroups: map[string]struct{}{"nope": {}},
- }
- for n := 0; n < b.N; n++ {
- assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
- }
- })
- b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
- c := &cert.CachedCertificate{
- Certificate: &dummyCert{
- name: "nope",
- networks: []netip.Prefix{netip.MustParsePrefix("fd99:99/128")},
- },
- InvertedGroups: map[string]struct{}{"nope": {}},
- }
- for n := 0; n < b.N; n++ {
- assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
- }
- })
- b.Run("pass on group on any local cidr", func(b *testing.B) {
- c := &cert.CachedCertificate{
- Certificate: &dummyCert{
- name: "nope",
- },
- InvertedGroups: map[string]struct{}{"good-group": {}},
- }
- for n := 0; n < b.N; n++ {
- assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
- }
- })
- b.Run("pass on group on specific local cidr", func(b *testing.B) {
- c := &cert.CachedCertificate{
- Certificate: &dummyCert{
- name: "nope",
- },
- InvertedGroups: map[string]struct{}{"good-group": {}},
- }
- for n := 0; n < b.N; n++ {
- assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
- }
- })
- b.Run("pass on group on specific local cidr6", func(b *testing.B) {
- c := &cert.CachedCertificate{
- Certificate: &dummyCert{
- name: "nope",
- },
- InvertedGroups: map[string]struct{}{"good-group": {}},
- }
- for n := 0; n < b.N; n++ {
- assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
- }
- })
- b.Run("pass on name", func(b *testing.B) {
- c := &cert.CachedCertificate{
- Certificate: &dummyCert{
- name: "good-host",
- },
- InvertedGroups: map[string]struct{}{"nope": {}},
- }
- for n := 0; n < b.N; n++ {
- ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
- }
- })
- }
- func TestFirewall_Drop2(t *testing.T) {
- l := test.NewLogger()
- ob := &bytes.Buffer{}
- l.SetOutput(ob)
- p := firewall.Packet{
- LocalAddr: netip.MustParseAddr("1.2.3.4"),
- RemoteAddr: netip.MustParseAddr("1.2.3.4"),
- LocalPort: 10,
- RemotePort: 90,
- Protocol: firewall.ProtoUDP,
- Fragment: false,
- }
- network := netip.MustParsePrefix("1.2.3.4/24")
- c := cert.CachedCertificate{
- Certificate: &dummyCert{
- name: "host1",
- networks: []netip.Prefix{network},
- },
- InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
- }
- h := HostInfo{
- ConnectionState: &ConnectionState{
- peerCert: &c,
- },
- vpnAddrs: []netip.Addr{network.Addr()},
- }
- h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
- c1 := cert.CachedCertificate{
- Certificate: &dummyCert{
- name: "host1",
- networks: []netip.Prefix{network},
- },
- InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
- }
- h1 := HostInfo{
- vpnAddrs: []netip.Addr{network.Addr()},
- ConnectionState: &ConnectionState{
- peerCert: &c1,
- },
- }
- h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
- fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
- cp := cert.NewCAPool()
- // h1/c1 lacks the proper groups
- require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
- // c has the proper groups
- resetConntrack(fw)
- require.NoError(t, fw.Drop(p, true, &h, cp, nil))
- }
- func TestFirewall_Drop3(t *testing.T) {
- l := test.NewLogger()
- ob := &bytes.Buffer{}
- l.SetOutput(ob)
- p := firewall.Packet{
- LocalAddr: netip.MustParseAddr("1.2.3.4"),
- RemoteAddr: netip.MustParseAddr("1.2.3.4"),
- LocalPort: 1,
- RemotePort: 1,
- Protocol: firewall.ProtoUDP,
- Fragment: false,
- }
- network := netip.MustParsePrefix("1.2.3.4/24")
- c := cert.CachedCertificate{
- Certificate: &dummyCert{
- name: "host-owner",
- networks: []netip.Prefix{network},
- },
- }
- c1 := cert.CachedCertificate{
- Certificate: &dummyCert{
- name: "host1",
- networks: []netip.Prefix{network},
- issuer: "signer-sha-bad",
- },
- }
- h1 := HostInfo{
- ConnectionState: &ConnectionState{
- peerCert: &c1,
- },
- vpnAddrs: []netip.Addr{network.Addr()},
- }
- h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
- c2 := cert.CachedCertificate{
- Certificate: &dummyCert{
- name: "host2",
- networks: []netip.Prefix{network},
- issuer: "signer-sha",
- },
- }
- h2 := HostInfo{
- ConnectionState: &ConnectionState{
- peerCert: &c2,
- },
- vpnAddrs: []netip.Addr{network.Addr()},
- }
- h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
- c3 := cert.CachedCertificate{
- Certificate: &dummyCert{
- name: "host3",
- networks: []netip.Prefix{network},
- issuer: "signer-sha-bad",
- },
- }
- h3 := HostInfo{
- ConnectionState: &ConnectionState{
- peerCert: &c3,
- },
- vpnAddrs: []netip.Addr{network.Addr()},
- }
- h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
- fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
- cp := cert.NewCAPool()
- // c1 should pass because host match
- require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
- // c2 should pass because ca sha match
- resetConntrack(fw)
- require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
- // c3 should fail because no match
- resetConntrack(fw)
- assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
- // Test a remote address match
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
- require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
- }
- func TestFirewall_Drop3V6(t *testing.T) {
- l := test.NewLogger()
- ob := &bytes.Buffer{}
- l.SetOutput(ob)
- p := firewall.Packet{
- LocalAddr: netip.MustParseAddr("fd12::34"),
- RemoteAddr: netip.MustParseAddr("fd12::34"),
- LocalPort: 1,
- RemotePort: 1,
- Protocol: firewall.ProtoUDP,
- Fragment: false,
- }
- network := netip.MustParsePrefix("fd12::34/120")
- c := cert.CachedCertificate{
- Certificate: &dummyCert{
- name: "host-owner",
- networks: []netip.Prefix{network},
- },
- }
- h := HostInfo{
- ConnectionState: &ConnectionState{
- peerCert: &c,
- },
- vpnAddrs: []netip.Addr{network.Addr()},
- }
- h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
- // Test a remote address match
- fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
- cp := cert.NewCAPool()
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
- require.NoError(t, fw.Drop(p, true, &h, cp, nil))
- }
- func TestFirewall_DropConntrackReload(t *testing.T) {
- l := test.NewLogger()
- ob := &bytes.Buffer{}
- l.SetOutput(ob)
- p := firewall.Packet{
- LocalAddr: netip.MustParseAddr("1.2.3.4"),
- RemoteAddr: netip.MustParseAddr("1.2.3.4"),
- LocalPort: 10,
- RemotePort: 90,
- Protocol: firewall.ProtoUDP,
- Fragment: false,
- }
- network := netip.MustParsePrefix("1.2.3.4/24")
- c := cert.CachedCertificate{
- Certificate: &dummyCert{
- name: "host1",
- networks: []netip.Prefix{network},
- groups: []string{"default-group"},
- issuer: "signer-shasum",
- },
- InvertedGroups: map[string]struct{}{"default-group": {}},
- }
- h := HostInfo{
- ConnectionState: &ConnectionState{
- peerCert: &c,
- },
- vpnAddrs: []netip.Addr{network.Addr()},
- }
- h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
- fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
- cp := cert.NewCAPool()
- // Drop outbound
- assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
- // Allow inbound
- resetConntrack(fw)
- require.NoError(t, fw.Drop(p, true, &h, cp, nil))
- // Allow outbound because conntrack
- require.NoError(t, fw.Drop(p, false, &h, cp, nil))
- oldFw := fw
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
- fw.Conntrack = oldFw.Conntrack
- fw.rulesVersion = oldFw.rulesVersion + 1
- // Allow outbound because conntrack and new rules allow port 10
- require.NoError(t, fw.Drop(p, false, &h, cp, nil))
- oldFw = fw
- fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
- fw.Conntrack = oldFw.Conntrack
- fw.rulesVersion = oldFw.rulesVersion + 1
- // Drop outbound because conntrack doesn't match new ruleset
- assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
- }
- func TestFirewall_DropIPSpoofing(t *testing.T) {
- l := test.NewLogger()
- ob := &bytes.Buffer{}
- l.SetOutput(ob)
- c := cert.CachedCertificate{
- Certificate: &dummyCert{
- name: "host-owner",
- networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.1/24")},
- },
- }
- c1 := cert.CachedCertificate{
- Certificate: &dummyCert{
- name: "host",
- networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.2/24")},
- unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")},
- },
- }
- h1 := HostInfo{
- ConnectionState: &ConnectionState{
- peerCert: &c1,
- },
- vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
- }
- h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
- fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
- require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
- cp := cert.NewCAPool()
- // Packet spoofed by `c1`. Note that the remote addr is not a valid one.
- p := firewall.Packet{
- LocalAddr: netip.MustParseAddr("192.0.2.1"),
- RemoteAddr: netip.MustParseAddr("192.0.2.3"),
- LocalPort: 1,
- RemotePort: 1,
- Protocol: firewall.ProtoUDP,
- Fragment: false,
- }
- assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
- }
- func BenchmarkLookup(b *testing.B) {
- ml := func(m map[string]struct{}, a [][]string) {
- for n := 0; n < b.N; n++ {
- for _, sg := range a {
- found := false
- for _, g := range sg {
- if _, ok := m[g]; !ok {
- found = false
- break
- }
- found = true
- }
- if found {
- return
- }
- }
- }
- }
- b.Run("array to map best", func(b *testing.B) {
- m := map[string]struct{}{
- "1ne": {},
- "2wo": {},
- "3hr": {},
- "4ou": {},
- "5iv": {},
- "6ix": {},
- }
- a := [][]string{
- {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
- {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
- {"one", "two", "3hr", "4ou", "5iv", "6ix"},
- {"one", "two", "thr", "4ou", "5iv", "6ix"},
- {"one", "two", "thr", "fou", "5iv", "6ix"},
- {"one", "two", "thr", "fou", "fiv", "6ix"},
- {"one", "two", "thr", "fou", "fiv", "six"},
- }
- for n := 0; n < b.N; n++ {
- ml(m, a)
- }
- })
- b.Run("array to map worst", func(b *testing.B) {
- m := map[string]struct{}{
- "one": {},
- "two": {},
- "thr": {},
- "fou": {},
- "fiv": {},
- "six": {},
- }
- a := [][]string{
- {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
- {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
- {"one", "two", "3hr", "4ou", "5iv", "6ix"},
- {"one", "two", "thr", "4ou", "5iv", "6ix"},
- {"one", "two", "thr", "fou", "5iv", "6ix"},
- {"one", "two", "thr", "fou", "fiv", "6ix"},
- {"one", "two", "thr", "fou", "fiv", "six"},
- }
- for n := 0; n < b.N; n++ {
- ml(m, a)
- }
- })
- }
- func Test_parsePort(t *testing.T) {
- _, _, err := parsePort("")
- require.EqualError(t, err, "was not a number; ``")
- _, _, err = parsePort(" ")
- require.EqualError(t, err, "was not a number; ` `")
- _, _, err = parsePort("-")
- require.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
- _, _, err = parsePort(" - ")
- require.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
- _, _, err = parsePort("a-b")
- require.EqualError(t, err, "beginning range was not a number; `a`")
- _, _, err = parsePort("1-b")
- require.EqualError(t, err, "ending range was not a number; `b`")
- s, e, err := parsePort(" 1 - 2 ")
- assert.Equal(t, int32(1), s)
- assert.Equal(t, int32(2), e)
- require.NoError(t, err)
- s, e, err = parsePort("0-1")
- assert.Equal(t, int32(0), s)
- assert.Equal(t, int32(0), e)
- require.NoError(t, err)
- s, e, err = parsePort("9919")
- assert.Equal(t, int32(9919), s)
- assert.Equal(t, int32(9919), e)
- require.NoError(t, err)
- s, e, err = parsePort("any")
- assert.Equal(t, int32(0), s)
- assert.Equal(t, int32(0), e)
- require.NoError(t, err)
- }
- func TestNewFirewallFromConfig(t *testing.T) {
- l := test.NewLogger()
- // Test a bad rule definition
- c := &dummyCert{}
- cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
- require.NoError(t, err)
- conf := config.NewC(l)
- conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
- _, err = NewFirewallFromConfig(l, cs, conf)
- require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
- // Test both port and code
- conf = config.NewC(l)
- conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
- _, err = NewFirewallFromConfig(l, cs, conf)
- require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
- // Test missing host, group, cidr, ca_name and ca_sha
- conf = config.NewC(l)
- conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
- _, err = NewFirewallFromConfig(l, cs, conf)
- require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
- // Test code/port error
- conf = config.NewC(l)
- conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}}
- _, err = NewFirewallFromConfig(l, cs, conf)
- require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
- conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}}
- _, err = NewFirewallFromConfig(l, cs, conf)
- require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
- // Test proto error
- conf = config.NewC(l)
- conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
- _, err = NewFirewallFromConfig(l, cs, conf)
- require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
- // Test cidr parse error
- conf = config.NewC(l)
- conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
- _, err = NewFirewallFromConfig(l, cs, conf)
- require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
- // Test local_cidr parse error
- conf = config.NewC(l)
- conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
- _, err = NewFirewallFromConfig(l, cs, conf)
- require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
- // Test both group and groups
- conf = config.NewC(l)
- conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
- _, err = NewFirewallFromConfig(l, cs, conf)
- require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
- }
- func TestAddFirewallRulesFromConfig(t *testing.T) {
- l := test.NewLogger()
- // Test adding tcp rule
- conf := config.NewC(l)
- mf := &mockFirewall{}
- conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
- require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
- assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
- // Test adding udp rule
- conf = config.NewC(l)
- mf = &mockFirewall{}
- conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
- require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
- assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
- // Test adding icmp rule
- conf = config.NewC(l)
- mf = &mockFirewall{}
- conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
- require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
- assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
- // Test adding any rule
- conf = config.NewC(l)
- mf = &mockFirewall{}
- conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
- require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
- // Test adding rule with cidr
- cidr := netip.MustParsePrefix("10.0.0.0/8")
- conf = config.NewC(l)
- mf = &mockFirewall{}
- conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
- require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
- // Test adding rule with local_cidr
- conf = config.NewC(l)
- mf = &mockFirewall{}
- conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
- require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
- // Test adding rule with cidr ipv6
- cidr6 := netip.MustParsePrefix("fd00::/8")
- conf = config.NewC(l)
- mf = &mockFirewall{}
- conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
- require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
- // Test adding rule with local_cidr ipv6
- conf = config.NewC(l)
- mf = &mockFirewall{}
- conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
- require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
- // Test adding rule with ca_sha
- conf = config.NewC(l)
- mf = &mockFirewall{}
- conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
- require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
- // Test adding rule with ca_name
- conf = config.NewC(l)
- mf = &mockFirewall{}
- conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
- require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
- // Test single group
- conf = config.NewC(l)
- mf = &mockFirewall{}
- conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
- require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
- // Test single groups
- conf = config.NewC(l)
- mf = &mockFirewall{}
- conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
- require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
- // Test multiple AND groups
- conf = config.NewC(l)
- mf = &mockFirewall{}
- conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
- require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
- // Test Add error
- conf = config.NewC(l)
- mf = &mockFirewall{}
- mf.nextCallReturn = errors.New("test error")
- conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
- require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
- }
- func TestFirewall_convertRule(t *testing.T) {
- l := test.NewLogger()
- ob := &bytes.Buffer{}
- l.SetOutput(ob)
- // Ensure group array of 1 is converted and a warning is printed
- c := map[string]any{
- "group": []any{"group1"},
- }
- r, err := convertRule(l, c, "test", 1)
- assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
- require.NoError(t, err)
- assert.Equal(t, "group1", r.Group)
- // Ensure group array of > 1 is errord
- ob.Reset()
- c = map[string]any{
- "group": []any{"group1", "group2"},
- }
- r, err = convertRule(l, c, "test", 1)
- assert.Empty(t, ob.String())
- require.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
- // Make sure a well formed group is alright
- ob.Reset()
- c = map[string]any{
- "group": "group1",
- }
- r, err = convertRule(l, c, "test", 1)
- require.NoError(t, err)
- assert.Equal(t, "group1", r.Group)
- }
- type addRuleCall struct {
- incoming bool
- proto uint8
- startPort int32
- endPort int32
- groups []string
- host string
- ip netip.Prefix
- localIp netip.Prefix
- caName string
- caSha string
- }
- type mockFirewall struct {
- lastCall addRuleCall
- nextCallReturn error
- }
- func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
- mf.lastCall = addRuleCall{
- incoming: incoming,
- proto: proto,
- startPort: startPort,
- endPort: endPort,
- groups: groups,
- host: host,
- ip: ip,
- localIp: localIp,
- caName: caName,
- caSha: caSha,
- }
- err := mf.nextCallReturn
- mf.nextCallReturn = nil
- return err
- }
- func resetConntrack(fw *Firewall) {
- fw.Conntrack.Lock()
- fw.Conntrack.Conns = map[firewall.Packet]*conn{}
- fw.Conntrack.Unlock()
- }
|