common_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. package common
  2. import (
  3. "fmt"
  4. "net"
  5. "net/http"
  6. "os"
  7. "os/exec"
  8. "runtime"
  9. "strings"
  10. "sync/atomic"
  11. "testing"
  12. "time"
  13. "github.com/rs/zerolog"
  14. "github.com/spf13/viper"
  15. "github.com/stretchr/testify/assert"
  16. "github.com/drakkan/sftpgo/dataprovider"
  17. "github.com/drakkan/sftpgo/httpclient"
  18. "github.com/drakkan/sftpgo/logger"
  19. )
  20. const (
  21. logSenderTest = "common_test"
  22. httpAddr = "127.0.0.1:9999"
  23. httpProxyAddr = "127.0.0.1:7777"
  24. configDir = ".."
  25. osWindows = "windows"
  26. userTestUsername = "common_test_username"
  27. userTestPwd = "common_test_pwd"
  28. )
  29. type providerConf struct {
  30. Config dataprovider.Config `json:"data_provider" mapstructure:"data_provider"`
  31. }
  32. type fakeConnection struct {
  33. *BaseConnection
  34. command string
  35. }
  36. func (c *fakeConnection) AddUser(user dataprovider.User) error {
  37. fs, err := user.GetFilesystem(c.GetID())
  38. if err != nil {
  39. return err
  40. }
  41. c.BaseConnection.User = user
  42. c.BaseConnection.Fs = fs
  43. return nil
  44. }
  45. func (c *fakeConnection) Disconnect() error {
  46. Connections.Remove(c.GetID())
  47. return nil
  48. }
  49. func (c *fakeConnection) GetClientVersion() string {
  50. return ""
  51. }
  52. func (c *fakeConnection) GetCommand() string {
  53. return c.command
  54. }
  55. func (c *fakeConnection) GetRemoteAddress() string {
  56. return ""
  57. }
  58. func (c *fakeConnection) SetConnDeadline() {}
  59. func TestMain(m *testing.M) {
  60. logfilePath := "common_test.log"
  61. logger.InitLogger(logfilePath, 5, 1, 28, false, zerolog.DebugLevel)
  62. viper.SetEnvPrefix("sftpgo")
  63. replacer := strings.NewReplacer(".", "__")
  64. viper.SetEnvKeyReplacer(replacer)
  65. viper.SetConfigName("sftpgo")
  66. viper.AutomaticEnv()
  67. viper.AllowEmptyEnv(true)
  68. driver, err := initializeDataprovider(-1)
  69. if err != nil {
  70. logger.WarnToConsole("error initializing data provider: %v", err)
  71. os.Exit(1)
  72. }
  73. logger.InfoToConsole("Starting COMMON tests, provider: %v", driver)
  74. Initialize(Configuration{})
  75. httpConfig := httpclient.Config{
  76. Timeout: 5,
  77. }
  78. httpConfig.Initialize(configDir)
  79. go func() {
  80. // start a test HTTP server to receive action notifications
  81. http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
  82. fmt.Fprintf(w, "OK\n")
  83. })
  84. http.HandleFunc("/404", func(w http.ResponseWriter, r *http.Request) {
  85. w.WriteHeader(http.StatusNotFound)
  86. fmt.Fprintf(w, "Not found\n")
  87. })
  88. if err := http.ListenAndServe(httpAddr, nil); err != nil {
  89. logger.ErrorToConsole("could not start HTTP notification server: %v", err)
  90. os.Exit(1)
  91. }
  92. }()
  93. go func() {
  94. Config.ProxyProtocol = 2
  95. listener, err := net.Listen("tcp", httpProxyAddr)
  96. if err != nil {
  97. logger.ErrorToConsole("error creating listener for proxy protocol server: %v", err)
  98. os.Exit(1)
  99. }
  100. proxyListener, err := Config.GetProxyListener(listener)
  101. if err != nil {
  102. logger.ErrorToConsole("error creating proxy protocol listener: %v", err)
  103. os.Exit(1)
  104. }
  105. Config.ProxyProtocol = 0
  106. s := &http.Server{}
  107. if err := s.Serve(proxyListener); err != nil {
  108. logger.ErrorToConsole("could not start HTTP proxy protocol server: %v", err)
  109. os.Exit(1)
  110. }
  111. }()
  112. waitTCPListening(httpAddr)
  113. waitTCPListening(httpProxyAddr)
  114. exitCode := m.Run()
  115. os.Remove(logfilePath) //nolint:errcheck
  116. os.Exit(exitCode)
  117. }
  118. func waitTCPListening(address string) {
  119. for {
  120. conn, err := net.Dial("tcp", address)
  121. if err != nil {
  122. logger.WarnToConsole("tcp server %v not listening: %v\n", address, err)
  123. time.Sleep(100 * time.Millisecond)
  124. continue
  125. }
  126. logger.InfoToConsole("tcp server %v now listening\n", address)
  127. conn.Close()
  128. break
  129. }
  130. }
  131. func initializeDataprovider(trackQuota int) (string, error) {
  132. configDir := ".."
  133. viper.AddConfigPath(configDir)
  134. if err := viper.ReadInConfig(); err != nil {
  135. return "", err
  136. }
  137. var cfg providerConf
  138. if err := viper.Unmarshal(&cfg); err != nil {
  139. return "", err
  140. }
  141. if trackQuota >= 0 && trackQuota <= 2 {
  142. cfg.Config.TrackQuota = trackQuota
  143. }
  144. return cfg.Config.Driver, dataprovider.Initialize(cfg.Config, configDir)
  145. }
  146. func closeDataprovider() error {
  147. return dataprovider.Close()
  148. }
  149. func TestIdleConnections(t *testing.T) {
  150. configCopy := Config
  151. Config.IdleTimeout = 1
  152. Initialize(Config)
  153. username := "test_user"
  154. user := dataprovider.User{
  155. Username: username,
  156. }
  157. c := NewBaseConnection("id1", ProtocolSFTP, user, nil)
  158. c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
  159. fakeConn := &fakeConnection{
  160. BaseConnection: c,
  161. }
  162. Connections.Add(fakeConn)
  163. assert.Equal(t, Connections.GetActiveSessions(username), 1)
  164. c = NewBaseConnection("id2", ProtocolFTP, dataprovider.User{}, nil)
  165. c.lastActivity = time.Now().UnixNano()
  166. fakeConn = &fakeConnection{
  167. BaseConnection: c,
  168. }
  169. Connections.Add(fakeConn)
  170. assert.Equal(t, Connections.GetActiveSessions(username), 1)
  171. assert.Len(t, Connections.GetStats(), 2)
  172. startIdleTimeoutTicker(100 * time.Millisecond)
  173. assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 0 }, 1*time.Second, 200*time.Millisecond)
  174. stopIdleTimeoutTicker()
  175. assert.Len(t, Connections.GetStats(), 1)
  176. c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
  177. startIdleTimeoutTicker(100 * time.Millisecond)
  178. assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 1*time.Second, 200*time.Millisecond)
  179. stopIdleTimeoutTicker()
  180. Config = configCopy
  181. }
  182. func TestCloseConnection(t *testing.T) {
  183. c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{}, nil)
  184. fakeConn := &fakeConnection{
  185. BaseConnection: c,
  186. }
  187. Connections.Add(fakeConn)
  188. assert.Len(t, Connections.GetStats(), 1)
  189. res := Connections.Close(fakeConn.GetID())
  190. assert.True(t, res)
  191. assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
  192. res = Connections.Close(fakeConn.GetID())
  193. assert.False(t, res)
  194. Connections.Remove(fakeConn.GetID())
  195. }
  196. func TestSwapConnection(t *testing.T) {
  197. c := NewBaseConnection("id", ProtocolFTP, dataprovider.User{}, nil)
  198. fakeConn := &fakeConnection{
  199. BaseConnection: c,
  200. }
  201. Connections.Add(fakeConn)
  202. if assert.Len(t, Connections.GetStats(), 1) {
  203. assert.Equal(t, "", Connections.GetStats()[0].Username)
  204. }
  205. c = NewBaseConnection("id", ProtocolFTP, dataprovider.User{
  206. Username: userTestUsername,
  207. }, nil)
  208. fakeConn = &fakeConnection{
  209. BaseConnection: c,
  210. }
  211. err := Connections.Swap(fakeConn)
  212. assert.NoError(t, err)
  213. if assert.Len(t, Connections.GetStats(), 1) {
  214. assert.Equal(t, userTestUsername, Connections.GetStats()[0].Username)
  215. }
  216. res := Connections.Close(fakeConn.GetID())
  217. assert.True(t, res)
  218. assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
  219. err = Connections.Swap(fakeConn)
  220. assert.Error(t, err)
  221. }
  222. func TestAtomicUpload(t *testing.T) {
  223. configCopy := Config
  224. Config.UploadMode = UploadModeStandard
  225. assert.False(t, Config.IsAtomicUploadEnabled())
  226. Config.UploadMode = UploadModeAtomic
  227. assert.True(t, Config.IsAtomicUploadEnabled())
  228. Config.UploadMode = UploadModeAtomicWithResume
  229. assert.True(t, Config.IsAtomicUploadEnabled())
  230. Config = configCopy
  231. }
  232. func TestConnectionStatus(t *testing.T) {
  233. username := "test_user"
  234. user := dataprovider.User{
  235. Username: username,
  236. }
  237. c1 := NewBaseConnection("id1", ProtocolSFTP, user, nil)
  238. fakeConn1 := &fakeConnection{
  239. BaseConnection: c1,
  240. }
  241. t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/r1", TransferUpload, 0, 0, true)
  242. t1.BytesReceived = 123
  243. t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/r2", TransferDownload, 0, 0, true)
  244. t2.BytesSent = 456
  245. c2 := NewBaseConnection("id2", ProtocolSSH, user, nil)
  246. fakeConn2 := &fakeConnection{
  247. BaseConnection: c2,
  248. command: "md5sum",
  249. }
  250. c3 := NewBaseConnection("id3", ProtocolWebDAV, user, nil)
  251. fakeConn3 := &fakeConnection{
  252. BaseConnection: c3,
  253. command: "PROPFIND",
  254. }
  255. t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/r2", TransferDownload, 0, 0, true)
  256. Connections.Add(fakeConn1)
  257. Connections.Add(fakeConn2)
  258. Connections.Add(fakeConn3)
  259. stats := Connections.GetStats()
  260. assert.Len(t, stats, 3)
  261. for _, stat := range stats {
  262. assert.Equal(t, stat.Username, username)
  263. assert.True(t, strings.HasPrefix(stat.GetConnectionInfo(), stat.Protocol))
  264. assert.True(t, strings.HasPrefix(stat.GetConnectionDuration(), "00:"))
  265. if stat.ConnectionID == "SFTP_id1" {
  266. assert.Len(t, stat.Transfers, 2)
  267. assert.Greater(t, len(stat.GetTransfersAsString()), 0)
  268. for _, tr := range stat.Transfers {
  269. if tr.OperationType == operationDownload {
  270. assert.True(t, strings.HasPrefix(tr.getConnectionTransferAsString(), "DL"))
  271. } else if tr.OperationType == operationUpload {
  272. assert.True(t, strings.HasPrefix(tr.getConnectionTransferAsString(), "UL"))
  273. }
  274. }
  275. } else if stat.ConnectionID == "DAV_id3" {
  276. assert.Len(t, stat.Transfers, 1)
  277. assert.Greater(t, len(stat.GetTransfersAsString()), 0)
  278. } else {
  279. assert.Equal(t, 0, len(stat.GetTransfersAsString()))
  280. }
  281. }
  282. err := t1.Close()
  283. assert.NoError(t, err)
  284. err = t2.Close()
  285. assert.NoError(t, err)
  286. err = fakeConn3.SignalTransfersAbort()
  287. assert.NoError(t, err)
  288. assert.Equal(t, int32(1), atomic.LoadInt32(&t3.AbortTransfer))
  289. err = t3.Close()
  290. assert.NoError(t, err)
  291. err = fakeConn3.SignalTransfersAbort()
  292. assert.Error(t, err)
  293. Connections.Remove(fakeConn1.GetID())
  294. Connections.Remove(fakeConn2.GetID())
  295. Connections.Remove(fakeConn3.GetID())
  296. stats = Connections.GetStats()
  297. assert.Len(t, stats, 0)
  298. }
  299. func TestQuotaScans(t *testing.T) {
  300. username := "username"
  301. assert.True(t, QuotaScans.AddUserQuotaScan(username))
  302. assert.False(t, QuotaScans.AddUserQuotaScan(username))
  303. if assert.Len(t, QuotaScans.GetUsersQuotaScans(), 1) {
  304. assert.Equal(t, QuotaScans.GetUsersQuotaScans()[0].Username, username)
  305. }
  306. assert.True(t, QuotaScans.RemoveUserQuotaScan(username))
  307. assert.False(t, QuotaScans.RemoveUserQuotaScan(username))
  308. assert.Len(t, QuotaScans.GetUsersQuotaScans(), 0)
  309. folderName := "/folder"
  310. assert.True(t, QuotaScans.AddVFolderQuotaScan(folderName))
  311. assert.False(t, QuotaScans.AddVFolderQuotaScan(folderName))
  312. if assert.Len(t, QuotaScans.GetVFoldersQuotaScans(), 1) {
  313. assert.Equal(t, QuotaScans.GetVFoldersQuotaScans()[0].MappedPath, folderName)
  314. }
  315. assert.True(t, QuotaScans.RemoveVFolderQuotaScan(folderName))
  316. assert.False(t, QuotaScans.RemoveVFolderQuotaScan(folderName))
  317. assert.Len(t, QuotaScans.GetVFoldersQuotaScans(), 0)
  318. }
  319. func TestProxyProtocolVersion(t *testing.T) {
  320. c := Configuration{
  321. ProxyProtocol: 1,
  322. }
  323. proxyListener, err := c.GetProxyListener(nil)
  324. assert.NoError(t, err)
  325. assert.Nil(t, proxyListener.Policy)
  326. c.ProxyProtocol = 2
  327. proxyListener, err = c.GetProxyListener(nil)
  328. assert.NoError(t, err)
  329. assert.NotNil(t, proxyListener.Policy)
  330. c.ProxyProtocol = 1
  331. c.ProxyAllowed = []string{"invalid"}
  332. _, err = c.GetProxyListener(nil)
  333. assert.Error(t, err)
  334. c.ProxyProtocol = 2
  335. _, err = c.GetProxyListener(nil)
  336. assert.Error(t, err)
  337. }
  338. func TestProxyProtocol(t *testing.T) {
  339. httpClient := httpclient.GetHTTPClient()
  340. resp, err := httpClient.Get(fmt.Sprintf("http://%v", httpProxyAddr))
  341. if assert.NoError(t, err) {
  342. defer resp.Body.Close()
  343. assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
  344. }
  345. }
  346. func TestPostConnectHook(t *testing.T) {
  347. Config.PostConnectHook = ""
  348. remoteAddr := &net.IPAddr{
  349. IP: net.ParseIP("127.0.0.1"),
  350. Zone: "",
  351. }
  352. assert.NoError(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolFTP))
  353. Config.PostConnectHook = "http://foo\x7f.com/"
  354. assert.Error(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolSFTP))
  355. Config.PostConnectHook = "http://invalid:1234/"
  356. assert.Error(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolSFTP))
  357. Config.PostConnectHook = fmt.Sprintf("http://%v/404", httpAddr)
  358. assert.Error(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolFTP))
  359. Config.PostConnectHook = fmt.Sprintf("http://%v", httpAddr)
  360. assert.NoError(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolFTP))
  361. Config.PostConnectHook = "invalid"
  362. assert.Error(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolFTP))
  363. if runtime.GOOS == osWindows {
  364. Config.PostConnectHook = "C:\\bad\\command"
  365. assert.Error(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolSFTP))
  366. } else {
  367. Config.PostConnectHook = "/invalid/path"
  368. assert.Error(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolSFTP))
  369. hookCmd, err := exec.LookPath("true")
  370. assert.NoError(t, err)
  371. Config.PostConnectHook = hookCmd
  372. assert.NoError(t, Config.ExecutePostConnectHook(remoteAddr.String(), ProtocolSFTP))
  373. }
  374. Config.PostConnectHook = ""
  375. }