Ver Fonte

types/opt: add BoolFlag for setting Bool value as a flag

Updates tailscale/corp#22578

Signed-off-by: Will Norris <[email protected]>
Will Norris há 1 ano atrás
pai
commit
cccacff564
2 ficheiros alterados com 63 adições e 0 exclusões
  1. 26 0
      types/opt/bool.go
  2. 37 0
      types/opt/bool_test.go

+ 26 - 0
types/opt/bool.go

@@ -105,3 +105,29 @@ func (b *Bool) UnmarshalJSON(j []byte) error {
 	}
 	return nil
 }
+
+// BoolFlag is a wrapper for Bool that implements [flag.Value].
+type BoolFlag struct {
+	*Bool
+}
+
+// Set the value of b, using any value supported by [strconv.ParseBool].
+func (b *BoolFlag) Set(s string) error {
+	v, err := strconv.ParseBool(s)
+	if err != nil {
+		return err
+	}
+	b.Bool.Set(v)
+	return nil
+}
+
+// String returns "true" or "false" if the value is set, or an empty string otherwise.
+func (b *BoolFlag) String() string {
+	if b == nil || b.Bool == nil {
+		return ""
+	}
+	if v, ok := b.Bool.Get(); ok {
+		return strconv.FormatBool(v)
+	}
+	return ""
+}

+ 37 - 0
types/opt/bool_test.go

@@ -5,7 +5,9 @@ package opt
 
 import (
 	"encoding/json"
+	"flag"
 	"reflect"
+	"strings"
 	"testing"
 )
 
@@ -127,3 +129,38 @@ func TestUnmarshalAlloc(t *testing.T) {
 		t.Errorf("got %v allocs, want 0", n)
 	}
 }
+
+func TestBoolFlag(t *testing.T) {
+	tests := []struct {
+		arguments      string
+		wantParseError bool // expect flag.Parse to error
+		want           Bool
+	}{
+		{"", false, Bool("")},
+		{"-test", true, Bool("")},
+		{`-test=""`, true, Bool("")},
+		{"-test invalid", true, Bool("")},
+
+		{"-test true", false, NewBool(true)},
+		{"-test 1", false, NewBool(true)},
+
+		{"-test false", false, NewBool(false)},
+		{"-test 0", false, NewBool(false)},
+	}
+
+	for _, tt := range tests {
+		var got Bool
+		fs := flag.NewFlagSet(t.Name(), flag.ContinueOnError)
+		fs.Var(&BoolFlag{&got}, "test", "test flag")
+
+		arguments := strings.Split(tt.arguments, " ")
+		err := fs.Parse(arguments)
+		if (err != nil) != tt.wantParseError {
+			t.Errorf("flag.Parse(%q) returned error %v, want %v", arguments, err, tt.wantParseError)
+		}
+
+		if got != tt.want {
+			t.Errorf("flag.Parse(%q) got %q, want %q", arguments, got, tt.want)
+		}
+	}
+}