actions_test.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. package common
  2. import (
  3. "errors"
  4. "fmt"
  5. "os"
  6. "os/exec"
  7. "path/filepath"
  8. "runtime"
  9. "testing"
  10. "github.com/lithammer/shortuuid/v3"
  11. "github.com/rs/xid"
  12. "github.com/sftpgo/sdk"
  13. "github.com/sftpgo/sdk/plugin/notifier"
  14. "github.com/stretchr/testify/assert"
  15. "github.com/drakkan/sftpgo/v2/dataprovider"
  16. "github.com/drakkan/sftpgo/v2/plugin"
  17. "github.com/drakkan/sftpgo/v2/vfs"
  18. )
  19. func TestNewActionNotification(t *testing.T) {
  20. user := &dataprovider.User{
  21. BaseUser: sdk.BaseUser{
  22. Username: "username",
  23. },
  24. }
  25. user.FsConfig.Provider = sdk.LocalFilesystemProvider
  26. user.FsConfig.S3Config = vfs.S3FsConfig{
  27. BaseS3FsConfig: sdk.BaseS3FsConfig{
  28. Bucket: "s3bucket",
  29. Endpoint: "endpoint",
  30. },
  31. }
  32. user.FsConfig.GCSConfig = vfs.GCSFsConfig{
  33. BaseGCSFsConfig: sdk.BaseGCSFsConfig{
  34. Bucket: "gcsbucket",
  35. },
  36. }
  37. user.FsConfig.AzBlobConfig = vfs.AzBlobFsConfig{
  38. BaseAzBlobFsConfig: sdk.BaseAzBlobFsConfig{
  39. Container: "azcontainer",
  40. Endpoint: "azendpoint",
  41. },
  42. }
  43. user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{
  44. BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{
  45. Endpoint: "sftpendpoint",
  46. },
  47. }
  48. sessionID := xid.New().String()
  49. a := newActionNotification(user, operationDownload, "path", "vpath", "target", "", "", ProtocolSFTP, "", sessionID,
  50. 123, 0, errors.New("fake error"))
  51. assert.Equal(t, user.Username, a.Username)
  52. assert.Equal(t, 0, len(a.Bucket))
  53. assert.Equal(t, 0, len(a.Endpoint))
  54. assert.Equal(t, 2, a.Status)
  55. user.FsConfig.Provider = sdk.S3FilesystemProvider
  56. a = newActionNotification(user, operationDownload, "path", "vpath", "target", "", "", ProtocolSSH, "", sessionID,
  57. 123, 0, nil)
  58. assert.Equal(t, "s3bucket", a.Bucket)
  59. assert.Equal(t, "endpoint", a.Endpoint)
  60. assert.Equal(t, 1, a.Status)
  61. user.FsConfig.Provider = sdk.GCSFilesystemProvider
  62. a = newActionNotification(user, operationDownload, "path", "vpath", "target", "", "", ProtocolSCP, "", sessionID,
  63. 123, 0, ErrQuotaExceeded)
  64. assert.Equal(t, "gcsbucket", a.Bucket)
  65. assert.Equal(t, 0, len(a.Endpoint))
  66. assert.Equal(t, 3, a.Status)
  67. user.FsConfig.Provider = sdk.AzureBlobFilesystemProvider
  68. a = newActionNotification(user, operationDownload, "path", "vpath", "target", "", "", ProtocolSCP, "", sessionID,
  69. 123, 0, nil)
  70. assert.Equal(t, "azcontainer", a.Bucket)
  71. assert.Equal(t, "azendpoint", a.Endpoint)
  72. assert.Equal(t, 1, a.Status)
  73. a = newActionNotification(user, operationDownload, "path", "vpath", "target", "", "", ProtocolSCP, "", sessionID,
  74. 123, os.O_APPEND, nil)
  75. assert.Equal(t, "azcontainer", a.Bucket)
  76. assert.Equal(t, "azendpoint", a.Endpoint)
  77. assert.Equal(t, 1, a.Status)
  78. assert.Equal(t, os.O_APPEND, a.OpenFlags)
  79. user.FsConfig.Provider = sdk.SFTPFilesystemProvider
  80. a = newActionNotification(user, operationDownload, "path", "vpath", "target", "", "", ProtocolSFTP, "", sessionID,
  81. 123, 0, nil)
  82. assert.Equal(t, "sftpendpoint", a.Endpoint)
  83. }
  84. func TestActionHTTP(t *testing.T) {
  85. actionsCopy := Config.Actions
  86. Config.Actions = ProtocolActions{
  87. ExecuteOn: []string{operationDownload},
  88. Hook: fmt.Sprintf("http://%v", httpAddr),
  89. }
  90. user := &dataprovider.User{
  91. BaseUser: sdk.BaseUser{
  92. Username: "username",
  93. },
  94. }
  95. a := newActionNotification(user, operationDownload, "path", "vpath", "target", "", "", ProtocolSFTP, "",
  96. xid.New().String(), 123, 0, nil)
  97. err := actionHandler.Handle(a)
  98. assert.NoError(t, err)
  99. Config.Actions.Hook = "http://invalid:1234"
  100. err = actionHandler.Handle(a)
  101. assert.Error(t, err)
  102. Config.Actions.Hook = fmt.Sprintf("http://%v/404", httpAddr)
  103. err = actionHandler.Handle(a)
  104. if assert.Error(t, err) {
  105. assert.EqualError(t, err, errUnexpectedHTTResponse.Error())
  106. }
  107. Config.Actions = actionsCopy
  108. }
  109. func TestActionCMD(t *testing.T) {
  110. if runtime.GOOS == osWindows {
  111. t.Skip("this test is not available on Windows")
  112. }
  113. actionsCopy := Config.Actions
  114. hookCmd, err := exec.LookPath("true")
  115. assert.NoError(t, err)
  116. Config.Actions = ProtocolActions{
  117. ExecuteOn: []string{operationDownload},
  118. Hook: hookCmd,
  119. }
  120. user := &dataprovider.User{
  121. BaseUser: sdk.BaseUser{
  122. Username: "username",
  123. },
  124. }
  125. sessionID := shortuuid.New()
  126. a := newActionNotification(user, operationDownload, "path", "vpath", "target", "", "", ProtocolSFTP, "", sessionID,
  127. 123, 0, nil)
  128. err = actionHandler.Handle(a)
  129. assert.NoError(t, err)
  130. c := NewBaseConnection("id", ProtocolSFTP, "", "", *user)
  131. ExecuteActionNotification(c, OperationSSHCmd, "path", "vpath", "target", "vtarget", "sha1sum", 0, nil)
  132. ExecuteActionNotification(c, operationDownload, "path", "vpath", "", "", "", 0, nil)
  133. Config.Actions = actionsCopy
  134. }
  135. func TestWrongActions(t *testing.T) {
  136. actionsCopy := Config.Actions
  137. badCommand := "/bad/command"
  138. if runtime.GOOS == osWindows {
  139. badCommand = "C:\\bad\\command"
  140. }
  141. Config.Actions = ProtocolActions{
  142. ExecuteOn: []string{operationUpload},
  143. Hook: badCommand,
  144. }
  145. user := &dataprovider.User{
  146. BaseUser: sdk.BaseUser{
  147. Username: "username",
  148. },
  149. }
  150. a := newActionNotification(user, operationUpload, "", "", "", "", "", ProtocolSFTP, "", xid.New().String(),
  151. 123, 0, nil)
  152. err := actionHandler.Handle(a)
  153. assert.Error(t, err, "action with bad command must fail")
  154. a.Action = operationDelete
  155. err = actionHandler.Handle(a)
  156. assert.EqualError(t, err, errUnconfiguredAction.Error())
  157. Config.Actions.Hook = "http://foo\x7f.com/"
  158. a.Action = operationUpload
  159. err = actionHandler.Handle(a)
  160. assert.Error(t, err, "action with bad url must fail")
  161. Config.Actions.Hook = ""
  162. err = actionHandler.Handle(a)
  163. if assert.Error(t, err) {
  164. assert.EqualError(t, err, errNoHook.Error())
  165. }
  166. Config.Actions.Hook = "relative path"
  167. err = actionHandler.Handle(a)
  168. if assert.Error(t, err) {
  169. assert.EqualError(t, err, fmt.Sprintf("invalid notification command %#v", Config.Actions.Hook))
  170. }
  171. Config.Actions = actionsCopy
  172. }
  173. func TestPreDeleteAction(t *testing.T) {
  174. if runtime.GOOS == osWindows {
  175. t.Skip("this test is not available on Windows")
  176. }
  177. actionsCopy := Config.Actions
  178. hookCmd, err := exec.LookPath("true")
  179. assert.NoError(t, err)
  180. Config.Actions = ProtocolActions{
  181. ExecuteOn: []string{operationPreDelete},
  182. Hook: hookCmd,
  183. }
  184. homeDir := filepath.Join(os.TempDir(), "test_user")
  185. err = os.MkdirAll(homeDir, os.ModePerm)
  186. assert.NoError(t, err)
  187. user := dataprovider.User{
  188. BaseUser: sdk.BaseUser{
  189. Username: "username",
  190. HomeDir: homeDir,
  191. },
  192. }
  193. user.Permissions = make(map[string][]string)
  194. user.Permissions["/"] = []string{dataprovider.PermAny}
  195. fs := vfs.NewOsFs("id", homeDir, "")
  196. c := NewBaseConnection("id", ProtocolSFTP, "", "", user)
  197. testfile := filepath.Join(user.HomeDir, "testfile")
  198. err = os.WriteFile(testfile, []byte("test"), os.ModePerm)
  199. assert.NoError(t, err)
  200. info, err := os.Stat(testfile)
  201. assert.NoError(t, err)
  202. err = c.RemoveFile(fs, testfile, "testfile", info)
  203. assert.NoError(t, err)
  204. assert.FileExists(t, testfile)
  205. os.RemoveAll(homeDir)
  206. Config.Actions = actionsCopy
  207. }
  208. func TestUnconfiguredHook(t *testing.T) {
  209. actionsCopy := Config.Actions
  210. Config.Actions = ProtocolActions{
  211. ExecuteOn: []string{operationDownload},
  212. Hook: "",
  213. }
  214. pluginsConfig := []plugin.Config{
  215. {
  216. Type: "notifier",
  217. },
  218. }
  219. err := plugin.Initialize(pluginsConfig, true)
  220. assert.Error(t, err)
  221. assert.True(t, plugin.Handler.HasNotifiers())
  222. c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{})
  223. err = ExecutePreAction(c, OperationPreDownload, "", "", 0, 0)
  224. assert.NoError(t, err)
  225. err = ExecutePreAction(c, operationPreDelete, "", "", 0, 0)
  226. assert.ErrorIs(t, err, errUnconfiguredAction)
  227. ExecuteActionNotification(c, operationDownload, "", "", "", "", "", 0, nil)
  228. err = plugin.Initialize(nil, true)
  229. assert.NoError(t, err)
  230. assert.False(t, plugin.Handler.HasNotifiers())
  231. Config.Actions = actionsCopy
  232. }
  233. type actionHandlerStub struct {
  234. called bool
  235. }
  236. func (h *actionHandlerStub) Handle(event *notifier.FsEvent) error {
  237. h.called = true
  238. return nil
  239. }
  240. func TestInitializeActionHandler(t *testing.T) {
  241. handler := &actionHandlerStub{}
  242. InitializeActionHandler(handler)
  243. t.Cleanup(func() {
  244. InitializeActionHandler(&defaultActionHandler{})
  245. })
  246. err := actionHandler.Handle(&notifier.FsEvent{})
  247. assert.NoError(t, err)
  248. assert.True(t, handler.called)
  249. }