service_windows.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. package service
  2. import (
  3. "fmt"
  4. "os"
  5. "path/filepath"
  6. "strings"
  7. "time"
  8. "golang.org/x/sys/windows/svc"
  9. "golang.org/x/sys/windows/svc/eventlog"
  10. "golang.org/x/sys/windows/svc/mgr"
  11. "github.com/drakkan/sftpgo/dataprovider"
  12. "github.com/drakkan/sftpgo/ftpd"
  13. "github.com/drakkan/sftpgo/httpd"
  14. "github.com/drakkan/sftpgo/logger"
  15. "github.com/drakkan/sftpgo/webdavd"
  16. )
  17. const (
  18. serviceName = "SFTPGo"
  19. serviceDesc = "Fully featured and highly configurable SFTP server with optional FTP/S and WebDAV support"
  20. rotateLogCmd = svc.Cmd(128)
  21. acceptRotateLog = svc.Accepted(rotateLogCmd)
  22. )
  23. // Status defines service status
  24. type Status uint8
  25. // Supported values for service status
  26. const (
  27. StatusUnknown Status = iota
  28. StatusRunning
  29. StatusStopped
  30. StatusPaused
  31. StatusStartPending
  32. StatusPausePending
  33. StatusContinuePending
  34. StatusStopPending
  35. )
  36. type WindowsService struct {
  37. Service Service
  38. isInteractive bool
  39. }
  40. func (s Status) String() string {
  41. switch s {
  42. case StatusRunning:
  43. return "running"
  44. case StatusStopped:
  45. return "stopped"
  46. case StatusStartPending:
  47. return "start pending"
  48. case StatusPausePending:
  49. return "pause pending"
  50. case StatusPaused:
  51. return "paused"
  52. case StatusContinuePending:
  53. return "continue pending"
  54. case StatusStopPending:
  55. return "stop pending"
  56. default:
  57. return "unknown"
  58. }
  59. }
  60. func (s *WindowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, errno uint32) {
  61. const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown | svc.AcceptParamChange | acceptRotateLog
  62. changes <- svc.Status{State: svc.StartPending}
  63. if err := s.Service.Start(); err != nil {
  64. return true, 1
  65. }
  66. changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
  67. loop:
  68. for {
  69. c := <-r
  70. switch c.Cmd {
  71. case svc.Interrogate:
  72. logger.Debug(logSender, "", "Received service interrogate request, current status: %v", c.CurrentStatus)
  73. changes <- c.CurrentStatus
  74. case svc.Stop, svc.Shutdown:
  75. logger.Debug(logSender, "", "Received service stop request")
  76. changes <- svc.Status{State: svc.StopPending}
  77. s.Service.Stop()
  78. break loop
  79. case svc.ParamChange:
  80. logger.Debug(logSender, "", "Received reload request")
  81. err := dataprovider.ReloadConfig()
  82. if err != nil {
  83. logger.Warn(logSender, "", "error reloading dataprovider configuration: %v", err)
  84. }
  85. err = httpd.ReloadTLSCertificate()
  86. if err != nil {
  87. logger.Warn(logSender, "", "error reloading TLS certificate: %v", err)
  88. }
  89. err = ftpd.ReloadTLSCertificate()
  90. if err != nil {
  91. logger.Warn(logSender, "", "error reloading FTPD TLS certificate: %v", err)
  92. }
  93. err = webdavd.ReloadTLSCertificate()
  94. if err != nil {
  95. logger.Warn(logSender, "", "error reloading WebDav TLS certificate: %v", err)
  96. }
  97. case rotateLogCmd:
  98. logger.Debug(logSender, "", "Received log file rotation request")
  99. err := logger.RotateLogFile()
  100. if err != nil {
  101. logger.Warn(logSender, "", "error rotating log file: %v", err)
  102. }
  103. default:
  104. continue loop
  105. }
  106. }
  107. return false, 0
  108. }
  109. func (s *WindowsService) RunService() error {
  110. exePath, err := s.getExePath()
  111. if err != nil {
  112. return err
  113. }
  114. isIntSess, err := svc.IsAnInteractiveSession()
  115. if err != nil {
  116. return err
  117. }
  118. s.isInteractive = isIntSess
  119. dir := filepath.Dir(exePath)
  120. if err = os.Chdir(dir); err != nil {
  121. return err
  122. }
  123. if s.isInteractive {
  124. return s.Start()
  125. }
  126. return svc.Run(serviceName, s)
  127. }
  128. func (s *WindowsService) Start() error {
  129. m, err := mgr.Connect()
  130. if err != nil {
  131. return err
  132. }
  133. defer m.Disconnect()
  134. service, err := m.OpenService(serviceName)
  135. if err != nil {
  136. return fmt.Errorf("could not access service: %v", err)
  137. }
  138. defer service.Close()
  139. err = service.Start()
  140. if err != nil {
  141. return fmt.Errorf("could not start service: %v", err)
  142. }
  143. return nil
  144. }
  145. func (s *WindowsService) Reload() error {
  146. m, err := mgr.Connect()
  147. if err != nil {
  148. return err
  149. }
  150. defer m.Disconnect()
  151. service, err := m.OpenService(serviceName)
  152. if err != nil {
  153. return fmt.Errorf("could not access service: %v", err)
  154. }
  155. defer service.Close()
  156. _, err = service.Control(svc.ParamChange)
  157. if err != nil {
  158. return fmt.Errorf("could not send control=%d: %v", svc.ParamChange, err)
  159. }
  160. return nil
  161. }
  162. func (s *WindowsService) RotateLogFile() error {
  163. m, err := mgr.Connect()
  164. if err != nil {
  165. return err
  166. }
  167. defer m.Disconnect()
  168. service, err := m.OpenService(serviceName)
  169. if err != nil {
  170. return fmt.Errorf("could not access service: %v", err)
  171. }
  172. defer service.Close()
  173. _, err = service.Control(rotateLogCmd)
  174. if err != nil {
  175. return fmt.Errorf("could not send control=%d: %v", rotateLogCmd, err)
  176. }
  177. return nil
  178. }
  179. func (s *WindowsService) Install(args ...string) error {
  180. exePath, err := s.getExePath()
  181. if err != nil {
  182. return err
  183. }
  184. m, err := mgr.Connect()
  185. if err != nil {
  186. return err
  187. }
  188. defer m.Disconnect()
  189. service, err := m.OpenService(serviceName)
  190. if err == nil {
  191. service.Close()
  192. return fmt.Errorf("service %s already exists", serviceName)
  193. }
  194. config := mgr.Config{
  195. DisplayName: serviceName,
  196. Description: serviceDesc,
  197. StartType: mgr.StartAutomatic}
  198. service, err = m.CreateService(serviceName, exePath, config, args...)
  199. if err != nil {
  200. return err
  201. }
  202. defer service.Close()
  203. err = eventlog.InstallAsEventCreate(serviceName, eventlog.Error|eventlog.Warning|eventlog.Info)
  204. if err != nil {
  205. if !strings.Contains(err.Error(), "exists") {
  206. service.Delete()
  207. return fmt.Errorf("SetupEventLogSource() failed: %s", err)
  208. }
  209. }
  210. recoveryActions := []mgr.RecoveryAction{
  211. {
  212. Type: mgr.ServiceRestart,
  213. Delay: 0,
  214. },
  215. {
  216. Type: mgr.ServiceRestart,
  217. Delay: 60 * time.Second,
  218. },
  219. {
  220. Type: mgr.NoAction,
  221. },
  222. }
  223. err = service.SetRecoveryActions(recoveryActions, uint32(3600))
  224. if err != nil {
  225. service.Delete()
  226. return fmt.Errorf("unable to set recovery actions: %v", err)
  227. }
  228. return nil
  229. }
  230. func (s *WindowsService) Uninstall() error {
  231. m, err := mgr.Connect()
  232. if err != nil {
  233. return err
  234. }
  235. defer m.Disconnect()
  236. service, err := m.OpenService(serviceName)
  237. if err != nil {
  238. return fmt.Errorf("service %s is not installed", serviceName)
  239. }
  240. defer service.Close()
  241. err = service.Delete()
  242. if err != nil {
  243. return err
  244. }
  245. err = eventlog.Remove(serviceName)
  246. if err != nil {
  247. return fmt.Errorf("RemoveEventLogSource() failed: %s", err)
  248. }
  249. return nil
  250. }
  251. func (s *WindowsService) Stop() error {
  252. m, err := mgr.Connect()
  253. if err != nil {
  254. return err
  255. }
  256. defer m.Disconnect()
  257. service, err := m.OpenService(serviceName)
  258. if err != nil {
  259. return fmt.Errorf("could not access service: %v", err)
  260. }
  261. defer service.Close()
  262. status, err := service.Control(svc.Stop)
  263. if err != nil {
  264. return fmt.Errorf("could not send control=%d: %v", svc.Stop, err)
  265. }
  266. timeout := time.Now().Add(10 * time.Second)
  267. for status.State != svc.Stopped {
  268. if timeout.Before(time.Now()) {
  269. return fmt.Errorf("timeout waiting for service to go to state=%d", svc.Stopped)
  270. }
  271. time.Sleep(300 * time.Millisecond)
  272. status, err = service.Query()
  273. if err != nil {
  274. return fmt.Errorf("could not retrieve service status: %v", err)
  275. }
  276. }
  277. return nil
  278. }
  279. func (s *WindowsService) Status() (Status, error) {
  280. m, err := mgr.Connect()
  281. if err != nil {
  282. return StatusUnknown, err
  283. }
  284. defer m.Disconnect()
  285. service, err := m.OpenService(serviceName)
  286. if err != nil {
  287. return StatusUnknown, fmt.Errorf("could not access service: %v", err)
  288. }
  289. defer service.Close()
  290. status, err := service.Query()
  291. if err != nil {
  292. return StatusUnknown, fmt.Errorf("could not query service status: %v", err)
  293. }
  294. switch status.State {
  295. case svc.StartPending:
  296. return StatusStartPending, nil
  297. case svc.Running:
  298. return StatusRunning, nil
  299. case svc.PausePending:
  300. return StatusPausePending, nil
  301. case svc.Paused:
  302. return StatusPaused, nil
  303. case svc.ContinuePending:
  304. return StatusContinuePending, nil
  305. case svc.StopPending:
  306. return StatusStopPending, nil
  307. case svc.Stopped:
  308. return StatusStopped, nil
  309. default:
  310. return StatusUnknown, fmt.Errorf("unknown status %v", status)
  311. }
  312. }
  313. func (s *WindowsService) getExePath() (string, error) {
  314. return os.Executable()
  315. }