Просмотр исходного кода

util/winutil: ensure domain controller address is used when retrieving remote profile information

We cannot directly pass a flat domain name into NetUserGetInfo; we must
resolve the address of a domain controller first.

This PR implements the appropriate resolution mechanisms to do that, and
also exposes a couple of new utility APIs for future needs.

Fixes #12627

Signed-off-by: Aaron Klotz <[email protected]>
Aaron Klotz 1 год назад
Родитель
Сommit
5f177090e3

+ 2 - 0
util/winutil/mksyscall.go

@@ -6,9 +6,11 @@ package winutil
 //go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go
 //go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go
 
+//sys dsGetDcName(computerName *uint16, domainName *uint16, domainGuid *windows.GUID, siteName *uint16, flags dsGetDcNameFlag, dcInfo **_DOMAIN_CONTROLLER_INFO) (ret error) = netapi32.DsGetDcNameW
 //sys expandEnvironmentStringsForUser(token windows.Token, src *uint16, dst *uint16, dstLen uint32) (err error) [int32(failretval)==0] = userenv.ExpandEnvironmentStringsForUserW
 //sys getApplicationRestartSettings(process windows.Handle, commandLine *uint16, commandLineLen *uint32, flags *uint32) (ret wingoes.HRESULT) = kernel32.GetApplicationRestartSettings
 //sys loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) [int32(failretval)==0] = userenv.LoadUserProfileW
+//sys netValidateName(server *uint16, name *uint16, account *uint16, password *uint16, nameType _NETSETUP_NAME_TYPE) (ret error) = netapi32.NetValidateName
 //sys queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) [failretval==0] = advapi32.QueryServiceConfig2W
 //sys registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret wingoes.HRESULT) = kernel32.RegisterApplicationRestart
 //sys rmEndSession(session _RMHANDLE) (ret error) = rstrtmgr.RmEndSession

+ 37 - 4
util/winutil/userprofile_windows.go

@@ -135,9 +135,36 @@ func (up *UserProfile) Close() error {
 }
 
 func getRoamingProfilePath(logf logger.Logf, token windows.Token, computerName, userName *uint16) (path *uint16, err error) {
-	// logf is for debugging/testing.
-	if logf == nil {
-		logf = logger.Discard
+	// logf is for debugging/testing. While we would normally replace a nil logf
+	// with logger.Discard, we're using explicit checks within this func so that
+	// we don't waste time allocating and converting UTF-16 strings unnecessarily.
+	var comp string
+	if logf != nil {
+		comp = windows.UTF16PtrToString(computerName)
+		user := windows.UTF16PtrToString(userName)
+		logf("BEGIN getRoamingProfilePath(%q, %q)", comp, user)
+		defer logf("END getRoamingProfilePath(%q, %q)", comp, user)
+	}
+
+	isDomainName, err := isDomainName(computerName)
+	if err != nil {
+		return nil, err
+	}
+	if isDomainName {
+		if logf != nil {
+			logf("computerName %q is a domain, resolving...", comp)
+		}
+		dcInfo, err := resolveDomainController(computerName, nil)
+		if err != nil {
+			return nil, err
+		}
+		defer dcInfo.Close()
+
+		computerName = dcInfo.DomainControllerName
+		if logf != nil {
+			dom := windows.UTF16PtrToString(computerName)
+			logf("%q resolved to %q", comp, dom)
+		}
 	}
 
 	var pbuf *byte
@@ -147,7 +174,9 @@ func getRoamingProfilePath(logf logger.Logf, token windows.Token, computerName,
 	defer windows.NetApiBufferFree(pbuf)
 
 	ui4 := (*_USER_INFO_4)(unsafe.Pointer(pbuf))
-	logf("getRoamingProfilePath: got %#v", *ui4)
+	if logf != nil {
+		logf("getRoamingProfilePath: got %#v", *ui4)
+	}
 	profilePath := ui4.Profile
 	if profilePath == nil {
 		return nil, nil
@@ -162,6 +191,10 @@ func getRoamingProfilePath(logf logger.Logf, token windows.Token, computerName,
 		return nil, err
 	}
 
+	if logf != nil {
+		logf("returning %q", windows.UTF16ToString(expanded[:]))
+	}
+
 	// This buffer is only used briefly, so we don't bother copying it into a shorter slice.
 	return &expanded[0], nil
 }

+ 24 - 0
util/winutil/userprofile_windows_test.go

@@ -0,0 +1,24 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package winutil
+
+import (
+	"testing"
+
+	"golang.org/x/sys/windows"
+)
+
+func TestGetRoamingProfilePath(t *testing.T) {
+	token := windows.GetCurrentProcessToken()
+	computerName, userName, err := getComputerAndUserName(token, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if _, err := getRoamingProfilePath(t.Logf, token, computerName, userName); err != nil {
+		t.Error(err)
+	}
+
+	// TODO(aaron): Flesh out better once can run tests under domain accounts.
+}

+ 144 - 0
util/winutil/winutil_windows.go

@@ -784,3 +784,147 @@ func SetNTString[NTS NTStr, BU BufUnit](nts *NTS, buf []BU) {
 		panic("unknown type")
 	}
 }
+
+type domainControllerAddressType uint32
+
+const (
+	//lint:ignore U1000 maps to a win32 API
+	_DS_INET_ADDRESS    domainControllerAddressType = 1
+	_DS_NETBIOS_ADDRESS domainControllerAddressType = 2
+)
+
+type domainControllerFlag uint32
+
+const (
+	//lint:ignore U1000 maps to a win32 API
+	_DS_PDC_FLAG                    domainControllerFlag = 0x00000001
+	_DS_GC_FLAG                     domainControllerFlag = 0x00000004
+	_DS_LDAP_FLAG                   domainControllerFlag = 0x00000008
+	_DS_DS_FLAG                     domainControllerFlag = 0x00000010
+	_DS_KDC_FLAG                    domainControllerFlag = 0x00000020
+	_DS_TIMESERV_FLAG               domainControllerFlag = 0x00000040
+	_DS_CLOSEST_FLAG                domainControllerFlag = 0x00000080
+	_DS_WRITABLE_FLAG               domainControllerFlag = 0x00000100
+	_DS_GOOD_TIMESERV_FLAG          domainControllerFlag = 0x00000200
+	_DS_NDNC_FLAG                   domainControllerFlag = 0x00000400
+	_DS_SELECT_SECRET_DOMAIN_6_FLAG domainControllerFlag = 0x00000800
+	_DS_FULL_SECRET_DOMAIN_6_FLAG   domainControllerFlag = 0x00001000
+	_DS_WS_FLAG                     domainControllerFlag = 0x00002000
+	_DS_DS_8_FLAG                   domainControllerFlag = 0x00004000
+	_DS_DS_9_FLAG                   domainControllerFlag = 0x00008000
+	_DS_DS_10_FLAG                  domainControllerFlag = 0x00010000
+	_DS_KEY_LIST_FLAG               domainControllerFlag = 0x00020000
+	_DS_PING_FLAGS                  domainControllerFlag = 0x000FFFFF
+	_DS_DNS_CONTROLLER_FLAG         domainControllerFlag = 0x20000000
+	_DS_DNS_DOMAIN_FLAG             domainControllerFlag = 0x40000000
+	_DS_DNS_FOREST_FLAG             domainControllerFlag = 0x80000000
+)
+
+type _DOMAIN_CONTROLLER_INFO struct {
+	DomainControllerName        *uint16
+	DomainControllerAddress     *uint16
+	DomainControllerAddressType domainControllerAddressType
+	DomainGuid                  windows.GUID
+	DomainName                  *uint16
+	DnsForestName               *uint16
+	Flags                       domainControllerFlag
+	DcSiteName                  *uint16
+	ClientSiteName              *uint16
+}
+
+func (dci *_DOMAIN_CONTROLLER_INFO) Close() error {
+	if dci == nil {
+		return nil
+	}
+	return windows.NetApiBufferFree((*byte)(unsafe.Pointer(dci)))
+}
+
+type dsGetDcNameFlag uint32
+
+const (
+	//lint:ignore U1000 maps to a win32 API
+	_DS_FORCE_REDISCOVERY             dsGetDcNameFlag = 0x00000001
+	_DS_DIRECTORY_SERVICE_REQUIRED    dsGetDcNameFlag = 0x00000010
+	_DS_DIRECTORY_SERVICE_PREFERRED   dsGetDcNameFlag = 0x00000020
+	_DS_GC_SERVER_REQUIRED            dsGetDcNameFlag = 0x00000040
+	_DS_PDC_REQUIRED                  dsGetDcNameFlag = 0x00000080
+	_DS_BACKGROUND_ONLY               dsGetDcNameFlag = 0x00000100
+	_DS_IP_REQUIRED                   dsGetDcNameFlag = 0x00000200
+	_DS_KDC_REQUIRED                  dsGetDcNameFlag = 0x00000400
+	_DS_TIMESERV_REQUIRED             dsGetDcNameFlag = 0x00000800
+	_DS_WRITABLE_REQUIRED             dsGetDcNameFlag = 0x00001000
+	_DS_GOOD_TIMESERV_PREFERRED       dsGetDcNameFlag = 0x00002000
+	_DS_AVOID_SELF                    dsGetDcNameFlag = 0x00004000
+	_DS_ONLY_LDAP_NEEDED              dsGetDcNameFlag = 0x00008000
+	_DS_IS_FLAT_NAME                  dsGetDcNameFlag = 0x00010000
+	_DS_IS_DNS_NAME                   dsGetDcNameFlag = 0x00020000
+	_DS_TRY_NEXTCLOSEST_SITE          dsGetDcNameFlag = 0x00040000
+	_DS_DIRECTORY_SERVICE_6_REQUIRED  dsGetDcNameFlag = 0x00080000
+	_DS_WEB_SERVICE_REQUIRED          dsGetDcNameFlag = 0x00100000
+	_DS_DIRECTORY_SERVICE_8_REQUIRED  dsGetDcNameFlag = 0x00200000
+	_DS_DIRECTORY_SERVICE_9_REQUIRED  dsGetDcNameFlag = 0x00400000
+	_DS_DIRECTORY_SERVICE_10_REQUIRED dsGetDcNameFlag = 0x00800000
+	_DS_KEY_LIST_SUPPORT_REQUIRED     dsGetDcNameFlag = 0x01000000
+	_DS_RETURN_DNS_NAME               dsGetDcNameFlag = 0x40000000
+	_DS_RETURN_FLAT_NAME              dsGetDcNameFlag = 0x80000000
+)
+
+func resolveDomainController(domainName *uint16, domainGUID *windows.GUID) (*_DOMAIN_CONTROLLER_INFO, error) {
+	const flags = _DS_DIRECTORY_SERVICE_REQUIRED | _DS_IS_FLAT_NAME | _DS_RETURN_DNS_NAME
+	var dcInfo *_DOMAIN_CONTROLLER_INFO
+	if err := dsGetDcName(nil, domainName, domainGUID, nil, flags, &dcInfo); err != nil {
+		return nil, err
+	}
+	return dcInfo, nil
+}
+
+// ResolveDomainController resolves the DNS name of the nearest available
+// domain controller for the domain specified by domainName.
+func ResolveDomainController(domainName string) (string, error) {
+	domainName16, err := windows.UTF16PtrFromString(domainName)
+	if err != nil {
+		return "", err
+	}
+
+	dcInfo, err := resolveDomainController(domainName16, nil)
+	if err != nil {
+		return "", err
+	}
+	defer dcInfo.Close()
+
+	return windows.UTF16PtrToString(dcInfo.DomainControllerName), nil
+}
+
+type _NETSETUP_NAME_TYPE int32
+
+const (
+	_NetSetupUnknown           _NETSETUP_NAME_TYPE = 0
+	_NetSetupMachine           _NETSETUP_NAME_TYPE = 1
+	_NetSetupWorkgroup         _NETSETUP_NAME_TYPE = 2
+	_NetSetupDomain            _NETSETUP_NAME_TYPE = 3
+	_NetSetupNonExistentDomain _NETSETUP_NAME_TYPE = 4
+	_NetSetupDnsMachine        _NETSETUP_NAME_TYPE = 5
+)
+
+func isDomainName(name *uint16) (bool, error) {
+	err := netValidateName(nil, name, nil, nil, _NetSetupDomain)
+	switch err {
+	case nil:
+		return true, nil
+	case windows.ERROR_NO_SUCH_DOMAIN:
+		return false, nil
+	default:
+		return false, err
+	}
+}
+
+// IsDomainName checks whether name represents an existing domain reachable by
+// the current machine.
+func IsDomainName(name string) (bool, error) {
+	name16, err := windows.UTF16PtrFromString(name)
+	if err != nil {
+		return false, err
+	}
+
+	return isDomainName(name16)
+}

+ 19 - 0
util/winutil/zsyscall_windows.go

@@ -42,12 +42,15 @@ func errnoErr(e syscall.Errno) error {
 var (
 	modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
 	modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
+	modnetapi32 = windows.NewLazySystemDLL("netapi32.dll")
 	modrstrtmgr = windows.NewLazySystemDLL("rstrtmgr.dll")
 	moduserenv  = windows.NewLazySystemDLL("userenv.dll")
 
 	procQueryServiceConfig2W             = modadvapi32.NewProc("QueryServiceConfig2W")
 	procGetApplicationRestartSettings    = modkernel32.NewProc("GetApplicationRestartSettings")
 	procRegisterApplicationRestart       = modkernel32.NewProc("RegisterApplicationRestart")
+	procDsGetDcNameW                     = modnetapi32.NewProc("DsGetDcNameW")
+	procNetValidateName                  = modnetapi32.NewProc("NetValidateName")
 	procRmEndSession                     = modrstrtmgr.NewProc("RmEndSession")
 	procRmGetList                        = modrstrtmgr.NewProc("RmGetList")
 	procRmJoinSession                    = modrstrtmgr.NewProc("RmJoinSession")
@@ -78,6 +81,22 @@ func registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret w
 	return
 }
 
+func dsGetDcName(computerName *uint16, domainName *uint16, domainGuid *windows.GUID, siteName *uint16, flags dsGetDcNameFlag, dcInfo **_DOMAIN_CONTROLLER_INFO) (ret error) {
+	r0, _, _ := syscall.Syscall6(procDsGetDcNameW.Addr(), 6, uintptr(unsafe.Pointer(computerName)), uintptr(unsafe.Pointer(domainName)), uintptr(unsafe.Pointer(domainGuid)), uintptr(unsafe.Pointer(siteName)), uintptr(flags), uintptr(unsafe.Pointer(dcInfo)))
+	if r0 != 0 {
+		ret = syscall.Errno(r0)
+	}
+	return
+}
+
+func netValidateName(server *uint16, name *uint16, account *uint16, password *uint16, nameType _NETSETUP_NAME_TYPE) (ret error) {
+	r0, _, _ := syscall.Syscall6(procNetValidateName.Addr(), 5, uintptr(unsafe.Pointer(server)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(account)), uintptr(unsafe.Pointer(password)), uintptr(nameType), 0)
+	if r0 != 0 {
+		ret = syscall.Errno(r0)
+	}
+	return
+}
+
 func rmEndSession(session _RMHANDLE) (ret error) {
 	r0, _, _ := syscall.Syscall(procRmEndSession.Addr(), 1, uintptr(session), 0, 0)
 	if r0 != 0 {