Browse Source

tailcfg: implement text encoding for ProtoPortRange

Updates tailscale/corp#15043
Signed-off-by: James Tucker <[email protected]>
James Tucker 2 years ago
parent
commit
4abd470322
2 changed files with 82 additions and 14 deletions
  1. 29 2
      tailcfg/proto_port_range.go
  2. 53 12
      tailcfg/proto_port_range_test.go

+ 29 - 2
tailcfg/proto_port_range.go

@@ -13,6 +13,11 @@ import (
 	"tailscale.com/util/vizerror"
 )
 
+var (
+	errEmptyProtocol = errors.New("empty protocol")
+	errEmptyString   = errors.New("empty string")
+)
+
 // ProtoPortRange is used to encode "proto:port" format.
 // The following formats are supported:
 //
@@ -30,6 +35,28 @@ type ProtoPortRange struct {
 	Ports PortRange
 }
 
+// UnmarshalText implements the encoding.TextUnmarshaler interface. See
+// ProtoPortRange for the format.
+func (ppr *ProtoPortRange) UnmarshalText(text []byte) error {
+	ppr2, err := parseProtoPortRange(string(text))
+	if err != nil {
+		return err
+	}
+	*ppr = *ppr2
+	return nil
+}
+
+// MarshalText implements the encoding.TextMarshaler interface. See
+// ProtoPortRange for the format.
+func (ppr *ProtoPortRange) MarshalText() ([]byte, error) {
+	if ppr.Proto == 0 && ppr.Ports == (PortRange{}) {
+		return []byte{}, nil
+	}
+	return []byte(ppr.String()), nil
+}
+
+// String implements the stringer interface. See ProtoPortRange for the
+// format.
 func (ppr ProtoPortRange) String() string {
 	if ppr.Proto == 0 {
 		if ppr.Ports == PortRangeAny {
@@ -69,7 +96,7 @@ func ParseProtoPortRanges(ips []string) ([]ProtoPortRange, error) {
 
 func parseProtoPortRange(ipProtoPort string) (*ProtoPortRange, error) {
 	if ipProtoPort == "" {
-		return nil, errors.New("empty string")
+		return nil, errEmptyString
 	}
 	if ipProtoPort == "*" {
 		return &ProtoPortRange{Ports: PortRangeAny}, nil
@@ -82,7 +109,7 @@ func parseProtoPortRange(ipProtoPort string) (*ProtoPortRange, error) {
 		return nil, err
 	}
 	if protoStr == "" {
-		return nil, errors.New("empty protocol")
+		return nil, errEmptyProtocol
 	}
 
 	ppr := &ProtoPortRange{

+ 53 - 12
tailcfg/proto_port_range_test.go

@@ -4,12 +4,15 @@
 package tailcfg
 
 import (
-	"errors"
+	"encoding"
 	"testing"
 
 	"tailscale.com/types/ipproto"
+	"tailscale.com/util/vizerror"
 )
 
+var _ encoding.TextUnmarshaler = (*ProtoPortRange)(nil)
+
 func TestProtoPortRangeParsing(t *testing.T) {
 	pr := func(s, e uint16) PortRange {
 		return PortRange{First: s, Last: e}
@@ -26,30 +29,28 @@ func TestProtoPortRangeParsing(t *testing.T) {
 		{in: "tcp:*", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: PortRangeAny}},
 		{
 			in:  "tcp:",
-			err: errors.New(`invalid port list: ""`),
+			err: vizerror.Errorf("invalid port list: %#v", ""),
 		},
 		{
 			in:  ":80",
-			err: errors.New(`empty protocol`),
+			err: errEmptyProtocol,
 		},
 		{
 			in:  "",
-			err: errors.New(`empty string`),
+			err: errEmptyString,
 		},
 	}
 
 	for _, tc := range tests {
 		t.Run(tc.in, func(t *testing.T) {
-			ppr, err := parseProtoPortRange(tc.in)
-			if gotErr, wantErr := err != nil, tc.err != nil; gotErr != wantErr {
-				t.Fatalf("got err %v; want %v", err, tc.err)
-			} else if gotErr {
-				if err.Error() != tc.err.Error() {
-					t.Fatalf("got err %q; want %q", err, tc.err)
+			var ppr ProtoPortRange
+			err := ppr.UnmarshalText([]byte(tc.in))
+			if tc.err != err {
+				if err == nil || tc.err.Error() != err.Error() {
+					t.Fatalf("want err=%v, got %v", tc.err, err)
 				}
-				return
 			}
-			if *ppr != tc.out {
+			if ppr != tc.out {
 				t.Fatalf("got %v; want %v", ppr, tc.out)
 			}
 		})
@@ -88,3 +89,43 @@ func TestProtoPortRangeString(t *testing.T) {
 		}
 	}
 }
+
+func TestProtoPortRangeRoundTrip(t *testing.T) {
+	tests := []struct {
+		input ProtoPortRange
+		text  string
+	}{
+		{ProtoPortRange{Ports: PortRangeAny}, "*"},
+		{ProtoPortRange{Ports: PortRange{23, 23}}, "23"},
+		{ProtoPortRange{Ports: PortRange{80, 120}}, "80-120"},
+		{ProtoPortRange{Proto: 100, Ports: PortRange{80, 80}}, "100:80"},
+		{ProtoPortRange{Proto: 200, Ports: PortRange{101, 105}}, "200:101-105"},
+		{ProtoPortRange{Proto: 1, Ports: PortRangeAny}, "icmp:*"},
+		{ProtoPortRange{Proto: 2, Ports: PortRangeAny}, "igmp:*"},
+		{ProtoPortRange{Proto: 6, Ports: PortRange{10, 13}}, "tcp:10-13"},
+		{ProtoPortRange{Proto: 17, Ports: PortRangeAny}, "udp:*"},
+		{ProtoPortRange{Proto: 0x84, Ports: PortRange{999, 999}}, "sctp:999"},
+		{ProtoPortRange{Proto: 0x3a, Ports: PortRangeAny}, "ipv6-icmp:*"},
+		{ProtoPortRange{Proto: 0x21, Ports: PortRangeAny}, "dccp:*"},
+		{ProtoPortRange{Proto: 0x2f, Ports: PortRangeAny}, "gre:*"},
+	}
+
+	for _, tc := range tests {
+		out, err := tc.input.MarshalText()
+		if err != nil {
+			t.Errorf("MarshalText for %v: %v", tc.input, err)
+			continue
+		}
+		if got := string(out); got != tc.text {
+			t.Errorf("MarshalText for %#v: got %q, want %q", tc.input, got, tc.text)
+		}
+		var ppr ProtoPortRange
+		if err := ppr.UnmarshalText(out); err != nil {
+			t.Errorf("UnmarshalText for %q: err=%v", tc.text, err)
+			continue
+		}
+		if ppr != tc.input {
+			t.Errorf("round trip error for %q: got %v, want %#v", tc.text, ppr, tc.input)
+		}
+	}
+}