Просмотр исходного кода

util/winutil: update UserProfile to ensure any environment variables in the roaming profile path are expanded

Updates #12383

Signed-off-by: Aaron Klotz <[email protected]>
Aaron Klotz 1 год назад
Родитель
Сommit
7354547bd8
3 измененных файлов с 30 добавлено и 21 удалено
  1. 1 0
      util/winutil/mksyscall.go
  2. 10 11
      util/winutil/userprofile_windows.go
  3. 19 10
      util/winutil/zsyscall_windows.go

+ 1 - 0
util/winutil/mksyscall.go

@@ -6,6 +6,7 @@ package winutil
 //go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go
 //go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go
 
+//sys expandEnvironmentStringsForUser(token windows.Token, src *uint16, dst *uint16, dstLen uint32) (err error) [int32(failretval)==0] = userenv.ExpandEnvironmentStringsForUserW
 //sys getApplicationRestartSettings(process windows.Handle, commandLine *uint16, commandLineLen *uint32, flags *uint32) (ret wingoes.HRESULT) = kernel32.GetApplicationRestartSettings
 //sys loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) [int32(failretval)==0] = userenv.LoadUserProfileW
 //sys queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) [failretval==0] = advapi32.QueryServiceConfig2W

+ 10 - 11
util/winutil/userprofile_windows.go

@@ -80,7 +80,7 @@ func LoadUserProfile(token windows.Token, u *user.User) (up *UserProfile, err er
 
 	var roamingProfilePath *uint16
 	if winenv.IsDomainJoined() {
-		roamingProfilePath, err = getRoamingProfilePath(nil, computerName, userName)
+		roamingProfilePath, err = getRoamingProfilePath(nil, token, computerName, userName)
 		if err != nil {
 			return nil, err
 		}
@@ -134,7 +134,7 @@ func (up *UserProfile) Close() error {
 	return nil
 }
 
-func getRoamingProfilePath(logf logger.Logf, computerName, userName *uint16) (path *uint16, err error) {
+func getRoamingProfilePath(logf logger.Logf, token windows.Token, computerName, userName *uint16) (path *uint16, err error) {
 	// logf is for debugging/testing.
 	if logf == nil {
 		logf = logger.Discard
@@ -152,19 +152,18 @@ func getRoamingProfilePath(logf logger.Logf, computerName, userName *uint16) (pa
 	if profilePath == nil {
 		return nil, nil
 	}
-
-	var sz int
-	for ptr := unsafe.Pointer(profilePath); *(*uint16)(ptr) != 0; sz++ {
-		ptr = unsafe.Pointer(uintptr(ptr) + unsafe.Sizeof(*profilePath))
+	if *profilePath == 0 {
+		// Empty string
+		return nil, nil
 	}
 
-	if sz == 0 {
-		return nil, nil
+	var expanded [windows.MAX_PATH + 1]uint16
+	if err := expandEnvironmentStringsForUser(token, profilePath, &expanded[0], uint32(len(expanded))); err != nil {
+		return nil, err
 	}
 
-	buf := unsafe.Slice(profilePath, sz+1)
-	cp := append([]uint16{}, buf...)
-	return unsafe.SliceData(cp), nil
+	// This buffer is only used briefly, so we don't bother copying it into a shorter slice.
+	return &expanded[0], nil
 }
 
 func getComputerAndUserName(token windows.Token, u *user.User) (computerName *uint16, userName *uint16, err error) {

+ 19 - 10
util/winutil/zsyscall_windows.go

@@ -45,16 +45,17 @@ var (
 	modrstrtmgr = windows.NewLazySystemDLL("rstrtmgr.dll")
 	moduserenv  = windows.NewLazySystemDLL("userenv.dll")
 
-	procQueryServiceConfig2W          = modadvapi32.NewProc("QueryServiceConfig2W")
-	procGetApplicationRestartSettings = modkernel32.NewProc("GetApplicationRestartSettings")
-	procRegisterApplicationRestart    = modkernel32.NewProc("RegisterApplicationRestart")
-	procRmEndSession                  = modrstrtmgr.NewProc("RmEndSession")
-	procRmGetList                     = modrstrtmgr.NewProc("RmGetList")
-	procRmJoinSession                 = modrstrtmgr.NewProc("RmJoinSession")
-	procRmRegisterResources           = modrstrtmgr.NewProc("RmRegisterResources")
-	procRmStartSession                = modrstrtmgr.NewProc("RmStartSession")
-	procLoadUserProfileW              = moduserenv.NewProc("LoadUserProfileW")
-	procUnloadUserProfile             = moduserenv.NewProc("UnloadUserProfile")
+	procQueryServiceConfig2W             = modadvapi32.NewProc("QueryServiceConfig2W")
+	procGetApplicationRestartSettings    = modkernel32.NewProc("GetApplicationRestartSettings")
+	procRegisterApplicationRestart       = modkernel32.NewProc("RegisterApplicationRestart")
+	procRmEndSession                     = modrstrtmgr.NewProc("RmEndSession")
+	procRmGetList                        = modrstrtmgr.NewProc("RmGetList")
+	procRmJoinSession                    = modrstrtmgr.NewProc("RmJoinSession")
+	procRmRegisterResources              = modrstrtmgr.NewProc("RmRegisterResources")
+	procRmStartSession                   = modrstrtmgr.NewProc("RmStartSession")
+	procExpandEnvironmentStringsForUserW = moduserenv.NewProc("ExpandEnvironmentStringsForUserW")
+	procLoadUserProfileW                 = moduserenv.NewProc("LoadUserProfileW")
+	procUnloadUserProfile                = moduserenv.NewProc("UnloadUserProfile")
 )
 
 func queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) {
@@ -117,6 +118,14 @@ func rmStartSession(pSession *_RMHANDLE, flags uint32, sessionKey *uint16) (ret
 	return
 }
 
+func expandEnvironmentStringsForUser(token windows.Token, src *uint16, dst *uint16, dstLen uint32) (err error) {
+	r1, _, e1 := syscall.Syscall6(procExpandEnvironmentStringsForUserW.Addr(), 4, uintptr(token), uintptr(unsafe.Pointer(src)), uintptr(unsafe.Pointer(dst)), uintptr(dstLen), 0, 0)
+	if int32(r1) == 0 {
+		err = errnoErr(e1)
+	}
+	return
+}
+
 func loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) {
 	r1, _, e1 := syscall.Syscall(procLoadUserProfileW.Addr(), 2, uintptr(token), uintptr(unsafe.Pointer(profileInfo)), 0)
 	if int32(r1) == 0 {