Преглед изворни кода

cmd/tailscaled, util/winutil: changes to process and token APIs in winutil

This PR changes the internal getTokenInfo function to use generics.
I also removed our own implementations for obtaining a token's user
and primary group in favour of calling the ones now available in
x/sys/windows.

Furthermore, I added two new functions for working with tokens, logon
session IDs, and Terminal Services / RDP session IDs.

I modified our privilege enabling code to allow enabling of multiple
privileges via one single function call.

Finally, I added the ProcessImageName function and updated the code in
tailscaled_windows.go to use that instead of directly calling the
underlying API.

All of these changes will be utilized by subsequent PRs pertaining to
this issue.

Updates https://github.com/tailscale/corp/issues/13998

Signed-off-by: Aaron Klotz <[email protected]>
Aaron Klotz пре 2 година
родитељ
комит
855f79fad7
2 измењених фајлова са 103 додато и 42 уклоњено
  1. 4 5
      cmd/tailscaled/tailscaled_windows.go
  2. 99 37
      util/winutil/winutil_windows.go

+ 4 - 5
cmd/tailscaled/tailscaled_windows.go

@@ -529,12 +529,11 @@ func uninstallWinTun(logf logger.Logf) {
 
 func fullyQualifiedWintunPath(logf logger.Logf) string {
 	var dir string
-	var buf [windows.MAX_PATH]uint16
-	length := uint32(len(buf))
-	if err := windows.QueryFullProcessImageName(windows.CurrentProcess(), 0, &buf[0], &length); err != nil {
-		logf("QueryFullProcessImageName failed: %v", err)
+	imgName, err := winutil.ProcessImageName(windows.CurrentProcess())
+	if err != nil {
+		logf("ProcessImageName failed: %v", err)
 	} else {
-		dir = filepath.Dir(windows.UTF16ToString(buf[:length]))
+		dir = filepath.Dir(imgName)
 	}
 
 	return filepath.Join(dir, "wintun.dll")

+ 99 - 37
util/winutil/winutil_windows.go

@@ -225,21 +225,40 @@ func isSIDValidPrincipal(uid string) bool {
 	}
 }
 
-// EnableCurrentThreadPrivilege enables the named privilege in the current
-// thread access token. The current goroutine is also locked to the OS thread
-// (runtime.LockOSThread). Callers must call the returned disable function when
-// done with the privileged task.
-func EnableCurrentThreadPrivilege(name string) (disable func() error, err error) {
+// EnableCurrentThreadPrivilege enables the named privilege
+// in the current thread's access token. The current goroutine is also locked to
+// the OS thread (runtime.LockOSThread). Callers must call the returned disable
+// function when done with the privileged task.
+func EnableCurrentThreadPrivilege(name string) (disable func(), err error) {
+	return EnableCurrentThreadPrivileges([]string{name})
+}
+
+// EnableCurrentThreadPrivileges enables the named privileges
+// in the current thread's access token. The current goroutine is also locked to
+// the OS thread (runtime.LockOSThread). Callers must call the returned disable
+// function when done with the privileged task.
+func EnableCurrentThreadPrivileges(names []string) (disable func(), err error) {
 	runtime.LockOSThread()
+	if len(names) == 0 {
+		// Nothing to enable; no-op isn't really an error...
+		return runtime.UnlockOSThread, nil
+	}
 
 	if err := windows.ImpersonateSelf(windows.SecurityImpersonation); err != nil {
 		runtime.UnlockOSThread()
 		return nil, err
 	}
-	disable = func() error {
+
+	disable = func() {
 		defer runtime.UnlockOSThread()
-		return windows.RevertToSelf()
+		// If RevertToSelf fails, it's not really recoverable and we should panic.
+		// Failure to do so would leak the privileges we're enabling, which is a
+		// security issue.
+		if err := windows.RevertToSelf(); err != nil {
+			panic(fmt.Sprintf("RevertToSelf failed: %v", err))
+		}
 	}
+
 	defer func() {
 		if err != nil {
 			disable()
@@ -254,19 +273,38 @@ func EnableCurrentThreadPrivilege(name string) (disable func() error, err error)
 	}
 	defer t.Close()
 
-	var tp windows.Tokenprivileges
+	tp := newTokenPrivileges(len(names))
+	privs := tp.AllPrivileges()
+	for i := range privs {
+		var privStr *uint16
+		privStr, err = windows.UTF16PtrFromString(names[i])
+		if err != nil {
+			return nil, err
+		}
+		err = windows.LookupPrivilegeValue(nil, privStr, &privs[i].Luid)
+		if err != nil {
+			return nil, err
+		}
+		privs[i].Attributes = windows.SE_PRIVILEGE_ENABLED
+	}
 
-	privStr, err := syscall.UTF16PtrFromString(name)
+	err = windows.AdjustTokenPrivileges(t, false, tp, 0, nil, nil)
 	if err != nil {
 		return nil, err
 	}
-	err = windows.LookupPrivilegeValue(nil, privStr, &tp.Privileges[0].Luid)
-	if err != nil {
-		return nil, err
+
+	return disable, nil
+}
+
+func newTokenPrivileges(numPrivs int) *windows.Tokenprivileges {
+	if numPrivs <= 0 {
+		panic("numPrivs must be > 0")
 	}
-	tp.PrivilegeCount = 1
-	tp.Privileges[0].Attributes = windows.SE_PRIVILEGE_ENABLED
-	return disable, windows.AdjustTokenPrivileges(t, false, &tp, 0, nil, nil)
+	numBytes := unsafe.Sizeof(windows.Tokenprivileges{}) + (uintptr(numPrivs-1) * unsafe.Sizeof(windows.LUIDAndAttributes{}))
+	buf := make([]byte, numBytes)
+	result := (*windows.Tokenprivileges)(unsafe.Pointer(unsafe.SliceData(buf)))
+	result.PrivilegeCount = uint32(numPrivs)
+	return result
 }
 
 // StartProcessAsChild starts exePath process as a child of parentPID.
@@ -346,35 +384,22 @@ func CreateAppMutex(name string) (windows.Handle, error) {
 	return windows.CreateMutex(nil, false, windows.StringToUTF16Ptr(name))
 }
 
-func getTokenInfo(token windows.Token, infoClass uint32) ([]byte, error) {
+func getTokenInfo[T any](token windows.Token, infoClass uint32) (*T, error) {
+	var buf []byte
 	var desiredLen uint32
-	err := windows.GetTokenInformation(token, infoClass, nil, 0, &desiredLen)
-	if err != nil && err != windows.ERROR_INSUFFICIENT_BUFFER {
-		return nil, err
-	}
 
-	buf := make([]byte, desiredLen)
-	actualLen := desiredLen
-	err = windows.GetTokenInformation(token, infoClass, &buf[0], desiredLen, &actualLen)
-	return buf, err
-}
+	err := windows.GetTokenInformation(token, infoClass, nil, 0, &desiredLen)
 
-func getTokenUserInfo(token windows.Token) (*windows.Tokenuser, error) {
-	buf, err := getTokenInfo(token, windows.TokenUser)
-	if err != nil {
-		return nil, err
+	for err == windows.ERROR_INSUFFICIENT_BUFFER {
+		buf = make([]byte, desiredLen)
+		err = windows.GetTokenInformation(token, infoClass, unsafe.SliceData(buf), desiredLen, &desiredLen)
 	}
 
-	return (*windows.Tokenuser)(unsafe.Pointer(&buf[0])), nil
-}
-
-func getTokenPrimaryGroupInfo(token windows.Token) (*windows.Tokenprimarygroup, error) {
-	buf, err := getTokenInfo(token, windows.TokenPrimaryGroup)
 	if err != nil {
 		return nil, err
 	}
 
-	return (*windows.Tokenprimarygroup)(unsafe.Pointer(&buf[0])), nil
+	return (*T)(unsafe.Pointer(unsafe.SliceData(buf))), nil
 }
 
 type tokenElevationType int32
@@ -417,12 +442,12 @@ func GetCurrentUserSIDs() (*UserSIDs, error) {
 	}
 	defer token.Close()
 
-	userInfo, err := getTokenUserInfo(token)
+	userInfo, err := token.GetTokenUser()
 	if err != nil {
 		return nil, err
 	}
 
-	primaryGroup, err := getTokenPrimaryGroupInfo(token)
+	primaryGroup, err := token.GetTokenPrimaryGroup()
 	if err != nil {
 		return nil, err
 	}
@@ -645,3 +670,40 @@ func registerForRestart(opts RegisterForRestartOpts) error {
 
 	return nil
 }
+
+// ProcessImageName returns the fully-qualified path to the executable image
+// associated with process.
+func ProcessImageName(process windows.Handle) (string, error) {
+	var pathBuf [windows.MAX_PATH]uint16
+	pathBufLen := uint32(len(pathBuf))
+	if err := windows.QueryFullProcessImageName(process, 0, &pathBuf[0], &pathBufLen); err != nil {
+		return "", err
+	}
+	return windows.UTF16ToString(pathBuf[:pathBufLen]), nil
+}
+
+// TSSessionIDToLogonSessionID retrieves the logon session ID associated with
+// tsSessionId, which is a Terminal Services / RDP session ID. The calling
+// process must be running as LocalSystem.
+func TSSessionIDToLogonSessionID(tsSessionID uint32) (logonSessionID windows.LUID, err error) {
+	var token windows.Token
+	if err := windows.WTSQueryUserToken(tsSessionID, &token); err != nil {
+		return logonSessionID, fmt.Errorf("WTSQueryUserToken: %w", err)
+	}
+	defer token.Close()
+	return LogonSessionID(token)
+}
+
+type tokenOrigin struct {
+	originatingLogonSession windows.LUID
+}
+
+// LogonSessionID obtains the logon session ID associated with token.
+func LogonSessionID(token windows.Token) (logonSessionID windows.LUID, err error) {
+	origin, err := getTokenInfo[tokenOrigin](token, windows.TokenOrigin)
+	if err != nil {
+		return logonSessionID, err
+	}
+
+	return origin.originatingLogonSession, nil
+}