Browse Source

tailcfg: import ProtoPortRange for local use

Imported type and parsing, with minor modifications.

Updates tailscale/corp#15043

Signed-off-by: James Tucker <[email protected]>
James Tucker 2 years ago
parent
commit
96f01a73b1

+ 1 - 1
cmd/tailscale/depaware.txt

@@ -158,7 +158,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
         tailscale.com/util/singleflight                              from tailscale.com/net/dnscache
         tailscale.com/util/slicesx                                   from tailscale.com/net/dnscache+
         tailscale.com/util/testenv                                   from tailscale.com/cmd/tailscale/cli
-        tailscale.com/util/vizerror                                  from tailscale.com/types/ipproto
+        tailscale.com/util/vizerror                                  from tailscale.com/types/ipproto+
      💣 tailscale.com/util/winutil                                   from tailscale.com/hostinfo+
    W 💣 tailscale.com/util/winutil/authenticode                      from tailscale.com/clientupdate
         tailscale.com/version                                        from tailscale.com/cmd/tailscale/cli+

+ 1 - 1
cmd/tailscaled/depaware.txt

@@ -363,7 +363,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
         tailscale.com/util/systemd                                   from tailscale.com/control/controlclient+
         tailscale.com/util/testenv                                   from tailscale.com/ipn/ipnlocal+
         tailscale.com/util/uniq                                      from tailscale.com/wgengine/magicsock+
-        tailscale.com/util/vizerror                                  from tailscale.com/types/ipproto
+        tailscale.com/util/vizerror                                  from tailscale.com/types/ipproto+
      💣 tailscale.com/util/winutil                                   from tailscale.com/control/controlclient+
    W 💣 tailscale.com/util/winutil/authenticode                      from tailscale.com/util/osdiag+
    W    tailscale.com/util/winutil/policy                            from tailscale.com/ipn/ipnlocal

+ 160 - 0
tailcfg/proto_port_range.go

@@ -0,0 +1,160 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tailcfg
+
+import (
+	"errors"
+	"fmt"
+	"strconv"
+	"strings"
+
+	"tailscale.com/types/ipproto"
+	"tailscale.com/util/vizerror"
+)
+
+// ProtoPortRange is used to encode "proto:port" format.
+// The following formats are supported:
+//
+//	"*" allows all TCP, UDP and ICMP traffic on all ports.
+//	"<ports>" allows all TCP, UDP and ICMP traffic on the specified ports.
+//	"proto:*" allows traffic of the specified proto on all ports.
+//	"proto:<port>" allows traffic of the specified proto on the specified port.
+//
+// Ports are either a single port number or a range of ports (e.g. "80-90").
+// String named protocols support names that ipproto.Proto accepts.
+type ProtoPortRange struct {
+	// Proto is the IP protocol number.
+	// If Proto is 0, it means TCP+UDP+ICMP(4+6).
+	Proto int
+	Ports PortRange
+}
+
+func (ppr ProtoPortRange) String() string {
+	if ppr.Proto == 0 {
+		if ppr.Ports == PortRangeAny {
+			return "*"
+		}
+	}
+	var buf strings.Builder
+	if ppr.Proto != 0 {
+		// Proto.MarshalText is infallible.
+		text, _ := ipproto.Proto(ppr.Proto).MarshalText()
+		buf.Write(text)
+		buf.Write([]byte(":"))
+	}
+	pr := ppr.Ports
+	if pr.First == pr.Last {
+		fmt.Fprintf(&buf, "%d", pr.First)
+	} else if pr == PortRangeAny {
+		buf.WriteByte('*')
+	} else {
+		fmt.Fprintf(&buf, "%d-%d", pr.First, pr.Last)
+	}
+	return buf.String()
+}
+
+// ParseProtoPortRanges parses a slice of IP port range fields.
+func ParseProtoPortRanges(ips []string) ([]ProtoPortRange, error) {
+	var out []ProtoPortRange
+	for _, p := range ips {
+		ppr, err := parseProtoPortRange(p)
+		if err != nil {
+			return nil, err
+		}
+		out = append(out, *ppr)
+	}
+	return out, nil
+}
+
+func parseProtoPortRange(ipProtoPort string) (*ProtoPortRange, error) {
+	if ipProtoPort == "" {
+		return nil, errors.New("empty string")
+	}
+	if ipProtoPort == "*" {
+		return &ProtoPortRange{Ports: PortRangeAny}, nil
+	}
+	if !strings.Contains(ipProtoPort, ":") {
+		ipProtoPort = "*:" + ipProtoPort
+	}
+	protoStr, portRange, err := parseHostPortRange(ipProtoPort)
+	if err != nil {
+		return nil, err
+	}
+	if protoStr == "" {
+		return nil, errors.New("empty protocol")
+	}
+
+	ppr := &ProtoPortRange{
+		Ports: portRange,
+	}
+	if protoStr == "*" {
+		return ppr, nil
+	}
+	var ipProto ipproto.Proto
+	if err := ipProto.UnmarshalText([]byte(protoStr)); err != nil {
+		return nil, err
+	}
+	ppr.Proto = int(ipProto)
+	return ppr, nil
+}
+
+// parseHostPortRange parses hostport as HOST:PORTS where HOST is
+// returned unchanged and PORTS is is either "*" or PORTLOW-PORTHIGH ranges.
+func parseHostPortRange(hostport string) (host string, ports PortRange, err error) {
+	hostport = strings.ToLower(hostport)
+	colon := strings.LastIndexByte(hostport, ':')
+	if colon < 0 {
+		return "", ports, vizerror.New("hostport must contain a colon (\":\")")
+	}
+	host = hostport[:colon]
+	portlist := hostport[colon+1:]
+
+	if strings.Contains(host, ",") {
+		return "", ports, vizerror.New("host cannot contain a comma (\",\")")
+	}
+
+	if portlist == "*" {
+		// Special case: permit hostname:* as a port wildcard.
+		return host, PortRangeAny, nil
+	}
+
+	if len(portlist) == 0 {
+		return "", ports, vizerror.Errorf("invalid port list: %#v", portlist)
+	}
+
+	if strings.Count(portlist, "-") > 1 {
+		return "", ports, vizerror.Errorf("port range %#v: too many dashes(-)", portlist)
+	}
+
+	firstStr, lastStr, isRange := strings.Cut(portlist, "-")
+
+	var first, last uint64
+	first, err = strconv.ParseUint(firstStr, 10, 16)
+	if err != nil {
+		return "", ports, vizerror.Errorf("port range %#v: invalid first integer", portlist)
+	}
+
+	if isRange {
+		last, err = strconv.ParseUint(lastStr, 10, 16)
+		if err != nil {
+			return "", ports, vizerror.Errorf("port range %#v: invalid last integer", portlist)
+		}
+	} else {
+		last = first
+	}
+
+	if first == 0 {
+		return "", ports, vizerror.Errorf("port range %#v: first port must be >0, or use '*' for wildcard", portlist)
+	}
+
+	if first > last {
+		return "", ports, vizerror.Errorf("port range %#v: first port must be >= last port", portlist)
+	}
+
+	return host, newPortRange(uint16(first), uint16(last)), nil
+}
+
+func newPortRange(first, last uint16) PortRange {
+	return PortRange{First: first, Last: last}
+}

+ 90 - 0
tailcfg/proto_port_range_test.go

@@ -0,0 +1,90 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tailcfg
+
+import (
+	"errors"
+	"testing"
+
+	"tailscale.com/types/ipproto"
+)
+
+func TestProtoPortRangeParsing(t *testing.T) {
+	pr := func(s, e uint16) PortRange {
+		return PortRange{First: s, Last: e}
+	}
+	tests := []struct {
+		in  string
+		out ProtoPortRange
+		err error
+	}{
+		{in: "tcp:80", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: pr(80, 80)}},
+		{in: "80", out: ProtoPortRange{Ports: pr(80, 80)}},
+		{in: "*", out: ProtoPortRange{Ports: PortRangeAny}},
+		{in: "*:*", out: ProtoPortRange{Ports: PortRangeAny}},
+		{in: "tcp:*", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: PortRangeAny}},
+		{
+			in:  "tcp:",
+			err: errors.New(`invalid port list: ""`),
+		},
+		{
+			in:  ":80",
+			err: errors.New(`empty protocol`),
+		},
+		{
+			in:  "",
+			err: errors.New(`empty string`),
+		},
+	}
+
+	for _, tc := range tests {
+		t.Run(tc.in, func(t *testing.T) {
+			ppr, err := parseProtoPortRange(tc.in)
+			if gotErr, wantErr := err != nil, tc.err != nil; gotErr != wantErr {
+				t.Fatalf("got err %v; want %v", err, tc.err)
+			} else if gotErr {
+				if err.Error() != tc.err.Error() {
+					t.Fatalf("got err %q; want %q", err, tc.err)
+				}
+				return
+			}
+			if *ppr != tc.out {
+				t.Fatalf("got %v; want %v", ppr, tc.out)
+			}
+		})
+	}
+}
+
+func TestProtoPortRangeString(t *testing.T) {
+	tests := []struct {
+		input ProtoPortRange
+		want  string
+	}{
+		{ProtoPortRange{}, "0"},
+
+		// Zero protocol.
+		{ProtoPortRange{Ports: PortRangeAny}, "*"},
+		{ProtoPortRange{Ports: PortRange{23, 23}}, "23"},
+		{ProtoPortRange{Ports: PortRange{80, 120}}, "80-120"},
+
+		// Non-zero unnamed protocol.
+		{ProtoPortRange{Proto: 100, Ports: PortRange{80, 80}}, "100:80"},
+		{ProtoPortRange{Proto: 200, Ports: PortRange{101, 105}}, "200:101-105"},
+
+		// Non-zero named protocol.
+		{ProtoPortRange{Proto: 1, Ports: PortRangeAny}, "icmp:*"},
+		{ProtoPortRange{Proto: 2, Ports: PortRangeAny}, "igmp:*"},
+		{ProtoPortRange{Proto: 6, Ports: PortRange{10, 13}}, "tcp:10-13"},
+		{ProtoPortRange{Proto: 17, Ports: PortRangeAny}, "udp:*"},
+		{ProtoPortRange{Proto: 0x84, Ports: PortRange{999, 999}}, "sctp:999"},
+		{ProtoPortRange{Proto: 0x3a, Ports: PortRangeAny}, "ipv6-icmp:*"},
+		{ProtoPortRange{Proto: 0x21, Ports: PortRangeAny}, "dccp:*"},
+		{ProtoPortRange{Proto: 0x2f, Ports: PortRangeAny}, "gre:*"},
+	}
+	for _, tc := range tests {
+		if got := tc.input.String(); got != tc.want {
+			t.Errorf("String for %v: got %q, want %q", tc.input, got, tc.want)
+		}
+	}
+}