driver_windows.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. //go:build windows
  2. package windivert
  3. import (
  4. "errors"
  5. "os"
  6. "path/filepath"
  7. "runtime"
  8. "strconv"
  9. "sync"
  10. E "github.com/sagernet/sing/common/exceptions"
  11. "golang.org/x/sys/windows"
  12. )
  13. const (
  14. driverServiceName = "WinDivert"
  15. driverDeviceName = `\\.\WinDivert`
  16. )
  17. var (
  18. driverOnce sync.Once
  19. driverErr error
  20. // driverDevName is ASCII-safe and must be available before ensureDriver
  21. // so Open can try CreateFile first and only install on FILE_NOT_FOUND.
  22. driverDevName, _ = windows.UTF16PtrFromString(driverDeviceName)
  23. )
  24. // Requires SeLoadDriverPrivilege (Administrator). Running the 386 build
  25. // under WOW64 on a 64-bit kernel is rejected — use the amd64 build.
  26. func ensureDriver() error {
  27. driverOnce.Do(func() {
  28. driverErr = installDriver()
  29. })
  30. return driverErr
  31. }
  32. func installDriver() error {
  33. if runtime.GOARCH == "386" {
  34. var isWow64 bool
  35. err := windows.IsWow64Process(windows.CurrentProcess(), &isWow64)
  36. if err == nil && isWow64 {
  37. return E.New("windivert: 386 build detected running under WOW64 on a 64-bit kernel; use the amd64 build")
  38. }
  39. }
  40. dir, err := ensureExtracted()
  41. if err != nil {
  42. return err
  43. }
  44. sysPath := filepath.Join(dir, driverSysName())
  45. sysPathW, err := windows.UTF16PtrFromString(sysPath)
  46. if err != nil {
  47. return E.Cause(err, "windivert: utf16 driver path")
  48. }
  49. // Serialize driver install across concurrent processes.
  50. mutexName, _ := windows.UTF16PtrFromString("WinDivertDriverInstallMutex")
  51. mutex, err := windows.CreateMutex(nil, false, mutexName)
  52. if err != nil {
  53. return E.Cause(err, "windivert: create install mutex")
  54. }
  55. defer windows.CloseHandle(mutex)
  56. _, err = windows.WaitForSingleObject(mutex, windows.INFINITE)
  57. if err != nil {
  58. return E.Cause(err, "windivert: wait install mutex")
  59. }
  60. defer windows.ReleaseMutex(mutex)
  61. manager, err := windows.OpenSCManager(nil, nil, windows.SC_MANAGER_ALL_ACCESS)
  62. if err != nil {
  63. return E.Cause(err, "windivert: open SCM")
  64. }
  65. defer windows.CloseServiceHandle(manager)
  66. serviceNameW, _ := windows.UTF16PtrFromString(driverServiceName)
  67. service, err := windows.OpenService(manager, serviceNameW, windows.SERVICE_ALL_ACCESS)
  68. if err != nil {
  69. service, err = windows.CreateService(
  70. manager,
  71. serviceNameW,
  72. serviceNameW,
  73. windows.SERVICE_ALL_ACCESS,
  74. windows.SERVICE_KERNEL_DRIVER,
  75. windows.SERVICE_DEMAND_START,
  76. windows.SERVICE_ERROR_NORMAL,
  77. sysPathW,
  78. nil, nil, nil, nil, nil,
  79. )
  80. if err != nil {
  81. if errors.Is(err, windows.ERROR_SERVICE_EXISTS) {
  82. service, err = windows.OpenService(manager, serviceNameW, windows.SERVICE_ALL_ACCESS)
  83. }
  84. if err != nil {
  85. return wrapDriverInstallError(err)
  86. }
  87. }
  88. }
  89. defer windows.CloseServiceHandle(service)
  90. err = windows.StartService(service, 0, nil)
  91. if err != nil && errors.Is(err, windows.ERROR_SERVICE_DISABLED) {
  92. // A prior process called DeleteService on a still-running kernel
  93. // driver: SCM marks the record for deletion and flips START_TYPE
  94. // to DISABLED until the last handle closes. Re-enable so we can
  95. // start it instead of waiting for a reboot.
  96. err = windows.ChangeServiceConfig(
  97. service,
  98. windows.SERVICE_NO_CHANGE,
  99. windows.SERVICE_DEMAND_START,
  100. windows.SERVICE_NO_CHANGE,
  101. nil, nil, nil, nil, nil, nil, nil,
  102. )
  103. if err != nil {
  104. return E.Cause(err, "windivert: re-enable disabled service")
  105. }
  106. err = windows.StartService(service, 0, nil)
  107. }
  108. if err == nil {
  109. // Mark for deletion so the driver unregisters when the last handle
  110. // closes or on next reboot. Matches the upstream DLL's behavior:
  111. // only the process that actually started the service takes on the
  112. // cleanup responsibility. If another process already started it,
  113. // we leave DeleteService to them.
  114. _ = windows.DeleteService(service)
  115. } else if !errors.Is(err, windows.ERROR_SERVICE_ALREADY_RUNNING) {
  116. return E.Cause(err, "windivert: start service")
  117. }
  118. return nil
  119. }
  120. func wrapDriverInstallError(err error) error {
  121. if errors.Is(err, windows.ERROR_ACCESS_DENIED) {
  122. return E.Cause(err, "windivert: installing the kernel driver requires Administrator privileges")
  123. }
  124. return E.Cause(err, "windivert: create service")
  125. }
  126. type assetFile struct {
  127. name string
  128. data []byte
  129. }
  130. var (
  131. extractOnce sync.Once
  132. extractErr error
  133. extractDir string
  134. )
  135. // The on-disk copy is protected by Windows Authenticode signature
  136. // enforcement, which rejects any tampered .sys at StartService time.
  137. func ensureExtracted() (string, error) {
  138. extractOnce.Do(func() {
  139. extractDir, extractErr = extractImpl()
  140. })
  141. return extractDir, extractErr
  142. }
  143. func extractImpl() (string, error) {
  144. files := assetFiles()
  145. if len(files) == 0 {
  146. return "", E.New("windivert: unsupported architecture ", runtime.GOARCH)
  147. }
  148. base, err := os.UserCacheDir()
  149. if err != nil {
  150. return "", E.Cause(err, "windivert: locate user cache dir")
  151. }
  152. dir := filepath.Join(base, "sing-box", "windivert", "v"+AssetVersion)
  153. err = os.MkdirAll(dir, 0o755)
  154. if err != nil {
  155. return "", E.Cause(err, "windivert: mkdir ", dir)
  156. }
  157. for _, asset := range files {
  158. err = ensureAsset(dir, asset)
  159. if err != nil {
  160. return "", err
  161. }
  162. }
  163. return dir, nil
  164. }
  165. // Concurrent sing-box processes race on os.Rename (atomic on NTFS);
  166. // whichever wins creates the final file. Writers that lose the race
  167. // silently discard their temp copy.
  168. func ensureAsset(dir string, asset assetFile) error {
  169. target := filepath.Join(dir, asset.name)
  170. _, err := os.Stat(target)
  171. if err == nil {
  172. return nil
  173. }
  174. if !os.IsNotExist(err) {
  175. return E.Cause(err, "windivert: stat ", asset.name)
  176. }
  177. tmp := target + ".tmp-" + strconv.Itoa(os.Getpid())
  178. err = os.WriteFile(tmp, asset.data, 0o644)
  179. if err != nil {
  180. return E.Cause(err, "windivert: write ", asset.name)
  181. }
  182. err = os.Rename(tmp, target)
  183. if err != nil {
  184. os.Remove(tmp)
  185. if _, statErr := os.Stat(target); statErr == nil {
  186. return nil
  187. }
  188. return E.Cause(err, "windivert: rename ", asset.name)
  189. }
  190. return nil
  191. }