Просмотр исходного кода

util/syspolicy: add ReadStringArray interface (#11857)

Fixes tailscale/corp#19459

This PR adds the ability for users of the syspolicy handler to read string arrays from the MDM solution configured on the system.

Signed-off-by: Andrea Gottardo <[email protected]>
Andrea Gottardo 1 год назад
Родитель
Сommit
1d3e77f373

+ 12 - 0
ipn/ipnlocal/local_test.go

@@ -1568,6 +1568,11 @@ func (h *errorSyspolicyHandler) ReadBoolean(key string) (bool, error) {
 	return false, syspolicy.ErrNoSuchKey
 }
 
+func (h *errorSyspolicyHandler) ReadStringArray(key string) ([]string, error) {
+	h.t.Errorf("ReadStringArray(%q) unexpectedly called", key)
+	return nil, syspolicy.ErrNoSuchKey
+}
+
 type mockSyspolicyHandler struct {
 	t *testing.T
 	// stringPolicies is the collection of policies that we expect to see
@@ -1607,6 +1612,13 @@ func (h *mockSyspolicyHandler) ReadBoolean(key string) (bool, error) {
 	return false, syspolicy.ErrNoSuchKey
 }
 
+func (h *mockSyspolicyHandler) ReadStringArray(key string) ([]string, error) {
+	if h.failUnknownPolicies {
+		h.t.Errorf("ReadStringArray(%q) unexpectedly called", key)
+	}
+	return nil, syspolicy.ErrNoSuchKey
+}
+
 func TestSetExitNodeIDPolicy(t *testing.T) {
 	pfx := netip.MustParsePrefix
 	tests := []struct {

+ 24 - 0
util/syspolicy/caching_handler.go

@@ -16,6 +16,7 @@ type CachingHandler struct {
 	strings  map[string]string
 	uint64s  map[string]uint64
 	bools    map[string]bool
+	strArrs  map[string][]string
 	notFound map[string]bool
 	handler  Handler
 }
@@ -27,6 +28,7 @@ func NewCachingHandler(handler Handler) *CachingHandler {
 		strings:  make(map[string]string),
 		uint64s:  make(map[string]uint64),
 		bools:    make(map[string]bool),
+		strArrs:  make(map[string][]string),
 		notFound: make(map[string]bool),
 	}
 }
@@ -96,3 +98,25 @@ func (ch *CachingHandler) ReadBoolean(key string) (bool, error) {
 	ch.bools[key] = val
 	return val, nil
 }
+
+// ReadBoolean reads the policy settings boolean value given the key.
+// ReadBoolean first reads from the handler's cache before resorting to using the handler.
+func (ch *CachingHandler) ReadStringArray(key string) ([]string, error) {
+	ch.mu.Lock()
+	defer ch.mu.Unlock()
+	if val, ok := ch.strArrs[key]; ok {
+		return val, nil
+	}
+	if notFound := ch.notFound[key]; notFound {
+		return nil, ErrNoSuchKey
+	}
+	val, err := ch.handler.ReadStringArray(key)
+	if errors.Is(err, ErrNoSuchKey) {
+		ch.notFound[key] = true
+		return nil, err
+	} else if err != nil {
+		return nil, err
+	}
+	ch.strArrs[key] = val
+	return val, nil
+}

+ 7 - 0
util/syspolicy/handler.go

@@ -25,6 +25,9 @@ type Handler interface {
 	// ReadBool reads the policy setting's boolean value for the given key.
 	// It should return ErrNoSuchKey if the key does not have a value set.
 	ReadBoolean(key string) (bool, error)
+	// ReadStringArray reads the policy setting's string array value for the given key.
+	// It should return ErrNoSuchKey if the key does not have a value set.
+	ReadStringArray(key string) ([]string, error)
 }
 
 // ErrNoSuchKey is returned by a Handler when the specified key does not have a
@@ -46,6 +49,10 @@ func (defaultHandler) ReadBoolean(_ string) (bool, error) {
 	return false, ErrNoSuchKey
 }
 
+func (defaultHandler) ReadStringArray(_ string) ([]string, error) {
+	return nil, ErrNoSuchKey
+}
+
 // markHandlerInUse is called before handler methods are called.
 func markHandlerInUse() {
 	handlerUsed.Store(true)

+ 10 - 0
util/syspolicy/handler_windows.go

@@ -93,3 +93,13 @@ func (windowsHandler) ReadBoolean(key string) (bool, error) {
 	}
 	return value != 0, err
 }
+
+func (windowsHandler) ReadStringArray(key string) ([]string, error) {
+	value, err := winutil.GetPolicyStringArray(key)
+	if errors.Is(err, winutil.ErrNoValue) {
+		err = ErrNoSuchKey
+	} else if err != nil {
+		windowsErrors.Add(1)
+	}
+	return value, err
+}

+ 9 - 0
util/syspolicy/syspolicy.go

@@ -36,6 +36,15 @@ func GetBoolean(key Key, defaultValue bool) (bool, error) {
 	return v, err
 }
 
+func GetStringArray(key Key, defaultValue []string) ([]string, error) {
+	markHandlerInUse()
+	v, err := handler.ReadStringArray(string(key))
+	if errors.Is(err, ErrNoSuchKey) {
+		return defaultValue, nil
+	}
+	return v, err
+}
+
 // PreferenceOption is a policy that governs whether a boolean variable
 // is forcibly assigned an administrator-defined value, or allowed to receive
 // a user-defined value.

+ 9 - 0
util/syspolicy/syspolicy_test.go

@@ -18,6 +18,7 @@ type testHandler struct {
 	s     string
 	u64   uint64
 	b     bool
+	sArr  []string
 	err   error
 	calls int // used for testing reads from cache vs. handler
 }
@@ -48,6 +49,14 @@ func (th *testHandler) ReadBoolean(key string) (bool, error) {
 	return th.b, th.err
 }
 
+func (th *testHandler) ReadStringArray(key string) ([]string, error) {
+	if key != string(th.key) {
+		th.t.Errorf("ReadStringArray(%q) want %q", key, th.key)
+	}
+	th.calls++
+	return th.sArr, th.err
+}
+
 func TestGetString(t *testing.T) {
 	tests := []struct {
 		name         string

+ 4 - 0
util/winutil/winutil.go

@@ -45,6 +45,10 @@ func GetPolicyInteger(name string) (uint64, error) {
 	return getPolicyInteger(name)
 }
 
+func GetPolicyStringArray(name string) ([]string, error) {
+	return getPolicyStringArray(name)
+}
+
 // GetRegString looks up a registry path in the local machine path, or returns
 // an empty string and error.
 //

+ 2 - 0
util/winutil/winutil_notwindows.go

@@ -21,6 +21,8 @@ func getPolicyString(name string) (string, error) { return "", ErrNoValue }
 
 func getPolicyInteger(name string) (uint64, error) { return 0, ErrNoValue }
 
+func getPolicyStringArray(name string) ([]string, error) { return nil, ErrNoValue }
+
 func getRegString(name string) (string, error) { return "", ErrNoValue }
 
 func getRegInteger(name string) (uint64, error) { return 0, ErrNoValue }

+ 4 - 0
util/winutil/winutil_windows.go

@@ -57,6 +57,10 @@ func getPolicyString(name string) (string, error) {
 	return s, err
 }
 
+func getPolicyStringArray(name string) ([]string, error) {
+	return getRegStringsInternal(regPolicyBase, name)
+}
+
 func getRegString(name string) (string, error) {
 	s, err := getRegStringInternal(regBase, name)
 	if err != nil {