manager_windows_test.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package dns
  4. import (
  5. "bytes"
  6. "context"
  7. "fmt"
  8. "math/rand"
  9. "net/netip"
  10. "strings"
  11. "testing"
  12. "time"
  13. "golang.org/x/sys/windows"
  14. "golang.org/x/sys/windows/registry"
  15. "tailscale.com/types/logger"
  16. "tailscale.com/util/dnsname"
  17. "tailscale.com/util/syspolicy/policyclient"
  18. "tailscale.com/util/winutil"
  19. "tailscale.com/util/winutil/gp"
  20. )
  21. const testGPRuleID = "{7B1B6151-84E6-41A3-8967-62F7F7B45687}"
  22. func TestHostFileNewLines(t *testing.T) {
  23. in := []byte("#foo\r\n#bar\n#baz\n")
  24. want := []byte("#foo\r\n#bar\r\n#baz\r\n# TailscaleHostsSectionStart\r\n# This section contains MagicDNS entries for Tailscale.\r\n# Do not edit this section manually.\r\n\r\n192.168.1.1 aaron\r\n\r\n# TailscaleHostsSectionEnd\r\n")
  25. he := []*HostEntry{
  26. &HostEntry{
  27. Addr: netip.MustParseAddr("192.168.1.1"),
  28. Hosts: []string{"aaron"},
  29. },
  30. }
  31. got, err := setTailscaleHosts(logger.Discard, in, he)
  32. if err != nil {
  33. t.Fatal(err)
  34. }
  35. if !bytes.Equal(got, want) {
  36. t.Errorf("got %q, want %q\n", got, want)
  37. }
  38. }
  39. func TestHostFileUnchanged(t *testing.T) {
  40. in := []byte("#foo\r\n#bar\r\n#baz\r\n# TailscaleHostsSectionStart\r\n# This section contains MagicDNS entries for Tailscale.\r\n# Do not edit this section manually.\r\n\r\n192.168.1.1 aaron\r\n\r\n# TailscaleHostsSectionEnd\r\n")
  41. he := []*HostEntry{
  42. &HostEntry{
  43. Addr: netip.MustParseAddr("192.168.1.1"),
  44. Hosts: []string{"aaron"},
  45. },
  46. }
  47. got, err := setTailscaleHosts(logger.Discard, in, he)
  48. if err != nil {
  49. t.Fatal(err)
  50. }
  51. if got != nil {
  52. t.Errorf("got %q, want nil\n", got)
  53. }
  54. }
  55. func TestHostFileChanged(t *testing.T) {
  56. in := []byte("#foo\r\n#bar\r\n#baz\r\n# TailscaleHostsSectionStart\r\n# This section contains MagicDNS entries for Tailscale.\r\n# Do not edit this section manually.\r\n\r\n192.168.1.1 aaron1\r\n\r\n# TailscaleHostsSectionEnd\r\n")
  57. want := []byte("#foo\r\n#bar\r\n#baz\r\n# TailscaleHostsSectionStart\r\n# This section contains MagicDNS entries for Tailscale.\r\n# Do not edit this section manually.\r\n\r\n192.168.1.1 aaron1\r\n192.168.1.2 aaron2\r\n\r\n# TailscaleHostsSectionEnd\r\n")
  58. he := []*HostEntry{
  59. &HostEntry{
  60. Addr: netip.MustParseAddr("192.168.1.1"),
  61. Hosts: []string{"aaron1"},
  62. },
  63. &HostEntry{
  64. Addr: netip.MustParseAddr("192.168.1.2"),
  65. Hosts: []string{"aaron2"},
  66. },
  67. }
  68. got, err := setTailscaleHosts(logger.Discard, in, he)
  69. if err != nil {
  70. t.Fatal(err)
  71. }
  72. if !bytes.Equal(got, want) {
  73. t.Errorf("got %q, want %q\n", got, want)
  74. }
  75. }
  76. func TestManagerWindowsLocal(t *testing.T) {
  77. if !isWindows10OrBetter() || !winutil.IsCurrentProcessElevated() {
  78. t.Skipf("test requires running as elevated user on Windows 10+")
  79. }
  80. runTest(t, true)
  81. }
  82. func TestManagerWindowsGP(t *testing.T) {
  83. if !isWindows10OrBetter() || !winutil.IsCurrentProcessElevated() {
  84. t.Skipf("test requires running as elevated user on Windows 10+")
  85. }
  86. checkGPNotificationsWork(t)
  87. // Make sure group policy is refreshed before this test exits but after we've
  88. // cleaned everything else up.
  89. defer gp.RefreshMachinePolicy(true)
  90. err := createFakeGPKey()
  91. if err != nil {
  92. t.Fatalf("Creating fake GP key: %v\n", err)
  93. }
  94. defer deleteFakeGPKey(t)
  95. runTest(t, false)
  96. }
  97. func TestManagerWindowsGPCopy(t *testing.T) {
  98. if !isWindows10OrBetter() || !winutil.IsCurrentProcessElevated() {
  99. t.Skipf("test requires running as elevated user on Windows 10+")
  100. }
  101. checkGPNotificationsWork(t)
  102. logf := func(format string, args ...any) {
  103. t.Logf(format, args...)
  104. }
  105. fakeInterface, err := windows.GenerateGUID()
  106. if err != nil {
  107. t.Fatalf("windows.GenerateGUID: %v\n", err)
  108. }
  109. delIfKey, err := createFakeInterfaceKey(t, fakeInterface)
  110. if err != nil {
  111. t.Fatalf("createFakeInterfaceKey: %v\n", err)
  112. }
  113. defer delIfKey()
  114. cfg, err := NewOSConfigurator(logf, nil, nil, policyclient.NoPolicyClient{}, nil, fakeInterface.String())
  115. if err != nil {
  116. t.Fatalf("NewOSConfigurator: %v\n", err)
  117. }
  118. mgr := cfg.(*windowsManager)
  119. defer mgr.Close()
  120. usingGP := mgr.nrptDB.writeAsGP
  121. if usingGP {
  122. t.Fatalf("usingGP %v, want %v\n", usingGP, false)
  123. }
  124. regWatcher, err := newRegKeyWatcher()
  125. if err != nil {
  126. t.Fatalf("newRegKeyWatcher error %v\n", err)
  127. }
  128. // Upon initialization of cfg, we should not have any NRPT rules
  129. ensureNoRules(t)
  130. resolvers := []netip.Addr{netip.MustParseAddr("1.1.1.1")}
  131. domains := genRandomSubdomains(t, 1)
  132. // 1. Populate local NRPT
  133. err = mgr.setSplitDNS(resolvers, domains)
  134. if err != nil {
  135. t.Fatalf("setSplitDNS: %v\n", err)
  136. }
  137. t.Logf("Validating that local NRPT is populated...\n")
  138. validateRegistry(t, nrptBaseLocal, domains)
  139. ensureNoRulesInSubkey(t, nrptBaseGP)
  140. // 2. Create fake GP key and refresh
  141. t.Logf("Creating fake group policy key and refreshing...\n")
  142. err = createFakeGPKey()
  143. if err != nil {
  144. t.Fatalf("createFakeGPKey: %v\n", err)
  145. }
  146. err = regWatcher.watch()
  147. if err != nil {
  148. t.Fatalf("regWatcher.watch: %v\n", err)
  149. }
  150. err = gp.RefreshMachinePolicy(true)
  151. if err != nil {
  152. t.Fatalf("testDoRefresh: %v\n", err)
  153. }
  154. err = regWatcher.wait()
  155. if err != nil {
  156. t.Fatalf("regWatcher.wait: %v\n", err)
  157. }
  158. // 3. Check that both local NRPT and GP NRPT are populated
  159. t.Logf("Validating that group policy NRPT is populated...\n")
  160. validateRegistry(t, nrptBaseLocal, domains)
  161. validateRegistry(t, nrptBaseGP, domains)
  162. // 4. Delete fake GP key and refresh
  163. t.Logf("Deleting fake group policy key and refreshing...\n")
  164. deleteFakeGPKey(t)
  165. err = regWatcher.watch()
  166. if err != nil {
  167. t.Fatalf("regWatcher.watch: %v\n", err)
  168. }
  169. err = gp.RefreshMachinePolicy(true)
  170. if err != nil {
  171. t.Fatalf("testDoRefresh: %v\n", err)
  172. }
  173. err = regWatcher.wait()
  174. if err != nil {
  175. t.Fatalf("regWatcher.wait: %v\n", err)
  176. }
  177. // 5. Check that local NRPT is populated and GP is empty
  178. t.Logf("Validating that local NRPT is populated...\n")
  179. validateRegistry(t, nrptBaseLocal, domains)
  180. ensureNoRulesInSubkey(t, nrptBaseGP)
  181. // 6. Cleanup
  182. t.Logf("Cleaning up...\n")
  183. err = mgr.setSplitDNS(nil, domains)
  184. if err != nil {
  185. t.Fatalf("setSplitDNS: %v\n", err)
  186. }
  187. ensureNoRules(t)
  188. }
  189. func checkGPNotificationsWork(t *testing.T) {
  190. // Test to ensure that RegisterGPNotification work on this machine,
  191. // otherwise this test will fail.
  192. trk, err := newGPNotificationTracker()
  193. if err != nil {
  194. t.Skipf("newGPNotificationTracker error: %v\n", err)
  195. }
  196. defer trk.Close()
  197. err = gp.RefreshMachinePolicy(true)
  198. if err != nil {
  199. t.Fatalf("RefreshPolicyEx error: %v\n", err)
  200. }
  201. timeout := uint32(10000) // Milliseconds
  202. if !trk.DidRefreshTimeout(timeout) {
  203. t.Skipf("GP notifications are not working on this machine\n")
  204. }
  205. }
  206. func runTest(t *testing.T, isLocal bool) {
  207. logf := func(format string, args ...any) {
  208. t.Logf(format, args...)
  209. }
  210. fakeInterface, err := windows.GenerateGUID()
  211. if err != nil {
  212. t.Fatalf("windows.GenerateGUID: %v\n", err)
  213. }
  214. delIfKey, err := createFakeInterfaceKey(t, fakeInterface)
  215. if err != nil {
  216. t.Fatalf("createFakeInterfaceKey: %v\n", err)
  217. }
  218. defer delIfKey()
  219. cfg, err := NewOSConfigurator(logf, nil, nil, policyclient.NoPolicyClient{}, nil, fakeInterface.String())
  220. if err != nil {
  221. t.Fatalf("NewOSConfigurator: %v\n", err)
  222. }
  223. mgr := cfg.(*windowsManager)
  224. defer mgr.Close()
  225. usingGP := mgr.nrptDB.writeAsGP
  226. if isLocal == usingGP {
  227. t.Fatalf("usingGP %v, want %v\n", usingGP, !usingGP)
  228. }
  229. // Upon initialization of cfg, we should not have any NRPT rules
  230. ensureNoRules(t)
  231. resolvers := []netip.Addr{netip.MustParseAddr("1.1.1.1")}
  232. domains := genRandomSubdomains(t, 2*nrptMaxDomainsPerRule+1)
  233. cases := []int{
  234. 1,
  235. 50,
  236. 51,
  237. 100,
  238. 101,
  239. 100,
  240. 50,
  241. 1,
  242. 51,
  243. }
  244. var regBaseValidate string
  245. var regBaseEnsure string
  246. if isLocal {
  247. regBaseValidate = nrptBaseLocal
  248. regBaseEnsure = nrptBaseGP
  249. } else {
  250. regBaseValidate = nrptBaseGP
  251. regBaseEnsure = nrptBaseLocal
  252. }
  253. var trk *gpNotificationTracker
  254. if isLocal {
  255. // (dblohm7) When isLocal == true, we keep trk active through the entire
  256. // sequence of test cases, and then we verify that no policy notifications
  257. // occurred. Because policy notifications are scoped to the entire computer,
  258. // this check could potentially fail if another process concurrently modifies
  259. // group policies while this test is running. I don't expect this to be an
  260. // issue on any computer on which we run this test, but something to keep in
  261. // mind if we start seeing flakiness around these GP notifications.
  262. trk, err = newGPNotificationTracker()
  263. if err != nil {
  264. t.Fatalf("newGPNotificationTracker: %v\n", err)
  265. }
  266. defer trk.Close()
  267. }
  268. runCase := func(n int) {
  269. t.Logf("Test case: %d domains\n", n)
  270. if !isLocal {
  271. // When !isLocal, we want to check that a GP notification occurred for
  272. // every single test case.
  273. trk, err = newGPNotificationTracker()
  274. if err != nil {
  275. t.Fatalf("newGPNotificationTracker: %v\n", err)
  276. }
  277. defer trk.Close()
  278. }
  279. caseDomains := domains[:n]
  280. err = mgr.setSplitDNS(resolvers, caseDomains)
  281. if err != nil {
  282. t.Fatalf("setSplitDNS: %v\n", err)
  283. }
  284. validateRegistry(t, regBaseValidate, caseDomains)
  285. ensureNoRulesInSubkey(t, regBaseEnsure)
  286. if !isLocal && !trk.DidRefresh(true) {
  287. t.Fatalf("DidRefresh false, want true\n")
  288. }
  289. }
  290. for _, n := range cases {
  291. runCase(n)
  292. }
  293. if isLocal && trk.DidRefresh(false) {
  294. t.Errorf("DidRefresh true, want false\n")
  295. }
  296. t.Logf("Test case: nil resolver\n")
  297. err = mgr.setSplitDNS(nil, domains)
  298. if err != nil {
  299. t.Fatalf("setSplitDNS: %v\n", err)
  300. }
  301. ensureNoRules(t)
  302. }
  303. func createFakeGPKey() error {
  304. keyStr := nrptBaseGP + `\` + testGPRuleID
  305. key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, keyStr, registry.SET_VALUE)
  306. if err != nil {
  307. return fmt.Errorf("opening %s: %w", keyStr, err)
  308. }
  309. defer key.Close()
  310. if err := key.SetDWordValue("Version", 1); err != nil {
  311. return err
  312. }
  313. if err := key.SetStringsValue("Name", []string{"._setbygp_.example.com"}); err != nil {
  314. return err
  315. }
  316. if err := key.SetStringValue("GenericDNSServers", "1.1.1.1"); err != nil {
  317. return err
  318. }
  319. if err := key.SetDWordValue("ConfigOptions", nrptOverrideDNS); err != nil {
  320. return err
  321. }
  322. return nil
  323. }
  324. func deleteFakeGPKey(t *testing.T) {
  325. keyName := nrptBaseGP + `\` + testGPRuleID
  326. if err := registry.DeleteKey(registry.LOCAL_MACHINE, keyName); err != nil && err != registry.ErrNotExist {
  327. t.Fatalf("Error deleting NRPT rule key %q: %v\n", keyName, err)
  328. }
  329. isEmpty, err := isPolicyConfigSubkeyEmpty()
  330. if err != nil {
  331. t.Fatalf("isPolicyConfigSubkeyEmpty: %v", err)
  332. }
  333. if !isEmpty {
  334. return
  335. }
  336. if err := registry.DeleteKey(registry.LOCAL_MACHINE, nrptBaseGP); err != nil {
  337. t.Fatalf("Deleting DnsPolicyKey Subkey: %v", err)
  338. }
  339. }
  340. func createFakeInterfaceKey(t *testing.T, guid windows.GUID) (func(), error) {
  341. basePaths := []winutil.RegistryPathPrefix{winutil.IPv4TCPIPInterfacePrefix, winutil.IPv6TCPIPInterfacePrefix}
  342. keyPaths := make([]string, 0, len(basePaths))
  343. guidStr := guid.String()
  344. for _, basePath := range basePaths {
  345. keyPath := string(basePath.WithSuffix(guidStr))
  346. key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, keyPath, registry.SET_VALUE)
  347. if err != nil {
  348. return nil, err
  349. }
  350. key.Close()
  351. keyPaths = append(keyPaths, keyPath)
  352. }
  353. result := func() {
  354. for _, keyPath := range keyPaths {
  355. if err := registry.DeleteKey(registry.LOCAL_MACHINE, keyPath); err != nil {
  356. t.Fatalf("deleting fake interface key \"%s\": %v\n", keyPath, err)
  357. }
  358. }
  359. }
  360. return result, nil
  361. }
  362. func ensureNoRules(t *testing.T) {
  363. ruleIDs := winutil.GetRegStrings(nrptRuleIDValueName, nil)
  364. if ruleIDs != nil {
  365. t.Errorf("%s: %v, want nil\n", nrptRuleIDValueName, ruleIDs)
  366. }
  367. for _, base := range []string{nrptBaseLocal, nrptBaseGP} {
  368. ensureNoSingleRule(t, base)
  369. }
  370. }
  371. func ensureNoRulesInSubkey(t *testing.T, base string) {
  372. ruleIDs := winutil.GetRegStrings(nrptRuleIDValueName, nil)
  373. if ruleIDs == nil {
  374. for _, base := range []string{nrptBaseLocal, nrptBaseGP} {
  375. ensureNoSingleRule(t, base)
  376. }
  377. return
  378. }
  379. for _, ruleID := range ruleIDs {
  380. keyName := base + `\` + ruleID
  381. key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyName, registry.READ)
  382. if err == nil {
  383. key.Close()
  384. } else if err != registry.ErrNotExist {
  385. t.Fatalf("%s: %q, want %q\n", keyName, err, registry.ErrNotExist)
  386. }
  387. }
  388. if base == nrptBaseGP {
  389. // When dealing with the group policy subkey, we want the base key to
  390. // also be absent.
  391. key, err := registry.OpenKey(registry.LOCAL_MACHINE, base, registry.READ)
  392. if err == nil {
  393. key.Close()
  394. isEmpty, err := isPolicyConfigSubkeyEmpty()
  395. if err != nil {
  396. t.Fatalf("isPolicyConfigSubkeyEmpty: %v", err)
  397. }
  398. if isEmpty {
  399. t.Errorf("Unexpectedly found group policy key\n")
  400. }
  401. } else if err != registry.ErrNotExist {
  402. t.Errorf("Group policy key error: %q, want %q\n", err, registry.ErrNotExist)
  403. }
  404. }
  405. }
  406. func ensureNoSingleRule(t *testing.T, base string) {
  407. singleKeyPath := base + `\` + nrptSingleRuleID
  408. key, err := registry.OpenKey(registry.LOCAL_MACHINE, singleKeyPath, registry.READ)
  409. if err == nil {
  410. key.Close()
  411. }
  412. if err != registry.ErrNotExist {
  413. t.Fatalf("%s: %q, want %q\n", singleKeyPath, err, registry.ErrNotExist)
  414. }
  415. }
  416. func validateRegistry(t *testing.T, nrptBase string, domains []dnsname.FQDN) {
  417. q := len(domains) / nrptMaxDomainsPerRule
  418. r := len(domains) % nrptMaxDomainsPerRule
  419. numRules := q
  420. if r > 0 {
  421. numRules++
  422. }
  423. ruleIDs := winutil.GetRegStrings(nrptRuleIDValueName, nil)
  424. if ruleIDs == nil {
  425. ruleIDs = []string{nrptSingleRuleID}
  426. } else if len(ruleIDs) != numRules {
  427. t.Errorf("%s for %d domains: %d, want %d\n", nrptRuleIDValueName, len(domains), len(ruleIDs), numRules)
  428. }
  429. for i, ruleID := range ruleIDs {
  430. savedDomains, err := getSavedDomainsForRule(nrptBase, ruleID)
  431. if err != nil {
  432. t.Fatalf("getSavedDomainsForRule(%q, %q): %v\n", nrptBase, ruleID, err)
  433. }
  434. start := i * nrptMaxDomainsPerRule
  435. end := start + nrptMaxDomainsPerRule
  436. if i == len(ruleIDs)-1 && r > 0 {
  437. end = start + r
  438. }
  439. checkDomains := domains[start:end]
  440. if len(checkDomains) != len(savedDomains) {
  441. t.Errorf("len(checkDomains) != len(savedDomains): %d, want %d\n", len(savedDomains), len(checkDomains))
  442. }
  443. for j, cd := range checkDomains {
  444. sd := strings.TrimPrefix(savedDomains[j], ".")
  445. if string(cd.WithoutTrailingDot()) != sd {
  446. t.Errorf("checkDomain differs savedDomain: %s, want %s\n", sd, cd.WithoutTrailingDot())
  447. }
  448. }
  449. }
  450. }
  451. func getSavedDomainsForRule(base, ruleID string) ([]string, error) {
  452. keyPath := base + `\` + ruleID
  453. key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.READ)
  454. if err != nil {
  455. return nil, err
  456. }
  457. defer key.Close()
  458. result, _, err := key.GetStringsValue("Name")
  459. return result, err
  460. }
  461. func genRandomSubdomains(t *testing.T, n int) []dnsname.FQDN {
  462. domains := make([]dnsname.FQDN, 0, n)
  463. seed := time.Now().UnixNano()
  464. t.Logf("genRandomSubdomains(%d) seed: %v\n", n, seed)
  465. r := rand.New(rand.NewSource(seed))
  466. const charset = "abcdefghijklmnopqrstuvwxyz"
  467. for len(domains) < cap(domains) {
  468. ln := r.Intn(19) + 1
  469. b := make([]byte, ln)
  470. for i := range b {
  471. b[i] = charset[r.Intn(len(charset))]
  472. }
  473. d := string(b) + ".example.com"
  474. fqdn, err := dnsname.ToFQDN(d)
  475. if err != nil {
  476. t.Fatalf("dnsname.ToFQDN: %v\n", err)
  477. }
  478. domains = append(domains, fqdn)
  479. }
  480. return domains
  481. }
  482. var (
  483. libUserenv = windows.NewLazySystemDLL("userenv.dll")
  484. procRegisterGPNotification = libUserenv.NewProc("RegisterGPNotification")
  485. procUnregisterGPNotification = libUserenv.NewProc("UnregisterGPNotification")
  486. )
  487. // gpNotificationTracker registers with the Windows policy engine and receives
  488. // notifications when policy refreshes occur.
  489. type gpNotificationTracker struct {
  490. event windows.Handle
  491. }
  492. func newGPNotificationTracker() (*gpNotificationTracker, error) {
  493. var err error
  494. evt, err := windows.CreateEvent(nil, 0, 0, nil)
  495. if err != nil {
  496. return nil, err
  497. }
  498. defer func() {
  499. if err != nil {
  500. windows.CloseHandle(evt)
  501. }
  502. }()
  503. ok, _, e := procRegisterGPNotification.Call(
  504. uintptr(evt),
  505. uintptr(1), // We want computer policy changes, not user policy changes.
  506. )
  507. if ok == 0 {
  508. err = e
  509. return nil, err
  510. }
  511. return &gpNotificationTracker{evt}, nil
  512. }
  513. func (trk *gpNotificationTracker) DidRefresh(isExpected bool) bool {
  514. // If we're not expecting a refresh event, then we need to use a timeout.
  515. timeout := uint32(1000) // 1 second (in milliseconds)
  516. if isExpected {
  517. // Otherwise, since it is imperative that we see an event, we wait infinitely.
  518. timeout = windows.INFINITE
  519. }
  520. return trk.DidRefreshTimeout(timeout)
  521. }
  522. func (trk *gpNotificationTracker) DidRefreshTimeout(timeout uint32) bool {
  523. waitCode, _ := windows.WaitForSingleObject(trk.event, timeout)
  524. return waitCode == windows.WAIT_OBJECT_0
  525. }
  526. func (trk *gpNotificationTracker) Close() error {
  527. procUnregisterGPNotification.Call(uintptr(trk.event))
  528. windows.CloseHandle(trk.event)
  529. trk.event = 0
  530. return nil
  531. }
  532. type regKeyWatcher struct {
  533. keyGP registry.Key
  534. evtGP windows.Handle
  535. }
  536. func newRegKeyWatcher() (result *regKeyWatcher, err error) {
  537. // Monitor dnsBaseGP instead of nrptBaseGP, since the latter will be
  538. // repeatedly created and destroyed throughout the course of the test.
  539. keyGP, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsBaseGP, registry.READ)
  540. if err != nil {
  541. return nil, err
  542. }
  543. defer func() {
  544. if err != nil {
  545. keyGP.Close()
  546. }
  547. }()
  548. evtGP, err := windows.CreateEvent(nil, 0, 0, nil)
  549. if err != nil {
  550. return nil, err
  551. }
  552. return &regKeyWatcher{
  553. keyGP: keyGP,
  554. evtGP: evtGP,
  555. }, nil
  556. }
  557. func (rw *regKeyWatcher) watch() error {
  558. // We can make these waits thread-agnostic because the tests that use this code must already run on Windows 10+
  559. return windows.RegNotifyChangeKeyValue(windows.Handle(rw.keyGP), true,
  560. windows.REG_NOTIFY_CHANGE_NAME|windows.REG_NOTIFY_THREAD_AGNOSTIC, rw.evtGP, true)
  561. }
  562. func (rw *regKeyWatcher) wait() error {
  563. waitCode, err := windows.WaitForSingleObject(
  564. rw.evtGP,
  565. 10000, // 10 seconds (as milliseconds)
  566. )
  567. switch waitCode {
  568. case uint32(windows.WAIT_TIMEOUT):
  569. return context.DeadlineExceeded
  570. case windows.WAIT_FAILED:
  571. return err
  572. default:
  573. return nil
  574. }
  575. }
  576. func (rw *regKeyWatcher) Close() error {
  577. rw.keyGP.Close()
  578. windows.CloseHandle(rw.evtGP)
  579. return nil
  580. }