actions_test.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. package common
  2. import (
  3. "errors"
  4. "fmt"
  5. "os"
  6. "os/exec"
  7. "path/filepath"
  8. "runtime"
  9. "testing"
  10. "github.com/stretchr/testify/assert"
  11. "github.com/drakkan/sftpgo/dataprovider"
  12. "github.com/drakkan/sftpgo/vfs"
  13. )
  14. func TestNewActionNotification(t *testing.T) {
  15. user := &dataprovider.User{
  16. Username: "username",
  17. }
  18. user.FsConfig.Provider = vfs.LocalFilesystemProvider
  19. user.FsConfig.S3Config = vfs.S3FsConfig{
  20. Bucket: "s3bucket",
  21. Endpoint: "endpoint",
  22. }
  23. user.FsConfig.GCSConfig = vfs.GCSFsConfig{
  24. Bucket: "gcsbucket",
  25. }
  26. user.FsConfig.AzBlobConfig = vfs.AzBlobFsConfig{
  27. Container: "azcontainer",
  28. SASURL: "azsasurl",
  29. Endpoint: "azendpoint",
  30. }
  31. a := newActionNotification(user, operationDownload, "path", "target", "", ProtocolSFTP, 123, errors.New("fake error"))
  32. assert.Equal(t, user.Username, a.Username)
  33. assert.Equal(t, 0, len(a.Bucket))
  34. assert.Equal(t, 0, len(a.Endpoint))
  35. assert.Equal(t, 0, a.Status)
  36. user.FsConfig.Provider = vfs.S3FilesystemProvider
  37. a = newActionNotification(user, operationDownload, "path", "target", "", ProtocolSSH, 123, nil)
  38. assert.Equal(t, "s3bucket", a.Bucket)
  39. assert.Equal(t, "endpoint", a.Endpoint)
  40. assert.Equal(t, 1, a.Status)
  41. user.FsConfig.Provider = vfs.GCSFilesystemProvider
  42. a = newActionNotification(user, operationDownload, "path", "target", "", ProtocolSCP, 123, ErrQuotaExceeded)
  43. assert.Equal(t, "gcsbucket", a.Bucket)
  44. assert.Equal(t, 0, len(a.Endpoint))
  45. assert.Equal(t, 2, a.Status)
  46. user.FsConfig.Provider = vfs.AzureBlobFilesystemProvider
  47. a = newActionNotification(user, operationDownload, "path", "target", "", ProtocolSCP, 123, nil)
  48. assert.Equal(t, "azcontainer", a.Bucket)
  49. assert.Equal(t, "azsasurl", a.Endpoint)
  50. assert.Equal(t, 1, a.Status)
  51. user.FsConfig.AzBlobConfig.SASURL = ""
  52. a = newActionNotification(user, operationDownload, "path", "target", "", ProtocolSCP, 123, nil)
  53. assert.Equal(t, "azcontainer", a.Bucket)
  54. assert.Equal(t, "azendpoint", a.Endpoint)
  55. assert.Equal(t, 1, a.Status)
  56. }
  57. func TestActionHTTP(t *testing.T) {
  58. actionsCopy := Config.Actions
  59. Config.Actions = ProtocolActions{
  60. ExecuteOn: []string{operationDownload},
  61. Hook: fmt.Sprintf("http://%v", httpAddr),
  62. }
  63. user := &dataprovider.User{
  64. Username: "username",
  65. }
  66. a := newActionNotification(user, operationDownload, "path", "target", "", ProtocolSFTP, 123, nil)
  67. err := actionHandler.Handle(a)
  68. assert.NoError(t, err)
  69. Config.Actions.Hook = "http://invalid:1234"
  70. err = actionHandler.Handle(a)
  71. assert.Error(t, err)
  72. Config.Actions.Hook = fmt.Sprintf("http://%v/404", httpAddr)
  73. err = actionHandler.Handle(a)
  74. if assert.Error(t, err) {
  75. assert.EqualError(t, err, errUnexpectedHTTResponse.Error())
  76. }
  77. Config.Actions = actionsCopy
  78. }
  79. func TestActionCMD(t *testing.T) {
  80. if runtime.GOOS == osWindows {
  81. t.Skip("this test is not available on Windows")
  82. }
  83. actionsCopy := Config.Actions
  84. hookCmd, err := exec.LookPath("true")
  85. assert.NoError(t, err)
  86. Config.Actions = ProtocolActions{
  87. ExecuteOn: []string{operationDownload},
  88. Hook: hookCmd,
  89. }
  90. user := &dataprovider.User{
  91. Username: "username",
  92. }
  93. a := newActionNotification(user, operationDownload, "path", "target", "", ProtocolSFTP, 123, nil)
  94. err = actionHandler.Handle(a)
  95. assert.NoError(t, err)
  96. SSHCommandActionNotification(user, "path", "target", "sha1sum", nil)
  97. Config.Actions = actionsCopy
  98. }
  99. func TestWrongActions(t *testing.T) {
  100. actionsCopy := Config.Actions
  101. badCommand := "/bad/command"
  102. if runtime.GOOS == osWindows {
  103. badCommand = "C:\\bad\\command"
  104. }
  105. Config.Actions = ProtocolActions{
  106. ExecuteOn: []string{operationUpload},
  107. Hook: badCommand,
  108. }
  109. user := &dataprovider.User{
  110. Username: "username",
  111. }
  112. a := newActionNotification(user, operationUpload, "", "", "", ProtocolSFTP, 123, nil)
  113. err := actionHandler.Handle(a)
  114. assert.Error(t, err, "action with bad command must fail")
  115. a.Action = operationDelete
  116. err = actionHandler.Handle(a)
  117. assert.EqualError(t, err, errUnconfiguredAction.Error())
  118. Config.Actions.Hook = "http://foo\x7f.com/"
  119. a.Action = operationUpload
  120. err = actionHandler.Handle(a)
  121. assert.Error(t, err, "action with bad url must fail")
  122. Config.Actions.Hook = ""
  123. err = actionHandler.Handle(a)
  124. if assert.Error(t, err) {
  125. assert.EqualError(t, err, errNoHook.Error())
  126. }
  127. Config.Actions.Hook = "relative path"
  128. err = actionHandler.Handle(a)
  129. if assert.Error(t, err) {
  130. assert.EqualError(t, err, fmt.Sprintf("invalid notification command %#v", Config.Actions.Hook))
  131. }
  132. Config.Actions = actionsCopy
  133. }
  134. func TestPreDeleteAction(t *testing.T) {
  135. if runtime.GOOS == osWindows {
  136. t.Skip("this test is not available on Windows")
  137. }
  138. actionsCopy := Config.Actions
  139. hookCmd, err := exec.LookPath("true")
  140. assert.NoError(t, err)
  141. Config.Actions = ProtocolActions{
  142. ExecuteOn: []string{operationPreDelete},
  143. Hook: hookCmd,
  144. }
  145. homeDir := filepath.Join(os.TempDir(), "test_user")
  146. err = os.MkdirAll(homeDir, os.ModePerm)
  147. assert.NoError(t, err)
  148. user := dataprovider.User{
  149. Username: "username",
  150. HomeDir: homeDir,
  151. }
  152. user.Permissions = make(map[string][]string)
  153. user.Permissions["/"] = []string{dataprovider.PermAny}
  154. fs := vfs.NewOsFs("id", homeDir, "")
  155. c := NewBaseConnection("id", ProtocolSFTP, user)
  156. testfile := filepath.Join(user.HomeDir, "testfile")
  157. err = os.WriteFile(testfile, []byte("test"), os.ModePerm)
  158. assert.NoError(t, err)
  159. info, err := os.Stat(testfile)
  160. assert.NoError(t, err)
  161. err = c.RemoveFile(fs, testfile, "testfile", info)
  162. assert.NoError(t, err)
  163. assert.FileExists(t, testfile)
  164. os.RemoveAll(homeDir)
  165. Config.Actions = actionsCopy
  166. }
  167. type actionHandlerStub struct {
  168. called bool
  169. }
  170. func (h *actionHandlerStub) Handle(notification *ActionNotification) error {
  171. h.called = true
  172. return nil
  173. }
  174. func TestInitializeActionHandler(t *testing.T) {
  175. handler := &actionHandlerStub{}
  176. InitializeActionHandler(handler)
  177. t.Cleanup(func() {
  178. InitializeActionHandler(&defaultActionHandler{})
  179. })
  180. err := actionHandler.Handle(&ActionNotification{})
  181. assert.NoError(t, err)
  182. assert.True(t, handler.called)
  183. }