winutil_windows.go 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package winutil
  4. import (
  5. "errors"
  6. "fmt"
  7. "log"
  8. "math"
  9. "os/exec"
  10. "os/user"
  11. "reflect"
  12. "runtime"
  13. "strings"
  14. "syscall"
  15. "time"
  16. "unsafe"
  17. "golang.org/x/exp/constraints"
  18. "golang.org/x/sys/windows"
  19. "golang.org/x/sys/windows/registry"
  20. )
  21. const (
  22. regBase = `SOFTWARE\Tailscale IPN`
  23. regPolicyBase = `SOFTWARE\Policies\Tailscale`
  24. )
  25. // ErrNoShell is returned when the shell process is not found.
  26. var ErrNoShell = errors.New("no Shell process is present")
  27. // ErrNoValue is returned when the value doesn't exist in the registry.
  28. var ErrNoValue = registry.ErrNotExist
  29. // GetDesktopPID searches the PID of the process that's running the
  30. // currently active desktop. Returns ErrNoShell if the shell is not present.
  31. // Usually the PID will be for explorer.exe.
  32. func GetDesktopPID() (uint32, error) {
  33. hwnd := windows.GetShellWindow()
  34. if hwnd == 0 {
  35. return 0, ErrNoShell
  36. }
  37. var pid uint32
  38. windows.GetWindowThreadProcessId(hwnd, &pid)
  39. if pid == 0 {
  40. return 0, fmt.Errorf("invalid PID for HWND %v", hwnd)
  41. }
  42. return pid, nil
  43. }
  44. func getPolicyString(name string) (string, error) {
  45. s, err := getRegStringInternal(registry.LOCAL_MACHINE, regPolicyBase, name)
  46. if err != nil {
  47. // Fall back to the legacy path
  48. return getRegString(name)
  49. }
  50. return s, err
  51. }
  52. func getPolicyStringArray(name string) ([]string, error) {
  53. return getRegStringsInternal(regPolicyBase, name)
  54. }
  55. func getRegString(name string) (string, error) {
  56. s, err := getRegStringInternal(registry.LOCAL_MACHINE, regBase, name)
  57. if err != nil {
  58. return "", err
  59. }
  60. return s, err
  61. }
  62. func getPolicyInteger(name string) (uint64, error) {
  63. i, err := getRegIntegerInternal(regPolicyBase, name)
  64. if err != nil {
  65. // Fall back to the legacy path
  66. return getRegInteger(name)
  67. }
  68. return i, err
  69. }
  70. func getRegInteger(name string) (uint64, error) {
  71. i, err := getRegIntegerInternal(regBase, name)
  72. if err != nil {
  73. return 0, err
  74. }
  75. return i, err
  76. }
  77. func getRegStringInternal(key registry.Key, subKey, name string) (string, error) {
  78. key, err := registry.OpenKey(key, subKey, registry.READ)
  79. if err != nil {
  80. if err != ErrNoValue {
  81. log.Printf("registry.OpenKey(%v): %v", subKey, err)
  82. }
  83. return "", err
  84. }
  85. defer key.Close()
  86. val, _, err := key.GetStringValue(name)
  87. if err != nil {
  88. if err != ErrNoValue {
  89. log.Printf("registry.GetStringValue(%v): %v", name, err)
  90. }
  91. return "", err
  92. }
  93. return val, nil
  94. }
  95. // GetRegUserString looks up a registry path in the current user key, or returns
  96. // an empty string and error.
  97. func GetRegUserString(name string) (string, error) {
  98. return getRegStringInternal(registry.CURRENT_USER, regBase, name)
  99. }
  100. // SetRegUserString sets a SZ value identified by name in the current user key
  101. // to the string specified by value.
  102. func SetRegUserString(name, value string) error {
  103. key, _, err := registry.CreateKey(registry.CURRENT_USER, regBase, registry.SET_VALUE)
  104. if err != nil {
  105. log.Printf("registry.CreateKey(%v): %v", regBase, err)
  106. }
  107. defer key.Close()
  108. return key.SetStringValue(name, value)
  109. }
  110. // GetRegStrings looks up a registry value in the local machine path, or returns
  111. // the given default if it can't.
  112. func GetRegStrings(name string, defval []string) []string {
  113. s, err := getRegStringsInternal(regBase, name)
  114. if err != nil {
  115. return defval
  116. }
  117. return s
  118. }
  119. func getRegStringsInternal(subKey, name string) ([]string, error) {
  120. key, err := registry.OpenKey(registry.LOCAL_MACHINE, subKey, registry.READ)
  121. if err != nil {
  122. if err != ErrNoValue {
  123. log.Printf("registry.OpenKey(%v): %v", subKey, err)
  124. }
  125. return nil, err
  126. }
  127. defer key.Close()
  128. val, _, err := key.GetStringsValue(name)
  129. if err != nil {
  130. if err != ErrNoValue {
  131. log.Printf("registry.GetStringValue(%v): %v", name, err)
  132. }
  133. return nil, err
  134. }
  135. return val, nil
  136. }
  137. // SetRegStrings sets a MULTI_SZ value in the in the local machine path
  138. // to the strings specified by values.
  139. func SetRegStrings(name string, values []string) error {
  140. return setRegStringsInternal(regBase, name, values)
  141. }
  142. func setRegStringsInternal(subKey, name string, values []string) error {
  143. key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, subKey, registry.SET_VALUE)
  144. if err != nil {
  145. log.Printf("registry.CreateKey(%v): %v", subKey, err)
  146. }
  147. defer key.Close()
  148. return key.SetStringsValue(name, values)
  149. }
  150. // DeleteRegValue removes a registry value in the local machine path.
  151. func DeleteRegValue(name string) error {
  152. return deleteRegValueInternal(regBase, name)
  153. }
  154. func deleteRegValueInternal(subKey, name string) error {
  155. key, err := registry.OpenKey(registry.LOCAL_MACHINE, subKey, registry.SET_VALUE)
  156. if err == ErrNoValue {
  157. return nil
  158. }
  159. if err != nil {
  160. log.Printf("registry.OpenKey(%v): %v", subKey, err)
  161. return err
  162. }
  163. defer key.Close()
  164. err = key.DeleteValue(name)
  165. if err == ErrNoValue {
  166. err = nil
  167. }
  168. return err
  169. }
  170. func getRegIntegerInternal(subKey, name string) (uint64, error) {
  171. key, err := registry.OpenKey(registry.LOCAL_MACHINE, subKey, registry.READ)
  172. if err != nil {
  173. if err != ErrNoValue {
  174. log.Printf("registry.OpenKey(%v): %v", subKey, err)
  175. }
  176. return 0, err
  177. }
  178. defer key.Close()
  179. val, _, err := key.GetIntegerValue(name)
  180. if err != nil {
  181. if err != ErrNoValue {
  182. log.Printf("registry.GetIntegerValue(%v): %v", name, err)
  183. }
  184. return 0, err
  185. }
  186. return val, nil
  187. }
  188. var (
  189. kernel32 = syscall.NewLazyDLL("kernel32.dll")
  190. procWTSGetActiveConsoleSessionId = kernel32.NewProc("WTSGetActiveConsoleSessionId")
  191. )
  192. // TODO(crawshaw): replace with x/sys/windows... one day.
  193. // https://go-review.googlesource.com/c/sys/+/331909
  194. func WTSGetActiveConsoleSessionId() uint32 {
  195. r1, _, _ := procWTSGetActiveConsoleSessionId.Call()
  196. return uint32(r1)
  197. }
  198. func isSIDValidPrincipal(uid string) bool {
  199. usid, err := syscall.StringToSid(uid)
  200. if err != nil {
  201. return false
  202. }
  203. _, _, accType, err := usid.LookupAccount("")
  204. if err != nil {
  205. return false
  206. }
  207. switch accType {
  208. case syscall.SidTypeUser, syscall.SidTypeGroup, syscall.SidTypeDomain, syscall.SidTypeAlias, syscall.SidTypeWellKnownGroup, syscall.SidTypeComputer:
  209. return true
  210. default:
  211. // Reject deleted users, invalid SIDs, unknown SIDs, mandatory label SIDs, etc.
  212. return false
  213. }
  214. }
  215. // EnableCurrentThreadPrivilege enables the named privilege
  216. // in the current thread's access token. The current goroutine is also locked to
  217. // the OS thread (runtime.LockOSThread). Callers must call the returned disable
  218. // function when done with the privileged task.
  219. func EnableCurrentThreadPrivilege(name string) (disable func(), err error) {
  220. return EnableCurrentThreadPrivileges([]string{name})
  221. }
  222. // EnableCurrentThreadPrivileges enables the named privileges
  223. // in the current thread's access token. The current goroutine is also locked to
  224. // the OS thread (runtime.LockOSThread). Callers must call the returned disable
  225. // function when done with the privileged task.
  226. func EnableCurrentThreadPrivileges(names []string) (disable func(), err error) {
  227. runtime.LockOSThread()
  228. if len(names) == 0 {
  229. // Nothing to enable; no-op isn't really an error...
  230. return runtime.UnlockOSThread, nil
  231. }
  232. if err := windows.ImpersonateSelf(windows.SecurityImpersonation); err != nil {
  233. runtime.UnlockOSThread()
  234. return nil, err
  235. }
  236. disable = func() {
  237. defer runtime.UnlockOSThread()
  238. // If RevertToSelf fails, it's not really recoverable and we should panic.
  239. // Failure to do so would leak the privileges we're enabling, which is a
  240. // security issue.
  241. if err := windows.RevertToSelf(); err != nil {
  242. panic(fmt.Sprintf("RevertToSelf failed: %v", err))
  243. }
  244. }
  245. defer func() {
  246. if err != nil {
  247. disable()
  248. }
  249. }()
  250. var t windows.Token
  251. err = windows.OpenThreadToken(windows.CurrentThread(),
  252. windows.TOKEN_QUERY|windows.TOKEN_ADJUST_PRIVILEGES, false, &t)
  253. if err != nil {
  254. return nil, err
  255. }
  256. defer t.Close()
  257. tp := newTokenPrivileges(len(names))
  258. privs := tp.AllPrivileges()
  259. for i := range privs {
  260. var privStr *uint16
  261. privStr, err = windows.UTF16PtrFromString(names[i])
  262. if err != nil {
  263. return nil, err
  264. }
  265. err = windows.LookupPrivilegeValue(nil, privStr, &privs[i].Luid)
  266. if err != nil {
  267. return nil, err
  268. }
  269. privs[i].Attributes = windows.SE_PRIVILEGE_ENABLED
  270. }
  271. err = windows.AdjustTokenPrivileges(t, false, tp, 0, nil, nil)
  272. if err != nil {
  273. return nil, err
  274. }
  275. return disable, nil
  276. }
  277. func newTokenPrivileges(numPrivs int) *windows.Tokenprivileges {
  278. if numPrivs <= 0 {
  279. panic("numPrivs must be > 0")
  280. }
  281. numBytes := unsafe.Sizeof(windows.Tokenprivileges{}) + (uintptr(numPrivs-1) * unsafe.Sizeof(windows.LUIDAndAttributes{}))
  282. buf := make([]byte, numBytes)
  283. result := (*windows.Tokenprivileges)(unsafe.Pointer(unsafe.SliceData(buf)))
  284. result.PrivilegeCount = uint32(numPrivs)
  285. return result
  286. }
  287. // StartProcessAsChild starts exePath process as a child of parentPID.
  288. // StartProcessAsChild copies parentPID's environment variables into
  289. // the new process, along with any optional environment variables in extraEnv.
  290. func StartProcessAsChild(parentPID uint32, exePath string, extraEnv []string) error {
  291. // The rest of this function requires SeDebugPrivilege to be held.
  292. //
  293. // According to https://docs.microsoft.com/en-us/windows/win32/procthread/process-security-and-access-rights
  294. //
  295. // ... To open a handle to another process and obtain full access rights,
  296. // you must enable the SeDebugPrivilege privilege. ...
  297. //
  298. // But we only need PROCESS_CREATE_PROCESS. So perhaps SeDebugPrivilege is too much.
  299. //
  300. // https://devblogs.microsoft.com/oldnewthing/20080314-00/?p=23113
  301. //
  302. // TODO: try look for something less than SeDebugPrivilege
  303. disableSeDebug, err := EnableCurrentThreadPrivilege("SeDebugPrivilege")
  304. if err != nil {
  305. return err
  306. }
  307. defer disableSeDebug()
  308. ph, err := windows.OpenProcess(
  309. windows.PROCESS_CREATE_PROCESS|windows.PROCESS_QUERY_INFORMATION|windows.PROCESS_DUP_HANDLE,
  310. false, parentPID)
  311. if err != nil {
  312. return err
  313. }
  314. defer windows.CloseHandle(ph)
  315. var pt windows.Token
  316. err = windows.OpenProcessToken(ph, windows.TOKEN_QUERY, &pt)
  317. if err != nil {
  318. return err
  319. }
  320. defer pt.Close()
  321. env, err := pt.Environ(false)
  322. if err != nil {
  323. return err
  324. }
  325. env = append(env, extraEnv...)
  326. sys := &syscall.SysProcAttr{ParentProcess: syscall.Handle(ph)}
  327. cmd := exec.Command(exePath)
  328. cmd.Env = env
  329. cmd.SysProcAttr = sys
  330. return cmd.Start()
  331. }
  332. // StartProcessAsCurrentGUIUser is like StartProcessAsChild, but if finds
  333. // current logged in user desktop process (normally explorer.exe),
  334. // and passes found PID to StartProcessAsChild.
  335. func StartProcessAsCurrentGUIUser(exePath string, extraEnv []string) error {
  336. // as described in https://devblogs.microsoft.com/oldnewthing/20190425-00/?p=102443
  337. desktop, err := GetDesktopPID()
  338. if err != nil {
  339. return fmt.Errorf("failed to find desktop: %v", err)
  340. }
  341. err = StartProcessAsChild(desktop, exePath, extraEnv)
  342. if err != nil {
  343. return fmt.Errorf("failed to start executable: %v", err)
  344. }
  345. return nil
  346. }
  347. // CreateAppMutex creates a named Windows mutex, returning nil if the mutex
  348. // is created successfully or an error if the mutex already exists or could not
  349. // be created for some other reason.
  350. func CreateAppMutex(name string) (windows.Handle, error) {
  351. return windows.CreateMutex(nil, false, windows.StringToUTF16Ptr(name))
  352. }
  353. // getTokenInfoFixedLen obtains known fixed-length token information. Use this
  354. // function for information classes that output enumerations, BOOLs, integers etc.
  355. func getTokenInfoFixedLen[T any](token windows.Token, infoClass uint32) (result T, err error) {
  356. var actualLen uint32
  357. p := (*byte)(unsafe.Pointer(&result))
  358. err = windows.GetTokenInformation(token, infoClass, p, uint32(unsafe.Sizeof(result)), &actualLen)
  359. return result, err
  360. }
  361. type tokenElevationType int32
  362. const (
  363. tokenElevationTypeDefault tokenElevationType = 1
  364. tokenElevationTypeFull tokenElevationType = 2
  365. tokenElevationTypeLimited tokenElevationType = 3
  366. )
  367. // IsTokenLimited returns whether token is a limited UAC token.
  368. func IsTokenLimited(token windows.Token) (bool, error) {
  369. elevationType, err := getTokenInfoFixedLen[tokenElevationType](token, windows.TokenElevationType)
  370. if err != nil {
  371. return false, err
  372. }
  373. return elevationType == tokenElevationTypeLimited, nil
  374. }
  375. // UserSIDs contains the SIDs for a Windows NT token object's associated user
  376. // as well as its primary group.
  377. type UserSIDs struct {
  378. User *windows.SID
  379. PrimaryGroup *windows.SID
  380. }
  381. // GetCurrentUserSIDs returns a UserSIDs struct containing SIDs for the
  382. // current process' user and primary group.
  383. func GetCurrentUserSIDs() (*UserSIDs, error) {
  384. token, err := windows.OpenCurrentProcessToken()
  385. if err != nil {
  386. return nil, err
  387. }
  388. defer token.Close()
  389. userInfo, err := token.GetTokenUser()
  390. if err != nil {
  391. return nil, err
  392. }
  393. primaryGroup, err := token.GetTokenPrimaryGroup()
  394. if err != nil {
  395. return nil, err
  396. }
  397. return &UserSIDs{userInfo.User.Sid, primaryGroup.PrimaryGroup}, nil
  398. }
  399. // IsCurrentProcessElevated returns true when the current process is
  400. // running with an elevated token, implying Administrator access.
  401. func IsCurrentProcessElevated() bool {
  402. token, err := windows.OpenCurrentProcessToken()
  403. if err != nil {
  404. return false
  405. }
  406. defer token.Close()
  407. return token.IsElevated()
  408. }
  409. // keyOpenTimeout is how long we wait for a registry key to appear. For some
  410. // reason, registry keys tied to ephemeral interfaces can take a long while to
  411. // appear after interface creation, and we can end up racing with that.
  412. const keyOpenTimeout = 20 * time.Second
  413. // RegistryPath represents a path inside a root registry.Key.
  414. type RegistryPath string
  415. // RegistryPathPrefix specifies a RegistryPath prefix that must be suffixed with
  416. // another RegistryPath to make a valid RegistryPath.
  417. type RegistryPathPrefix string
  418. // WithSuffix returns a RegistryPath with the given suffix appended.
  419. func (p RegistryPathPrefix) WithSuffix(suf string) RegistryPath {
  420. return RegistryPath(string(p) + suf)
  421. }
  422. const (
  423. IPv4TCPIPBase RegistryPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters`
  424. IPv6TCPIPBase RegistryPath = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters`
  425. NetBTBase RegistryPath = `SYSTEM\CurrentControlSet\Services\NetBT\Parameters`
  426. IPv4TCPIPInterfacePrefix RegistryPathPrefix = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
  427. IPv6TCPIPInterfacePrefix RegistryPathPrefix = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
  428. NetBTInterfacePrefix RegistryPathPrefix = `SYSTEM\CurrentControlSet\Services\NetBT\Parameters\Interfaces\Tcpip_`
  429. )
  430. // ErrKeyWaitTimeout is returned by OpenKeyWait when calls timeout.
  431. var ErrKeyWaitTimeout = errors.New("timeout waiting for registry key")
  432. // OpenKeyWait opens a registry key, waiting for it to appear if necessary. It
  433. // returns the opened key, or ErrKeyWaitTimeout if the key does not appear
  434. // within 20s. The caller must call Close on the returned key.
  435. func OpenKeyWait(k registry.Key, path RegistryPath, access uint32) (registry.Key, error) {
  436. runtime.LockOSThread()
  437. defer runtime.UnlockOSThread()
  438. deadline := time.Now().Add(keyOpenTimeout)
  439. pathSpl := strings.Split(string(path), "\\")
  440. for i := 0; ; i++ {
  441. keyName := pathSpl[i]
  442. isLast := i+1 == len(pathSpl)
  443. event, err := windows.CreateEvent(nil, 0, 0, nil)
  444. if err != nil {
  445. return 0, fmt.Errorf("windows.CreateEvent: %w", err)
  446. }
  447. defer windows.CloseHandle(event)
  448. var key registry.Key
  449. for {
  450. err = windows.RegNotifyChangeKeyValue(windows.Handle(k), false, windows.REG_NOTIFY_CHANGE_NAME, event, true)
  451. if err != nil {
  452. return 0, fmt.Errorf("windows.RegNotifyChangeKeyValue: %w", err)
  453. }
  454. var accessFlags uint32
  455. if isLast {
  456. accessFlags = access
  457. } else {
  458. accessFlags = registry.NOTIFY
  459. }
  460. key, err = registry.OpenKey(k, keyName, accessFlags)
  461. if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND {
  462. timeout := time.Until(deadline) / time.Millisecond
  463. if timeout < 0 {
  464. timeout = 0
  465. }
  466. s, err := windows.WaitForSingleObject(event, uint32(timeout))
  467. if err != nil {
  468. return 0, fmt.Errorf("windows.WaitForSingleObject: %w", err)
  469. }
  470. if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows
  471. return 0, ErrKeyWaitTimeout
  472. }
  473. } else if err != nil {
  474. return 0, fmt.Errorf("registry.OpenKey(%v): %w", path, err)
  475. } else {
  476. if isLast {
  477. return key, nil
  478. }
  479. defer key.Close()
  480. break
  481. }
  482. }
  483. k = key
  484. }
  485. }
  486. func lookupPseudoUser(uid string) (*user.User, error) {
  487. sid, err := windows.StringToSid(uid)
  488. if err != nil {
  489. return nil, err
  490. }
  491. // We're looking for SIDs "S-1-5-x" where 17 <= x <= 20.
  492. // This is checking for the the "5"
  493. if sid.IdentifierAuthority() != windows.SECURITY_NT_AUTHORITY {
  494. return nil, fmt.Errorf(`SID %q does not use "NT AUTHORITY"`, uid)
  495. }
  496. // This is ensuring that there is only one sub-authority.
  497. // In other words, only one value after the "5".
  498. if sid.SubAuthorityCount() != 1 {
  499. return nil, fmt.Errorf("SID %q should have only one subauthority", uid)
  500. }
  501. // Get that sub-authority value (this is "x" above) and check it.
  502. rid := sid.SubAuthority(0)
  503. if rid < 17 || rid > 20 {
  504. return nil, fmt.Errorf("SID %q does not represent a known pseudo-user", uid)
  505. }
  506. // We've got one of the known pseudo-users. Look up the localized name of the
  507. // account.
  508. username, domain, _, err := sid.LookupAccount("")
  509. if err != nil {
  510. return nil, err
  511. }
  512. // This call is best-effort. If it fails, homeDir will be empty.
  513. homeDir, _ := findHomeDirInRegistry(uid)
  514. result := &user.User{
  515. Uid: uid,
  516. Gid: uid, // Gid == Uid with these accounts.
  517. Username: fmt.Sprintf(`%s\%s`, domain, username),
  518. Name: username,
  519. HomeDir: homeDir,
  520. }
  521. return result, nil
  522. }
  523. // findHomeDirInRegistry finds the user home path based on the uid.
  524. // This is borrowed from Go's std lib.
  525. func findHomeDirInRegistry(uid string) (dir string, err error) {
  526. k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion\ProfileList\`+uid, registry.QUERY_VALUE)
  527. if err != nil {
  528. return "", err
  529. }
  530. defer k.Close()
  531. dir, _, err = k.GetStringValue("ProfileImagePath")
  532. if err != nil {
  533. return "", err
  534. }
  535. return dir, nil
  536. }
  537. // ProcessImageName returns the fully-qualified path to the executable image
  538. // associated with process.
  539. func ProcessImageName(process windows.Handle) (string, error) {
  540. var pathBuf [windows.MAX_PATH]uint16
  541. pathBufLen := uint32(len(pathBuf))
  542. if err := windows.QueryFullProcessImageName(process, 0, &pathBuf[0], &pathBufLen); err != nil {
  543. return "", err
  544. }
  545. return windows.UTF16ToString(pathBuf[:pathBufLen]), nil
  546. }
  547. // TSSessionIDToLogonSessionID retrieves the logon session ID associated with
  548. // tsSessionId, which is a Terminal Services / RDP session ID. The calling
  549. // process must be running as LocalSystem.
  550. func TSSessionIDToLogonSessionID(tsSessionID uint32) (logonSessionID windows.LUID, err error) {
  551. var token windows.Token
  552. if err := windows.WTSQueryUserToken(tsSessionID, &token); err != nil {
  553. return logonSessionID, fmt.Errorf("WTSQueryUserToken: %w", err)
  554. }
  555. defer token.Close()
  556. return LogonSessionID(token)
  557. }
  558. // TSSessionID obtains the Terminal Services (RDP) session ID associated with token.
  559. func TSSessionID(token windows.Token) (tsSessionID uint32, err error) {
  560. return getTokenInfoFixedLen[uint32](token, windows.TokenSessionId)
  561. }
  562. type tokenOrigin struct {
  563. originatingLogonSession windows.LUID
  564. }
  565. // LogonSessionID obtains the logon session ID associated with token.
  566. func LogonSessionID(token windows.Token) (logonSessionID windows.LUID, err error) {
  567. origin, err := getTokenInfoFixedLen[tokenOrigin](token, windows.TokenOrigin)
  568. if err != nil {
  569. return logonSessionID, err
  570. }
  571. return origin.originatingLogonSession, nil
  572. }
  573. // BufUnit is a type constraint for buffers passed into AllocateContiguousBuffer
  574. // and SetNTString.
  575. type BufUnit interface {
  576. byte | uint16
  577. }
  578. // AllocateContiguousBuffer allocates memory to satisfy the Windows idiom where
  579. // some structs contain pointers that are expected to refer to memory within the
  580. // same buffer containing the struct itself. T is the type that contains
  581. // the pointers. values must contain the actual data that is to be copied
  582. // into the buffer after T. AllocateContiguousBuffer returns a pointer to the
  583. // struct, the total length of the buffer in bytes, and a slice containing
  584. // each value within the buffer. The caller may use slcs to populate any
  585. // pointers in t as needed. Each element of slcs corresponds to the element of
  586. // values in the same position.
  587. //
  588. // It is the responsibility of the caller to ensure that any values expected
  589. // to contain null-terminated strings are in fact null-terminated!
  590. //
  591. // AllocateContiguousBuffer panics if no values are passed in, as there are
  592. // better alternatives for allocating a struct in that case.
  593. func AllocateContiguousBuffer[T any, BU BufUnit](values ...[]BU) (t *T, tLenBytes uint32, slcs [][]BU) {
  594. if len(values) == 0 {
  595. panic("len(values) must be > 0")
  596. }
  597. // Get the sizes of T and BU, then compute a preferred alignment for T.
  598. tT := reflect.TypeFor[T]()
  599. szT := tT.Size()
  600. szBU := int(unsafe.Sizeof(BU(0)))
  601. alignment := max(tT.Align(), szBU)
  602. // Our buffers for values will start at the next szBU boundary.
  603. tLenBytes = alignUp(uint32(szT), szBU)
  604. firstValueOffset := tLenBytes
  605. // Accumulate the length of each value into tLenBytes
  606. for _, v := range values {
  607. tLenBytes += uint32(len(v) * szBU)
  608. }
  609. // Now that we know the final length, align up to our preferred boundary.
  610. tLenBytes = alignUp(tLenBytes, alignment)
  611. // Allocate the buffer. We choose a type for the slice that is appropriate
  612. // for the desired alignment. Note that we do not have a strict requirement
  613. // that T contain pointer fields; we could just be appending more data
  614. // within the same buffer.
  615. bufLen := tLenBytes / uint32(alignment)
  616. var pt unsafe.Pointer
  617. switch alignment {
  618. case 1:
  619. pt = unsafe.Pointer(unsafe.SliceData(make([]byte, bufLen)))
  620. case 2:
  621. pt = unsafe.Pointer(unsafe.SliceData(make([]uint16, bufLen)))
  622. case 4:
  623. pt = unsafe.Pointer(unsafe.SliceData(make([]uint32, bufLen)))
  624. case 8:
  625. pt = unsafe.Pointer(unsafe.SliceData(make([]uint64, bufLen)))
  626. default:
  627. panic(fmt.Sprintf("bad alignment %d", alignment))
  628. }
  629. t = (*T)(pt)
  630. slcs = make([][]BU, 0, len(values))
  631. // Use the limits of the buffer area after t to construct a slice representing the remaining buffer.
  632. firstValuePtr := unsafe.Pointer(uintptr(pt) + uintptr(firstValueOffset))
  633. buf := unsafe.Slice((*BU)(firstValuePtr), (tLenBytes-firstValueOffset)/uint32(szBU))
  634. // Copy each value into the buffer and record a slice describing each value's limits into slcs.
  635. var index int
  636. for _, v := range values {
  637. if len(v) == 0 {
  638. // We allow zero-length values; we simply append a nil slice.
  639. slcs = append(slcs, nil)
  640. continue
  641. }
  642. valueSlice := buf[index : index+len(v)]
  643. copy(valueSlice, v)
  644. slcs = append(slcs, valueSlice)
  645. index += len(v)
  646. }
  647. return t, tLenBytes, slcs
  648. }
  649. // alignment must be a power of 2
  650. func alignUp[V constraints.Integer](v V, alignment int) V {
  651. return v + ((-v) & (V(alignment) - 1))
  652. }
  653. // NTStr is a type constraint requiring the type to be either a
  654. // windows.NTString or a windows.NTUnicodeString.
  655. type NTStr interface {
  656. windows.NTString | windows.NTUnicodeString
  657. }
  658. // SetNTString sets the value of nts in-place to point to the string contained
  659. // within buf. A nul terminator is optional in buf.
  660. func SetNTString[NTS NTStr, BU BufUnit](nts *NTS, buf []BU) {
  661. isEmpty := len(buf) == 0
  662. codeUnitSize := uint16(unsafe.Sizeof(BU(0)))
  663. lenBytes := len(buf) * int(codeUnitSize)
  664. if lenBytes > math.MaxUint16 {
  665. panic("buffer length must fit into uint16")
  666. }
  667. lenBytes16 := uint16(lenBytes)
  668. switch p := any(nts).(type) {
  669. case *windows.NTString:
  670. if isEmpty {
  671. *p = windows.NTString{}
  672. break
  673. }
  674. p.Buffer = unsafe.SliceData(any(buf).([]byte))
  675. p.MaximumLength = lenBytes16
  676. p.Length = lenBytes16
  677. // account for nul terminator when present
  678. if buf[len(buf)-1] == 0 {
  679. p.Length -= codeUnitSize
  680. }
  681. case *windows.NTUnicodeString:
  682. if isEmpty {
  683. *p = windows.NTUnicodeString{}
  684. break
  685. }
  686. p.Buffer = unsafe.SliceData(any(buf).([]uint16))
  687. p.MaximumLength = lenBytes16
  688. p.Length = lenBytes16
  689. // account for nul terminator when present
  690. if buf[len(buf)-1] == 0 {
  691. p.Length -= codeUnitSize
  692. }
  693. default:
  694. panic("unknown type")
  695. }
  696. }
  697. type domainControllerAddressType uint32
  698. const (
  699. //lint:ignore U1000 maps to a win32 API
  700. _DS_INET_ADDRESS domainControllerAddressType = 1
  701. _DS_NETBIOS_ADDRESS domainControllerAddressType = 2
  702. )
  703. type domainControllerFlag uint32
  704. const (
  705. //lint:ignore U1000 maps to a win32 API
  706. _DS_PDC_FLAG domainControllerFlag = 0x00000001
  707. _DS_GC_FLAG domainControllerFlag = 0x00000004
  708. _DS_LDAP_FLAG domainControllerFlag = 0x00000008
  709. _DS_DS_FLAG domainControllerFlag = 0x00000010
  710. _DS_KDC_FLAG domainControllerFlag = 0x00000020
  711. _DS_TIMESERV_FLAG domainControllerFlag = 0x00000040
  712. _DS_CLOSEST_FLAG domainControllerFlag = 0x00000080
  713. _DS_WRITABLE_FLAG domainControllerFlag = 0x00000100
  714. _DS_GOOD_TIMESERV_FLAG domainControllerFlag = 0x00000200
  715. _DS_NDNC_FLAG domainControllerFlag = 0x00000400
  716. _DS_SELECT_SECRET_DOMAIN_6_FLAG domainControllerFlag = 0x00000800
  717. _DS_FULL_SECRET_DOMAIN_6_FLAG domainControllerFlag = 0x00001000
  718. _DS_WS_FLAG domainControllerFlag = 0x00002000
  719. _DS_DS_8_FLAG domainControllerFlag = 0x00004000
  720. _DS_DS_9_FLAG domainControllerFlag = 0x00008000
  721. _DS_DS_10_FLAG domainControllerFlag = 0x00010000
  722. _DS_KEY_LIST_FLAG domainControllerFlag = 0x00020000
  723. _DS_PING_FLAGS domainControllerFlag = 0x000FFFFF
  724. _DS_DNS_CONTROLLER_FLAG domainControllerFlag = 0x20000000
  725. _DS_DNS_DOMAIN_FLAG domainControllerFlag = 0x40000000
  726. _DS_DNS_FOREST_FLAG domainControllerFlag = 0x80000000
  727. )
  728. type _DOMAIN_CONTROLLER_INFO struct {
  729. DomainControllerName *uint16
  730. DomainControllerAddress *uint16
  731. DomainControllerAddressType domainControllerAddressType
  732. DomainGuid windows.GUID
  733. DomainName *uint16
  734. DnsForestName *uint16
  735. Flags domainControllerFlag
  736. DcSiteName *uint16
  737. ClientSiteName *uint16
  738. }
  739. func (dci *_DOMAIN_CONTROLLER_INFO) Close() error {
  740. if dci == nil {
  741. return nil
  742. }
  743. return windows.NetApiBufferFree((*byte)(unsafe.Pointer(dci)))
  744. }
  745. type dsGetDcNameFlag uint32
  746. const (
  747. //lint:ignore U1000 maps to a win32 API
  748. _DS_FORCE_REDISCOVERY dsGetDcNameFlag = 0x00000001
  749. _DS_DIRECTORY_SERVICE_REQUIRED dsGetDcNameFlag = 0x00000010
  750. _DS_DIRECTORY_SERVICE_PREFERRED dsGetDcNameFlag = 0x00000020
  751. _DS_GC_SERVER_REQUIRED dsGetDcNameFlag = 0x00000040
  752. _DS_PDC_REQUIRED dsGetDcNameFlag = 0x00000080
  753. _DS_BACKGROUND_ONLY dsGetDcNameFlag = 0x00000100
  754. _DS_IP_REQUIRED dsGetDcNameFlag = 0x00000200
  755. _DS_KDC_REQUIRED dsGetDcNameFlag = 0x00000400
  756. _DS_TIMESERV_REQUIRED dsGetDcNameFlag = 0x00000800
  757. _DS_WRITABLE_REQUIRED dsGetDcNameFlag = 0x00001000
  758. _DS_GOOD_TIMESERV_PREFERRED dsGetDcNameFlag = 0x00002000
  759. _DS_AVOID_SELF dsGetDcNameFlag = 0x00004000
  760. _DS_ONLY_LDAP_NEEDED dsGetDcNameFlag = 0x00008000
  761. _DS_IS_FLAT_NAME dsGetDcNameFlag = 0x00010000
  762. _DS_IS_DNS_NAME dsGetDcNameFlag = 0x00020000
  763. _DS_TRY_NEXTCLOSEST_SITE dsGetDcNameFlag = 0x00040000
  764. _DS_DIRECTORY_SERVICE_6_REQUIRED dsGetDcNameFlag = 0x00080000
  765. _DS_WEB_SERVICE_REQUIRED dsGetDcNameFlag = 0x00100000
  766. _DS_DIRECTORY_SERVICE_8_REQUIRED dsGetDcNameFlag = 0x00200000
  767. _DS_DIRECTORY_SERVICE_9_REQUIRED dsGetDcNameFlag = 0x00400000
  768. _DS_DIRECTORY_SERVICE_10_REQUIRED dsGetDcNameFlag = 0x00800000
  769. _DS_KEY_LIST_SUPPORT_REQUIRED dsGetDcNameFlag = 0x01000000
  770. _DS_RETURN_DNS_NAME dsGetDcNameFlag = 0x40000000
  771. _DS_RETURN_FLAT_NAME dsGetDcNameFlag = 0x80000000
  772. )
  773. func resolveDomainController(domainName *uint16, domainGUID *windows.GUID) (*_DOMAIN_CONTROLLER_INFO, error) {
  774. const flags = _DS_DIRECTORY_SERVICE_REQUIRED | _DS_IS_FLAT_NAME | _DS_RETURN_DNS_NAME
  775. var dcInfo *_DOMAIN_CONTROLLER_INFO
  776. if err := dsGetDcName(nil, domainName, domainGUID, nil, flags, &dcInfo); err != nil {
  777. return nil, err
  778. }
  779. return dcInfo, nil
  780. }
  781. // ResolveDomainController resolves the DNS name of the nearest available
  782. // domain controller for the domain specified by domainName.
  783. func ResolveDomainController(domainName string) (string, error) {
  784. domainName16, err := windows.UTF16PtrFromString(domainName)
  785. if err != nil {
  786. return "", err
  787. }
  788. dcInfo, err := resolveDomainController(domainName16, nil)
  789. if err != nil {
  790. return "", err
  791. }
  792. defer dcInfo.Close()
  793. return windows.UTF16PtrToString(dcInfo.DomainControllerName), nil
  794. }
  795. type _NETSETUP_NAME_TYPE int32
  796. const (
  797. _NetSetupUnknown _NETSETUP_NAME_TYPE = 0
  798. _NetSetupMachine _NETSETUP_NAME_TYPE = 1
  799. _NetSetupWorkgroup _NETSETUP_NAME_TYPE = 2
  800. _NetSetupDomain _NETSETUP_NAME_TYPE = 3
  801. _NetSetupNonExistentDomain _NETSETUP_NAME_TYPE = 4
  802. _NetSetupDnsMachine _NETSETUP_NAME_TYPE = 5
  803. )
  804. func isDomainName(name *uint16) (bool, error) {
  805. err := netValidateName(nil, name, nil, nil, _NetSetupDomain)
  806. switch err {
  807. case nil:
  808. return true, nil
  809. case windows.ERROR_NO_SUCH_DOMAIN:
  810. return false, nil
  811. default:
  812. return false, err
  813. }
  814. }
  815. // IsDomainName checks whether name represents an existing domain reachable by
  816. // the current machine.
  817. func IsDomainName(name string) (bool, error) {
  818. name16, err := windows.UTF16PtrFromString(name)
  819. if err != nil {
  820. return false, err
  821. }
  822. return isDomainName(name16)
  823. }