Просмотр исходного кода

wgengine/filter: add and clean up documentation.

Signed-off-by: David Anderson <[email protected]>
David Anderson 5 лет назад
Родитель
Сommit
b950bd60bf
3 измененных файлов с 131 добавлено и 114 удалено
  1. 42 114
      wgengine/filter/filter.go
  2. 2 0
      wgengine/filter/match.go
  3. 87 0
      wgengine/filter/tailcfg.go

+ 42 - 114
wgengine/filter/filter.go

@@ -2,7 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 // license that can be found in the LICENSE file.
 
 
-// Package filter contains a stateful packet filter.
+// Package filter is a stateful packet filter.
 package filter
 package filter
 
 
 import (
 import (
@@ -14,15 +14,9 @@ import (
 	"golang.org/x/time/rate"
 	"golang.org/x/time/rate"
 	"inet.af/netaddr"
 	"inet.af/netaddr"
 	"tailscale.com/net/packet"
 	"tailscale.com/net/packet"
-	"tailscale.com/tailcfg"
 	"tailscale.com/types/logger"
 	"tailscale.com/types/logger"
 )
 )
 
 
-type filterState struct {
-	mu  sync.Mutex
-	lru *lru.Cache // of tuple
-}
-
 // Filter is a stateful packet filter.
 // Filter is a stateful packet filter.
 type Filter struct {
 type Filter struct {
 	logf logger.Logf
 	logf logger.Logf
@@ -45,14 +39,31 @@ type Filter struct {
 	state *filterState
 	state *filterState
 }
 }
 
 
-// Response is a verdict: either a Drop, Accept, or noVerdict skip to
-// continue processing.
+// tuple is a 4-tuple of source and destination IPv4 and port. It's
+// used as a lookup key in filterState.
+type tuple struct {
+	SrcIP   packet.IP4
+	DstIP   packet.IP4
+	SrcPort uint16
+	DstPort uint16
+}
+
+// filterState is a state cache of past seen packets.
+type filterState struct {
+	mu  sync.Mutex
+	lru *lru.Cache // of tuple
+}
+
+// lruMax is the size of the LRU cache in filterState.
+const lruMax = 512
+
+// Response is a verdict from the packet filter.
 type Response int
 type Response int
 
 
 const (
 const (
-	Drop Response = iota
-	Accept
-	noVerdict // Returned from subfilters to continue processing.
+	Drop      Response = iota // do not continue processing packet.
+	Accept                    // continue processing packet.
+	noVerdict                 // no verdict yet, continue running filter
 )
 )
 
 
 func (r Response) String() string {
 func (r Response) String() string {
@@ -72,30 +83,16 @@ func (r Response) String() string {
 type RunFlags int
 type RunFlags int
 
 
 const (
 const (
-	LogDrops RunFlags = 1 << iota
-	LogAccepts
-	HexdumpDrops
-	HexdumpAccepts
+	LogDrops       RunFlags = 1 << iota // write dropped packet info to logf
+	LogAccepts                          // write accepted packet info to logf
+	HexdumpDrops                        // print packet hexdump when logging drops
+	HexdumpAccepts                      // print packet hexdump when logging accepts
 )
 )
 
 
-type tuple struct {
-	SrcIP   packet.IP4
-	DstIP   packet.IP4
-	SrcPort uint16
-	DstPort uint16
-}
-
-const lruMax = 512 // max entries in UDP LRU cache
-
-// MatchAllowAll matches all packets.
-var MatchAllowAll = Matches{
-	Match{NetPortRangeAny, NetAny},
-}
-
 // NewAllowAll returns a packet filter that accepts everything to and
 // NewAllowAll returns a packet filter that accepts everything to and
 // from localNets.
 // from localNets.
 func NewAllowAll(localNets []netaddr.IPPrefix, logf logger.Logf) *Filter {
 func NewAllowAll(localNets []netaddr.IPPrefix, logf logger.Logf) *Filter {
-	return New(MatchAllowAll, localNets, nil, logf)
+	return New(Matches{Match{NetPortRangeAny, NetAny}}, localNets, nil, logf)
 }
 }
 
 
 // NewAllowNone returns a packet filter that rejects everything.
 // NewAllowNone returns a packet filter that rejects everything.
@@ -106,8 +103,8 @@ func NewAllowNone(logf logger.Logf) *Filter {
 // New creates a new packet filter. The filter enforces that incoming
 // New creates a new packet filter. The filter enforces that incoming
 // packets must be destined to an IP in localNets, and must be allowed
 // packets must be destined to an IP in localNets, and must be allowed
 // by matches. If shareStateWith is non-nil, the returned filter
 // by matches. If shareStateWith is non-nil, the returned filter
-// shares state with the previous one, to enable rules to be changed
-// at runtime without breaking existing flows.
+// shares state with the previous one, to enable changing rules at
+// runtime without breaking existing stateful flows.
 func New(matches Matches, localNets []netaddr.IPPrefix, shareStateWith *Filter, logf logger.Logf) *Filter {
 func New(matches Matches, localNets []netaddr.IPPrefix, shareStateWith *Filter, logf logger.Logf) *Filter {
 	var state *filterState
 	var state *filterState
 	if shareStateWith != nil {
 	if shareStateWith != nil {
@@ -133,82 +130,6 @@ func maybeHexdump(flag RunFlags, b []byte) string {
 	return packet.Hexdump(b) + "\n"
 	return packet.Hexdump(b) + "\n"
 }
 }
 
 
-// MatchesFromFilterRules parse a number of wire-format FilterRule values into
-// the Matches format.
-// If an error is returned, the Matches result is still valid, containing the rules that
-// were successfully converted.
-func MatchesFromFilterRules(pf []tailcfg.FilterRule) (Matches, error) {
-	mm := make([]Match, 0, len(pf))
-	var erracc error
-
-	for _, r := range pf {
-		m := Match{}
-
-		for i, s := range r.SrcIPs {
-			bits := 32
-			if len(r.SrcBits) > i {
-				bits = r.SrcBits[i]
-			}
-			net, err := parseIP(s, bits)
-			if err != nil && erracc == nil {
-				erracc = err
-				continue
-			}
-			m.Srcs = append(m.Srcs, net)
-		}
-
-		for _, d := range r.DstPorts {
-			bits := 32
-			if d.Bits != nil {
-				bits = *d.Bits
-			}
-			net, err := parseIP(d.IP, bits)
-			if err != nil && erracc == nil {
-				erracc = err
-				continue
-			}
-			m.Dsts = append(m.Dsts, NetPortRange{
-				Net: net,
-				Ports: PortRange{
-					First: d.Ports.First,
-					Last:  d.Ports.Last,
-				},
-			})
-		}
-
-		mm = append(mm, m)
-	}
-	return mm, erracc
-}
-
-func parseIP(host string, defaultBits int) (netaddr.IPPrefix, error) {
-	if host == "*" {
-		// User explicitly requested wildcard dst ip.
-		// TODO: ipv6
-		return netaddr.IPPrefix{IP: netaddr.IPv4(0, 0, 0, 0), Bits: 0}, nil
-	}
-
-	ip, err := netaddr.ParseIP(host)
-	if err != nil {
-		return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IP address", host)
-	}
-	if ip == netaddr.IPv4(0, 0, 0, 0) {
-		// For clarity, reject 0.0.0.0 as an input
-		return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host)
-	}
-	if !ip.Is4() {
-		// TODO: ipv6
-		return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IPv4 address", host)
-	}
-	if defaultBits < 0 || defaultBits > 32 {
-		return netaddr.IPPrefix{}, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host)
-	}
-	return netaddr.IPPrefix{
-		IP:   ip,
-		Bits: uint8(defaultBits),
-	}, nil
-}
-
 // TODO(apenwarr): use a bigger bucket for specifically TCP SYN accept logging?
 // TODO(apenwarr): use a bigger bucket for specifically TCP SYN accept logging?
 //   Logging is a quick way to record every newly opened TCP connection, but
 //   Logging is a quick way to record every newly opened TCP connection, but
 //   we have to be cautious about flooding the logs vs letting people use
 //   we have to be cautious about flooding the logs vs letting people use
@@ -240,7 +161,8 @@ func (f *Filter) logRateLimit(runflags RunFlags, q *packet.ParsedPacket, dir dir
 	}
 	}
 }
 }
 
 
-// RunIn determines whether this node is allowed to receive q from a Tailscale peer.
+// RunIn determines whether this node is allowed to receive q from a
+// Tailscale peer.
 func (f *Filter) RunIn(q *packet.ParsedPacket, rf RunFlags) Response {
 func (f *Filter) RunIn(q *packet.ParsedPacket, rf RunFlags) Response {
 	dir := in
 	dir := in
 	r := f.pre(q, rf, dir)
 	r := f.pre(q, rf, dir)
@@ -254,7 +176,8 @@ func (f *Filter) RunIn(q *packet.ParsedPacket, rf RunFlags) Response {
 	return r
 	return r
 }
 }
 
 
-// RunOut determines whether this node is allowed to send q to a Tailscale peer.
+// RunOut determines whether this node is allowed to send q to a
+// Tailscale peer.
 func (f *Filter) RunOut(q *packet.ParsedPacket, rf RunFlags) Response {
 func (f *Filter) RunOut(q *packet.ParsedPacket, rf RunFlags) Response {
 	dir := out
 	dir := out
 	r := f.pre(q, rf, dir)
 	r := f.pre(q, rf, dir)
@@ -267,6 +190,7 @@ func (f *Filter) RunOut(q *packet.ParsedPacket, rf RunFlags) Response {
 	return r
 	return r
 }
 }
 
 
+// runIn runs the input-specific part of the filter logic.
 func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
 func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
 	// A compromised peer could try to send us packets for
 	// A compromised peer could try to send us packets for
 	// destinations we didn't explicitly advertise. This check is to
 	// destinations we didn't explicitly advertise. This check is to
@@ -327,6 +251,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
 	return Drop, "no rules matched"
 	return Drop, "no rules matched"
 }
 }
 
 
+// runIn runs the output-specific part of the filter logic.
 func (f *Filter) runOut(q *packet.ParsedPacket) (r Response, why string) {
 func (f *Filter) runOut(q *packet.ParsedPacket) (r Response, why string) {
 	if q.IPProto == packet.UDP {
 	if q.IPProto == packet.UDP {
 		t := tuple{q.DstIP, q.SrcIP, q.DstPort, q.SrcPort}
 		t := tuple{q.DstIP, q.SrcIP, q.DstPort, q.SrcPort}
@@ -339,12 +264,13 @@ func (f *Filter) runOut(q *packet.ParsedPacket) (r Response, why string) {
 	return Accept, "ok out"
 	return Accept, "ok out"
 }
 }
 
 
-// direction is whether a packet was flowing in to this machine, or flowing out.
+// direction is whether a packet was flowing in to this machine, or
+// flowing out.
 type direction int
 type direction int
 
 
 const (
 const (
-	in direction = iota
-	out
+	in  direction = iota // from Tailscale peer to local machine
+	out                  // from local machine to Tailscale peer
 )
 )
 
 
 func (d direction) String() string {
 func (d direction) String() string {
@@ -358,6 +284,8 @@ func (d direction) String() string {
 	}
 	}
 }
 }
 
 
+// pre runs the direction-agnostic filter logic. dir is only used for
+// logging.
 func (f *Filter) pre(q *packet.ParsedPacket, rf RunFlags, dir direction) Response {
 func (f *Filter) pre(q *packet.ParsedPacket, rf RunFlags, dir direction) Response {
 	if len(q.Buffer()) == 0 {
 	if len(q.Buffer()) == 0 {
 		// wireguard keepalive packet, always permit.
 		// wireguard keepalive packet, always permit.

+ 2 - 0
wgengine/filter/match.go

@@ -29,6 +29,7 @@ func (pr PortRange) String() string {
 	}
 	}
 }
 }
 
 
+// contains returns whether port is in pr.
 func (pr PortRange) contains(port uint16) bool {
 func (pr PortRange) contains(port uint16) bool {
 	return port >= pr.First && port <= pr.Last
 	return port >= pr.First && port <= pr.Last
 }
 }
@@ -47,6 +48,7 @@ func (npr NetPortRange) String() string {
 	return fmt.Sprintf("%v:%v", npr.Net, npr.Ports)
 	return fmt.Sprintf("%v:%v", npr.Net, npr.Ports)
 }
 }
 
 
+// NetPortRangeAny matches any IP and port.
 var NetPortRangeAny = []NetPortRange{{Net: NetAny[0], Ports: PortRangeAny}}
 var NetPortRangeAny = []NetPortRange{{Net: NetAny[0], Ports: PortRangeAny}}
 
 
 // Match matches packets from any IP address in Srcs to any ip:port in
 // Match matches packets from any IP address in Srcs to any ip:port in

+ 87 - 0
wgengine/filter/tailcfg.go

@@ -0,0 +1,87 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package filter
+
+import (
+	"fmt"
+
+	"inet.af/netaddr"
+	"tailscale.com/tailcfg"
+)
+
+// MatchesFromFilterRules converts tailcfg FilterRules into Matches.
+// If an error is returned, the Matches result is still valid,
+// containing the rules that were successfully converted.
+func MatchesFromFilterRules(pf []tailcfg.FilterRule) (Matches, error) {
+	mm := make([]Match, 0, len(pf))
+	var erracc error
+
+	for _, r := range pf {
+		m := Match{}
+
+		for i, s := range r.SrcIPs {
+			bits := 32
+			if len(r.SrcBits) > i {
+				bits = r.SrcBits[i]
+			}
+			net, err := parseIP(s, bits)
+			if err != nil && erracc == nil {
+				erracc = err
+				continue
+			}
+			m.Srcs = append(m.Srcs, net)
+		}
+
+		for _, d := range r.DstPorts {
+			bits := 32
+			if d.Bits != nil {
+				bits = *d.Bits
+			}
+			net, err := parseIP(d.IP, bits)
+			if err != nil && erracc == nil {
+				erracc = err
+				continue
+			}
+			m.Dsts = append(m.Dsts, NetPortRange{
+				Net: net,
+				Ports: PortRange{
+					First: d.Ports.First,
+					Last:  d.Ports.Last,
+				},
+			})
+		}
+
+		mm = append(mm, m)
+	}
+	return mm, erracc
+}
+
+func parseIP(host string, defaultBits int) (netaddr.IPPrefix, error) {
+	if host == "*" {
+		// User explicitly requested wildcard dst ip.
+		// TODO: ipv6
+		return netaddr.IPPrefix{IP: netaddr.IPv4(0, 0, 0, 0), Bits: 0}, nil
+	}
+
+	ip, err := netaddr.ParseIP(host)
+	if err != nil {
+		return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IP address", host)
+	}
+	if ip == netaddr.IPv4(0, 0, 0, 0) {
+		// For clarity, reject 0.0.0.0 as an input
+		return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host)
+	}
+	if !ip.Is4() {
+		// TODO: ipv6
+		return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IPv4 address", host)
+	}
+	if defaultBits < 0 || defaultBits > 32 {
+		return netaddr.IPPrefix{}, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host)
+	}
+	return netaddr.IPPrefix{
+		IP:   ip,
+		Bits: uint8(defaultBits),
+	}, nil
+}