ソースを参照

wgengine/filter: let unknown IPProto match if IP okay & match allows all ports

RELNOTE=yes

Change-Id: I96eaf3cf550cee7bb6cdb4ad81fc761e280a1b2a
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 4 年 前
コミット
69de3bf7bf
3 ファイル変更96 行追加18 行削除
  1. 6 0
      wgengine/filter/filter.go
  2. 65 18
      wgengine/filter/filter_test.go
  3. 25 0
      wgengine/filter/match.go

+ 6 - 0
wgengine/filter/filter.go

@@ -384,6 +384,9 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) {
 	case ipproto.TSMP:
 		return Accept, "tsmp ok"
 	default:
+		if f.matches4.matchProtoAndIPsOnlyIfAllPorts(q) {
+			return Accept, "otherproto ok"
+		}
 		return Drop, "Unknown proto"
 	}
 	return Drop, "no rules matched"
@@ -441,6 +444,9 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) {
 	case ipproto.TSMP:
 		return Accept, "tsmp ok"
 	default:
+		if f.matches6.matchProtoAndIPsOnlyIfAllPorts(q) {
+			return Accept, "otherproto ok"
+		}
 		return Drop, "Unknown proto"
 	}
 	return Drop, "no rules matched"

+ 65 - 18
wgengine/filter/filter_test.go

@@ -23,17 +23,25 @@ import (
 	"tailscale.com/types/logger"
 )
 
-func newFilter(logf logger.Logf) *Filter {
-	m := func(srcs []netaddr.IPPrefix, dsts []NetPortRange, protos ...ipproto.Proto) Match {
-		if protos == nil {
-			protos = defaultProtos
-		}
-		return Match{
-			IPProto: protos,
-			Srcs:    srcs,
-			Dsts:    dsts,
-		}
+// testAllowedProto is an IP protocol number we treat as allowed for
+// these tests.
+const (
+	testAllowedProto ipproto.Proto = 116
+	testDeniedProto  ipproto.Proto = 127 // CRUDP, appropriately cruddy
+)
+
+func m(srcs []netaddr.IPPrefix, dsts []NetPortRange, protos ...ipproto.Proto) Match {
+	if protos == nil {
+		protos = defaultProtos
 	}
+	return Match{
+		IPProto: protos,
+		Srcs:    srcs,
+		Dsts:    dsts,
+	}
+}
+
+func newFilter(logf logger.Logf) *Filter {
 	matches := []Match{
 		m(nets("8.1.1.1", "8.2.2.2"), netports("1.2.3.4:22", "5.6.7.8:23-24")),
 		m(nets("9.1.1.1", "9.2.2.2"), netports("1.2.3.4:22", "5.6.7.8:23-24"), ipproto.SCTP),
@@ -44,6 +52,8 @@ func newFilter(logf logger.Logf) *Filter {
 		m(nets("153.1.1.1", "153.1.1.2", "153.3.3.3"), netports("1.2.3.4:999")),
 		m(nets("::1", "::2"), netports("2001::1:22", "2001::2:22")),
 		m(nets("::/0"), netports("::/0:443")),
+		m(nets("0.0.0.0/0"), netports("0.0.0.0/0:*"), testAllowedProto),
+		m(nets("::/0"), netports("::/0:*"), testAllowedProto),
 	}
 
 	// Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8,
@@ -112,6 +122,12 @@ func TestFilter(t *testing.T) {
 		{Drop, parsed(ipproto.SCTP, "8.1.1.1", "1.2.3.4", 999, 22)},
 		// But SCTP is allowed for 9.1.1.1
 		{Accept, parsed(ipproto.SCTP, "9.1.1.1", "1.2.3.4", 999, 22)},
+
+		// Unknown protocol is allowed if all its ports are allowed.
+		{Accept, parsed(testAllowedProto, "1.2.3.4", "5.6.7.8", 0, 0)},
+		{Accept, parsed(testAllowedProto, "2001::1", "2001::2", 0, 0)},
+		{Drop, parsed(testDeniedProto, "1.2.3.4", "5.6.7.8", 0, 0)},
+		{Drop, parsed(testDeniedProto, "2001::1", "2001::2", 0, 0)},
 	}
 	for i, test := range tests {
 		aclFunc := acl.runIn4
@@ -534,13 +550,7 @@ func TestLoggingPrivacy(t *testing.T) {
 	}
 }
 
-func mustIP(s string) netaddr.IP {
-	ip, err := netaddr.ParseIP(s)
-	if err != nil {
-		panic(err)
-	}
-	return ip
-}
+var mustIP = netaddr.MustParseIP
 
 func parsed(proto ipproto.Proto, src, dst string, sport, dport uint16) packet.Parsed {
 	sip, dip := mustIP(src), mustIP(dst)
@@ -689,7 +699,7 @@ func nets(nets ...string) (ret []netaddr.IPPrefix) {
 
 func ports(s string) PortRange {
 	if s == "*" {
-		return PortRange{First: 0, Last: 65535}
+		return allPorts
 	}
 
 	var fs, ls string
@@ -825,3 +835,40 @@ func TestNewAllowAllForTest(t *testing.T) {
 		t.Fatalf("unexpected drop verdict: %v", res)
 	}
 }
+
+func TestMatchesMatchProtoAndIPsOnlyIfAllPorts(t *testing.T) {
+	tests := []struct {
+		name string
+		m    Match
+		p    packet.Parsed
+		want bool
+	}{
+		{
+			name: "all_ports_okay",
+			m:    m(nets("0.0.0.0/0"), netports("0.0.0.0/0:*"), testAllowedProto),
+			p:    parsed(testAllowedProto, "1.2.3.4", "5.6.7.8", 0, 0),
+			want: true,
+		},
+		{
+			name: "all_ports_match_but_packet_wrong_proto",
+			m:    m(nets("0.0.0.0/0"), netports("0.0.0.0/0:*"), testAllowedProto),
+			p:    parsed(testDeniedProto, "1.2.3.4", "5.6.7.8", 0, 0),
+			want: false,
+		},
+		{
+			name: "ports_requirements_dont_match_unknown_proto",
+			m:    m(nets("0.0.0.0/0"), netports("0.0.0.0/0:12345"), testAllowedProto),
+			p:    parsed(testAllowedProto, "1.2.3.4", "5.6.7.8", 0, 0),
+			want: false,
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			matches := matches{tt.m}
+			got := matches.matchProtoAndIPsOnlyIfAllPorts(&tt.p)
+			if got != tt.want {
+				t.Errorf("got = %v; want %v", got, tt.want)
+			}
+		})
+	}
+}

+ 25 - 0
wgengine/filter/match.go

@@ -20,6 +20,8 @@ type PortRange struct {
 	First, Last uint16 // inclusive
 }
 
+var allPorts = PortRange{0, 0xffff}
+
 func (pr PortRange) String() string {
 	if pr.First == 0 && pr.Last == 65535 {
 		return "*"
@@ -115,6 +117,29 @@ func (ms matches) matchIPsOnly(q *packet.Parsed) bool {
 	return false
 }
 
+// matchProtoAndIPsOnlyIfAllPorts reports q matches any Match in ms where the
+// Match if for the right IP Protocol and IP address, but ports are
+// ignored, as long as the match is for the entire uint16 port range.
+func (ms matches) matchProtoAndIPsOnlyIfAllPorts(q *packet.Parsed) bool {
+	for _, m := range ms {
+		if !protoInList(q.IPProto, m.IPProto) {
+			continue
+		}
+		if !ipInList(q.Src.IP(), m.Srcs) {
+			continue
+		}
+		for _, dst := range m.Dsts {
+			if dst.Ports != allPorts {
+				continue
+			}
+			if dst.Net.Contains(q.Dst.IP()) {
+				return true
+			}
+		}
+	}
+	return false
+}
+
 func ipInList(ip netaddr.IP, netlist []netaddr.IPPrefix) bool {
 	for _, net := range netlist {
 		if net.Contains(ip) {