atomicfile_windows_test.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package atomicfile
  4. import (
  5. "os"
  6. "testing"
  7. "unsafe"
  8. "golang.org/x/sys/windows"
  9. )
  10. var _SECURITY_RESOURCE_MANAGER_AUTHORITY = windows.SidIdentifierAuthority{[6]byte{0, 0, 0, 0, 0, 9}}
  11. // makeRandomSID generates a SID derived from a v4 GUID.
  12. // This is basically the same algorithm used by browser sandboxes for generating
  13. // random SIDs.
  14. func makeRandomSID() (*windows.SID, error) {
  15. guid, err := windows.GenerateGUID()
  16. if err != nil {
  17. return nil, err
  18. }
  19. rids := *((*[4]uint32)(unsafe.Pointer(&guid)))
  20. var pSID *windows.SID
  21. if err := windows.AllocateAndInitializeSid(&_SECURITY_RESOURCE_MANAGER_AUTHORITY, 4, rids[0], rids[1], rids[2], rids[3], 0, 0, 0, 0, &pSID); err != nil {
  22. return nil, err
  23. }
  24. defer windows.FreeSid(pSID)
  25. // Make a copy that lives on the Go heap
  26. return pSID.Copy()
  27. }
  28. func getExistingFileSD(name string) (*windows.SECURITY_DESCRIPTOR, error) {
  29. const infoFlags = windows.DACL_SECURITY_INFORMATION
  30. return windows.GetNamedSecurityInfo(name, windows.SE_FILE_OBJECT, infoFlags)
  31. }
  32. func getExistingFileDACL(name string) (*windows.ACL, error) {
  33. sd, err := getExistingFileSD(name)
  34. if err != nil {
  35. return nil, err
  36. }
  37. dacl, _, err := sd.DACL()
  38. return dacl, err
  39. }
  40. func addDenyACEForRandomSID(dacl *windows.ACL) (*windows.ACL, error) {
  41. randomSID, err := makeRandomSID()
  42. if err != nil {
  43. return nil, err
  44. }
  45. randomSIDTrustee := windows.TRUSTEE{nil, windows.NO_MULTIPLE_TRUSTEE,
  46. windows.TRUSTEE_IS_SID, windows.TRUSTEE_IS_UNKNOWN,
  47. windows.TrusteeValueFromSID(randomSID)}
  48. entries := []windows.EXPLICIT_ACCESS{
  49. {
  50. windows.GENERIC_ALL,
  51. windows.DENY_ACCESS,
  52. windows.NO_INHERITANCE,
  53. randomSIDTrustee,
  54. },
  55. }
  56. return windows.ACLFromEntries(entries, dacl)
  57. }
  58. func setExistingFileDACL(name string, dacl *windows.ACL) error {
  59. return windows.SetNamedSecurityInfo(name, windows.SE_FILE_OBJECT,
  60. windows.DACL_SECURITY_INFORMATION, nil, nil, dacl, nil)
  61. }
  62. // makeOrigFileWithCustomDACL creates a new, temporary file with a custom
  63. // DACL that we can check for later. It returns the name of the temporary
  64. // file and the security descriptor for the file in SDDL format.
  65. func makeOrigFileWithCustomDACL() (name, sddl string, err error) {
  66. f, err := os.CreateTemp("", "foo*.tmp")
  67. if err != nil {
  68. return "", "", err
  69. }
  70. name = f.Name()
  71. if err := f.Close(); err != nil {
  72. return "", "", err
  73. }
  74. f = nil
  75. defer func() {
  76. if err != nil {
  77. os.Remove(name)
  78. }
  79. }()
  80. dacl, err := getExistingFileDACL(name)
  81. if err != nil {
  82. return "", "", err
  83. }
  84. // Add a harmless, deny-only ACE for a random SID that isn't used for anything
  85. // (but that we can check for later).
  86. dacl, err = addDenyACEForRandomSID(dacl)
  87. if err != nil {
  88. return "", "", err
  89. }
  90. if err := setExistingFileDACL(name, dacl); err != nil {
  91. return "", "", err
  92. }
  93. sd, err := getExistingFileSD(name)
  94. if err != nil {
  95. return "", "", err
  96. }
  97. return name, sd.String(), nil
  98. }
  99. func TestPreserveSecurityInfo(t *testing.T) {
  100. // Make a test file with a custom ACL.
  101. origFileName, want, err := makeOrigFileWithCustomDACL()
  102. if err != nil {
  103. t.Fatalf("makeOrigFileWithCustomDACL returned %v", err)
  104. }
  105. t.Cleanup(func() {
  106. os.Remove(origFileName)
  107. })
  108. if err := WriteFile(origFileName, []byte{}, 0); err != nil {
  109. t.Fatalf("WriteFile returned %v", err)
  110. }
  111. // We expect origFileName's security descriptor to be unchanged despite
  112. // the WriteFile call.
  113. sd, err := getExistingFileSD(origFileName)
  114. if err != nil {
  115. t.Fatalf("getExistingFileSD(%q) returned %v", origFileName, err)
  116. }
  117. if got := sd.String(); got != want {
  118. t.Errorf("security descriptor comparison failed: got %q, want %q", got, want)
  119. }
  120. }