|
|
@@ -6,6 +6,8 @@ package filter
|
|
|
|
|
|
import (
|
|
|
"fmt"
|
|
|
+ "math/bits"
|
|
|
+ "net"
|
|
|
"strings"
|
|
|
|
|
|
"tailscale.com/wgengine/packet"
|
|
|
@@ -13,9 +15,42 @@ import (
|
|
|
|
|
|
type IP = packet.IP
|
|
|
|
|
|
-const IPAny = IP(0)
|
|
|
+func NewIP(ip net.IP) IP {
|
|
|
+ return packet.NewIP(ip)
|
|
|
+}
|
|
|
+
|
|
|
+type Net struct {
|
|
|
+ IP IP
|
|
|
+ Mask IP
|
|
|
+}
|
|
|
+
|
|
|
+func (n Net) Includes(ip IP) bool {
|
|
|
+ return (n.IP & n.Mask) == (ip & n.Mask)
|
|
|
+}
|
|
|
+
|
|
|
+func (n Net) Bits() int {
|
|
|
+ return 32 - bits.TrailingZeros32(uint32(n.Mask))
|
|
|
+}
|
|
|
|
|
|
-var NewIP = packet.NewIP
|
|
|
+func (n Net) String() string {
|
|
|
+ b := n.Bits()
|
|
|
+ if b == 32 {
|
|
|
+ return n.IP.String()
|
|
|
+ } else if b == 0 {
|
|
|
+ return "*"
|
|
|
+ } else {
|
|
|
+ return fmt.Sprintf("%s/%d", n.IP, b)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+var NetAny = Net{0, 0}
|
|
|
+var NetNone = Net{^IP(0), ^IP(0)}
|
|
|
+
|
|
|
+func Netmask(bits int) IP {
|
|
|
+ var b uint32
|
|
|
+ b = ^uint32((1 << (32 - bits)) - 1)
|
|
|
+ return IP(b)
|
|
|
+}
|
|
|
|
|
|
type PortRange struct {
|
|
|
First, Last uint16
|
|
|
@@ -33,39 +68,39 @@ func (pr PortRange) String() string {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-type IPPortRange struct {
|
|
|
- IP IP
|
|
|
+type NetPortRange struct {
|
|
|
+ Net Net
|
|
|
Ports PortRange
|
|
|
}
|
|
|
|
|
|
-var IPPortRangeAny = IPPortRange{IPAny, PortRangeAny}
|
|
|
+var NetPortRangeAny = NetPortRange{NetAny, PortRangeAny}
|
|
|
|
|
|
-func (ipr IPPortRange) String() string {
|
|
|
- return fmt.Sprintf("%v:%v", ipr.IP, ipr.Ports)
|
|
|
+func (ipr NetPortRange) String() string {
|
|
|
+ return fmt.Sprintf("%v:%v", ipr.Net, ipr.Ports)
|
|
|
}
|
|
|
|
|
|
type Match struct {
|
|
|
- DstPorts []IPPortRange
|
|
|
- SrcIPs []IP
|
|
|
+ Dsts []NetPortRange
|
|
|
+ Srcs []Net
|
|
|
}
|
|
|
|
|
|
func (m Match) Clone() (res Match) {
|
|
|
- if m.DstPorts != nil {
|
|
|
- res.DstPorts = append([]IPPortRange{}, m.DstPorts...)
|
|
|
+ if m.Dsts != nil {
|
|
|
+ res.Dsts = append([]NetPortRange{}, m.Dsts...)
|
|
|
}
|
|
|
- if m.SrcIPs != nil {
|
|
|
- res.SrcIPs = append([]IP{}, m.SrcIPs...)
|
|
|
+ if m.Srcs != nil {
|
|
|
+ res.Srcs = append([]Net{}, m.Srcs...)
|
|
|
}
|
|
|
return res
|
|
|
}
|
|
|
|
|
|
func (m Match) String() string {
|
|
|
srcs := []string{}
|
|
|
- for _, srcip := range m.SrcIPs {
|
|
|
- srcs = append(srcs, srcip.String())
|
|
|
+ for _, src := range m.Srcs {
|
|
|
+ srcs = append(srcs, src.String())
|
|
|
}
|
|
|
dsts := []string{}
|
|
|
- for _, dst := range m.DstPorts {
|
|
|
+ for _, dst := range m.Dsts {
|
|
|
dsts = append(dsts, dst.String())
|
|
|
}
|
|
|
|
|
|
@@ -92,9 +127,9 @@ func (m Matches) Clone() (res Matches) {
|
|
|
return res
|
|
|
}
|
|
|
|
|
|
-func ipInList(ip IP, iplist []IP) bool {
|
|
|
- for _, ipp := range iplist {
|
|
|
- if ipp == IPAny || ipp == ip {
|
|
|
+func ipInList(ip IP, netlist []Net) bool {
|
|
|
+ for _, net := range netlist {
|
|
|
+ if net.Includes(ip) {
|
|
|
return true
|
|
|
}
|
|
|
}
|
|
|
@@ -103,14 +138,14 @@ func ipInList(ip IP, iplist []IP) bool {
|
|
|
|
|
|
func matchIPPorts(mm Matches, q *packet.QDecode) bool {
|
|
|
for _, acl := range mm {
|
|
|
- for _, dst := range acl.DstPorts {
|
|
|
- if dst.IP != IPAny && dst.IP != q.DstIP {
|
|
|
+ for _, dst := range acl.Dsts {
|
|
|
+ if !dst.Net.Includes(q.DstIP) {
|
|
|
continue
|
|
|
}
|
|
|
if q.DstPort < dst.Ports.First || q.DstPort > dst.Ports.Last {
|
|
|
continue
|
|
|
}
|
|
|
- if !ipInList(q.SrcIP, acl.SrcIPs) {
|
|
|
+ if !ipInList(q.SrcIP, acl.Srcs) {
|
|
|
// Skip other dests in this acl, since
|
|
|
// the src will never match.
|
|
|
break
|
|
|
@@ -123,11 +158,11 @@ func matchIPPorts(mm Matches, q *packet.QDecode) bool {
|
|
|
|
|
|
func matchIPWithoutPorts(mm Matches, q *packet.QDecode) bool {
|
|
|
for _, acl := range mm {
|
|
|
- for _, dst := range acl.DstPorts {
|
|
|
- if dst.IP != IPAny && dst.IP != q.DstIP {
|
|
|
+ for _, dst := range acl.Dsts {
|
|
|
+ if !dst.Net.Includes(q.DstIP) {
|
|
|
continue
|
|
|
}
|
|
|
- if !ipInList(q.SrcIP, acl.SrcIPs) {
|
|
|
+ if !ipInList(q.SrcIP, acl.Srcs) {
|
|
|
// Skip other dests in this acl, since
|
|
|
// the src will never match.
|
|
|
break
|