|
|
@@ -23,17 +23,25 @@ import (
|
|
|
"tailscale.com/types/logger"
|
|
|
)
|
|
|
|
|
|
-func newFilter(logf logger.Logf) *Filter {
|
|
|
- m := func(srcs []netaddr.IPPrefix, dsts []NetPortRange, protos ...ipproto.Proto) Match {
|
|
|
- if protos == nil {
|
|
|
- protos = defaultProtos
|
|
|
- }
|
|
|
- return Match{
|
|
|
- IPProto: protos,
|
|
|
- Srcs: srcs,
|
|
|
- Dsts: dsts,
|
|
|
- }
|
|
|
+// testAllowedProto is an IP protocol number we treat as allowed for
|
|
|
+// these tests.
|
|
|
+const (
|
|
|
+ testAllowedProto ipproto.Proto = 116
|
|
|
+ testDeniedProto ipproto.Proto = 127 // CRUDP, appropriately cruddy
|
|
|
+)
|
|
|
+
|
|
|
+func m(srcs []netaddr.IPPrefix, dsts []NetPortRange, protos ...ipproto.Proto) Match {
|
|
|
+ if protos == nil {
|
|
|
+ protos = defaultProtos
|
|
|
}
|
|
|
+ return Match{
|
|
|
+ IPProto: protos,
|
|
|
+ Srcs: srcs,
|
|
|
+ Dsts: dsts,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func newFilter(logf logger.Logf) *Filter {
|
|
|
matches := []Match{
|
|
|
m(nets("8.1.1.1", "8.2.2.2"), netports("1.2.3.4:22", "5.6.7.8:23-24")),
|
|
|
m(nets("9.1.1.1", "9.2.2.2"), netports("1.2.3.4:22", "5.6.7.8:23-24"), ipproto.SCTP),
|
|
|
@@ -44,6 +52,8 @@ func newFilter(logf logger.Logf) *Filter {
|
|
|
m(nets("153.1.1.1", "153.1.1.2", "153.3.3.3"), netports("1.2.3.4:999")),
|
|
|
m(nets("::1", "::2"), netports("2001::1:22", "2001::2:22")),
|
|
|
m(nets("::/0"), netports("::/0:443")),
|
|
|
+ m(nets("0.0.0.0/0"), netports("0.0.0.0/0:*"), testAllowedProto),
|
|
|
+ m(nets("::/0"), netports("::/0:*"), testAllowedProto),
|
|
|
}
|
|
|
|
|
|
// Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8,
|
|
|
@@ -112,6 +122,12 @@ func TestFilter(t *testing.T) {
|
|
|
{Drop, parsed(ipproto.SCTP, "8.1.1.1", "1.2.3.4", 999, 22)},
|
|
|
// But SCTP is allowed for 9.1.1.1
|
|
|
{Accept, parsed(ipproto.SCTP, "9.1.1.1", "1.2.3.4", 999, 22)},
|
|
|
+
|
|
|
+ // Unknown protocol is allowed if all its ports are allowed.
|
|
|
+ {Accept, parsed(testAllowedProto, "1.2.3.4", "5.6.7.8", 0, 0)},
|
|
|
+ {Accept, parsed(testAllowedProto, "2001::1", "2001::2", 0, 0)},
|
|
|
+ {Drop, parsed(testDeniedProto, "1.2.3.4", "5.6.7.8", 0, 0)},
|
|
|
+ {Drop, parsed(testDeniedProto, "2001::1", "2001::2", 0, 0)},
|
|
|
}
|
|
|
for i, test := range tests {
|
|
|
aclFunc := acl.runIn4
|
|
|
@@ -534,13 +550,7 @@ func TestLoggingPrivacy(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func mustIP(s string) netaddr.IP {
|
|
|
- ip, err := netaddr.ParseIP(s)
|
|
|
- if err != nil {
|
|
|
- panic(err)
|
|
|
- }
|
|
|
- return ip
|
|
|
-}
|
|
|
+var mustIP = netaddr.MustParseIP
|
|
|
|
|
|
func parsed(proto ipproto.Proto, src, dst string, sport, dport uint16) packet.Parsed {
|
|
|
sip, dip := mustIP(src), mustIP(dst)
|
|
|
@@ -689,7 +699,7 @@ func nets(nets ...string) (ret []netaddr.IPPrefix) {
|
|
|
|
|
|
func ports(s string) PortRange {
|
|
|
if s == "*" {
|
|
|
- return PortRange{First: 0, Last: 65535}
|
|
|
+ return allPorts
|
|
|
}
|
|
|
|
|
|
var fs, ls string
|
|
|
@@ -825,3 +835,40 @@ func TestNewAllowAllForTest(t *testing.T) {
|
|
|
t.Fatalf("unexpected drop verdict: %v", res)
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+func TestMatchesMatchProtoAndIPsOnlyIfAllPorts(t *testing.T) {
|
|
|
+ tests := []struct {
|
|
|
+ name string
|
|
|
+ m Match
|
|
|
+ p packet.Parsed
|
|
|
+ want bool
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ name: "all_ports_okay",
|
|
|
+ m: m(nets("0.0.0.0/0"), netports("0.0.0.0/0:*"), testAllowedProto),
|
|
|
+ p: parsed(testAllowedProto, "1.2.3.4", "5.6.7.8", 0, 0),
|
|
|
+ want: true,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "all_ports_match_but_packet_wrong_proto",
|
|
|
+ m: m(nets("0.0.0.0/0"), netports("0.0.0.0/0:*"), testAllowedProto),
|
|
|
+ p: parsed(testDeniedProto, "1.2.3.4", "5.6.7.8", 0, 0),
|
|
|
+ want: false,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "ports_requirements_dont_match_unknown_proto",
|
|
|
+ m: m(nets("0.0.0.0/0"), netports("0.0.0.0/0:12345"), testAllowedProto),
|
|
|
+ p: parsed(testAllowedProto, "1.2.3.4", "5.6.7.8", 0, 0),
|
|
|
+ want: false,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ for _, tt := range tests {
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
+ matches := matches{tt.m}
|
|
|
+ got := matches.matchProtoAndIPsOnlyIfAllPorts(&tt.p)
|
|
|
+ if got != tt.want {
|
|
|
+ t.Errorf("got = %v; want %v", got, tt.want)
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|