defenderdb_test.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. package common
  2. import (
  3. "encoding/hex"
  4. "encoding/json"
  5. "os"
  6. "path/filepath"
  7. "testing"
  8. "time"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/drakkan/sftpgo/v2/dataprovider"
  11. "github.com/drakkan/sftpgo/v2/util"
  12. )
  13. func TestBasicDbDefender(t *testing.T) {
  14. if !isDbDefenderSupported() {
  15. t.Skip("this test is not supported with the current database provider")
  16. }
  17. config := &DefenderConfig{
  18. Enabled: true,
  19. BanTime: 10,
  20. BanTimeIncrement: 2,
  21. Threshold: 5,
  22. ScoreInvalid: 2,
  23. ScoreValid: 1,
  24. ScoreLimitExceeded: 3,
  25. ObservationTime: 15,
  26. EntriesSoftLimit: 1,
  27. EntriesHardLimit: 10,
  28. SafeListFile: "slFile",
  29. BlockListFile: "blFile",
  30. }
  31. _, err := newDBDefender(config)
  32. assert.Error(t, err)
  33. bl := HostListFile{
  34. IPAddresses: []string{"172.16.1.1", "172.16.1.2"},
  35. CIDRNetworks: []string{"10.8.0.0/24"},
  36. }
  37. sl := HostListFile{
  38. IPAddresses: []string{"172.16.1.3", "172.16.1.4"},
  39. CIDRNetworks: []string{"192.168.8.0/24"},
  40. }
  41. blFile := filepath.Join(os.TempDir(), "bl.json")
  42. slFile := filepath.Join(os.TempDir(), "sl.json")
  43. data, err := json.Marshal(bl)
  44. assert.NoError(t, err)
  45. err = os.WriteFile(blFile, data, os.ModePerm)
  46. assert.NoError(t, err)
  47. data, err = json.Marshal(sl)
  48. assert.NoError(t, err)
  49. err = os.WriteFile(slFile, data, os.ModePerm)
  50. assert.NoError(t, err)
  51. config.BlockListFile = blFile
  52. _, err = newDBDefender(config)
  53. assert.Error(t, err)
  54. config.SafeListFile = slFile
  55. d, err := newDBDefender(config)
  56. assert.NoError(t, err)
  57. defender := d.(*dbDefender)
  58. assert.True(t, defender.IsBanned("172.16.1.1"))
  59. assert.False(t, defender.IsBanned("172.16.1.10"))
  60. assert.False(t, defender.IsBanned("10.8.1.3"))
  61. assert.True(t, defender.IsBanned("10.8.0.4"))
  62. assert.False(t, defender.IsBanned("invalid ip"))
  63. hosts, err := defender.GetHosts()
  64. assert.NoError(t, err)
  65. assert.Len(t, hosts, 0)
  66. _, err = defender.GetHost("10.8.0.3")
  67. assert.Error(t, err)
  68. defender.AddEvent("172.16.1.4", HostEventLoginFailed)
  69. defender.AddEvent("192.168.8.4", HostEventUserNotFound)
  70. defender.AddEvent("172.16.1.3", HostEventLimitExceeded)
  71. hosts, err = defender.GetHosts()
  72. assert.NoError(t, err)
  73. assert.Len(t, hosts, 0)
  74. assert.True(t, defender.getLastCleanup().IsZero())
  75. testIP := "123.45.67.89"
  76. defender.AddEvent(testIP, HostEventLoginFailed)
  77. lastCleanup := defender.getLastCleanup()
  78. assert.False(t, lastCleanup.IsZero())
  79. score, err := defender.GetScore(testIP)
  80. assert.NoError(t, err)
  81. assert.Equal(t, 1, score)
  82. hosts, err = defender.GetHosts()
  83. assert.NoError(t, err)
  84. if assert.Len(t, hosts, 1) {
  85. assert.Equal(t, 1, hosts[0].Score)
  86. assert.True(t, hosts[0].BanTime.IsZero())
  87. assert.Empty(t, hosts[0].GetBanTime())
  88. }
  89. host, err := defender.GetHost(testIP)
  90. assert.NoError(t, err)
  91. assert.Equal(t, 1, host.Score)
  92. assert.Empty(t, host.GetBanTime())
  93. banTime, err := defender.GetBanTime(testIP)
  94. assert.NoError(t, err)
  95. assert.Nil(t, banTime)
  96. defender.AddEvent(testIP, HostEventLimitExceeded)
  97. score, err = defender.GetScore(testIP)
  98. assert.NoError(t, err)
  99. assert.Equal(t, 4, score)
  100. hosts, err = defender.GetHosts()
  101. assert.NoError(t, err)
  102. if assert.Len(t, hosts, 1) {
  103. assert.Equal(t, 4, hosts[0].Score)
  104. assert.True(t, hosts[0].BanTime.IsZero())
  105. assert.Empty(t, hosts[0].GetBanTime())
  106. }
  107. defender.AddEvent(testIP, HostEventNoLoginTried)
  108. defender.AddEvent(testIP, HostEventNoLoginTried)
  109. score, err = defender.GetScore(testIP)
  110. assert.NoError(t, err)
  111. assert.Equal(t, 0, score)
  112. banTime, err = defender.GetBanTime(testIP)
  113. assert.NoError(t, err)
  114. assert.NotNil(t, banTime)
  115. hosts, err = defender.GetHosts()
  116. assert.NoError(t, err)
  117. if assert.Len(t, hosts, 1) {
  118. assert.Equal(t, 0, hosts[0].Score)
  119. assert.False(t, hosts[0].BanTime.IsZero())
  120. assert.NotEmpty(t, hosts[0].GetBanTime())
  121. assert.Equal(t, hex.EncodeToString([]byte(testIP)), hosts[0].GetID())
  122. }
  123. host, err = defender.GetHost(testIP)
  124. assert.NoError(t, err)
  125. assert.Equal(t, 0, host.Score)
  126. assert.NotEmpty(t, host.GetBanTime())
  127. // ban time should increase
  128. assert.True(t, defender.IsBanned(testIP))
  129. newBanTime, err := defender.GetBanTime(testIP)
  130. assert.NoError(t, err)
  131. assert.True(t, newBanTime.After(*banTime))
  132. assert.True(t, defender.DeleteHost(testIP))
  133. assert.False(t, defender.DeleteHost(testIP))
  134. // test cleanup
  135. testIP1 := "123.45.67.90"
  136. testIP2 := "123.45.67.91"
  137. testIP3 := "123.45.67.92"
  138. for i := 0; i < 3; i++ {
  139. defender.AddEvent(testIP, HostEventNoLoginTried)
  140. defender.AddEvent(testIP1, HostEventNoLoginTried)
  141. defender.AddEvent(testIP2, HostEventNoLoginTried)
  142. }
  143. hosts, err = defender.GetHosts()
  144. assert.NoError(t, err)
  145. assert.Len(t, hosts, 3)
  146. for _, host := range hosts {
  147. assert.Equal(t, 0, host.Score)
  148. assert.False(t, host.BanTime.IsZero())
  149. assert.NotEmpty(t, host.GetBanTime())
  150. }
  151. defender.AddEvent(testIP3, HostEventLoginFailed)
  152. hosts, err = defender.GetHosts()
  153. assert.NoError(t, err)
  154. assert.Len(t, hosts, 4)
  155. // now set a ban time in the past, so the host will be cleanead up
  156. for _, ip := range []string{testIP1, testIP2} {
  157. err = dataprovider.SetDefenderBanTime(ip, util.GetTimeAsMsSinceEpoch(time.Now().Add(-1*time.Minute)))
  158. assert.NoError(t, err)
  159. }
  160. hosts, err = defender.GetHosts()
  161. assert.NoError(t, err)
  162. assert.Len(t, hosts, 4)
  163. for _, host := range hosts {
  164. switch host.IP {
  165. case testIP:
  166. assert.Equal(t, 0, host.Score)
  167. assert.False(t, host.BanTime.IsZero())
  168. assert.NotEmpty(t, host.GetBanTime())
  169. case testIP3:
  170. assert.Equal(t, 1, host.Score)
  171. assert.True(t, host.BanTime.IsZero())
  172. assert.Empty(t, host.GetBanTime())
  173. default:
  174. assert.Equal(t, 6, host.Score)
  175. assert.True(t, host.BanTime.IsZero())
  176. assert.Empty(t, host.GetBanTime())
  177. }
  178. }
  179. host, err = defender.GetHost(testIP)
  180. assert.NoError(t, err)
  181. assert.Equal(t, 0, host.Score)
  182. assert.False(t, host.BanTime.IsZero())
  183. assert.NotEmpty(t, host.GetBanTime())
  184. host, err = defender.GetHost(testIP3)
  185. assert.NoError(t, err)
  186. assert.Equal(t, 1, host.Score)
  187. assert.True(t, host.BanTime.IsZero())
  188. assert.Empty(t, host.GetBanTime())
  189. // cleanup db
  190. err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute)))
  191. assert.NoError(t, err)
  192. hosts, err = defender.GetHosts()
  193. assert.NoError(t, err)
  194. if assert.Len(t, hosts, 1) {
  195. assert.Equal(t, testIP, hosts[0].IP)
  196. assert.Equal(t, 0, hosts[0].Score)
  197. assert.False(t, hosts[0].BanTime.IsZero())
  198. assert.NotEmpty(t, hosts[0].GetBanTime())
  199. }
  200. err = dataprovider.SetDefenderBanTime(testIP, util.GetTimeAsMsSinceEpoch(time.Now().Add(-1*time.Minute)))
  201. assert.NoError(t, err)
  202. err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute)))
  203. assert.NoError(t, err)
  204. hosts, err = defender.GetHosts()
  205. assert.NoError(t, err)
  206. assert.Len(t, hosts, 0)
  207. err = os.Remove(slFile)
  208. assert.NoError(t, err)
  209. err = os.Remove(blFile)
  210. assert.NoError(t, err)
  211. }
  212. func TestDbDefenderCleanup(t *testing.T) {
  213. if !isDbDefenderSupported() {
  214. t.Skip("this test is not supported with the current database provider")
  215. }
  216. config := &DefenderConfig{
  217. Enabled: true,
  218. BanTime: 10,
  219. BanTimeIncrement: 2,
  220. Threshold: 5,
  221. ScoreInvalid: 2,
  222. ScoreValid: 1,
  223. ScoreLimitExceeded: 3,
  224. ObservationTime: 15,
  225. EntriesSoftLimit: 1,
  226. EntriesHardLimit: 10,
  227. }
  228. d, err := newDBDefender(config)
  229. assert.NoError(t, err)
  230. defender := d.(*dbDefender)
  231. lastCleanup := defender.getLastCleanup()
  232. assert.True(t, lastCleanup.IsZero())
  233. defender.cleanup()
  234. lastCleanup = defender.getLastCleanup()
  235. assert.False(t, lastCleanup.IsZero())
  236. defender.cleanup()
  237. assert.Equal(t, lastCleanup, defender.getLastCleanup())
  238. defender.setLastCleanup(time.Now().Add(-time.Duration(config.ObservationTime) * time.Minute * 4))
  239. time.Sleep(20 * time.Millisecond)
  240. defender.cleanup()
  241. assert.True(t, lastCleanup.Before(defender.getLastCleanup()))
  242. providerConf := dataprovider.GetProviderConfig()
  243. err = dataprovider.Close()
  244. assert.NoError(t, err)
  245. lastCleanup = time.Now().Add(-time.Duration(config.ObservationTime) * time.Minute * 4)
  246. defender.setLastCleanup(lastCleanup)
  247. defender.cleanup()
  248. // cleanup will fail and so last cleanup should be reset to the previous value
  249. assert.Equal(t, lastCleanup, defender.getLastCleanup())
  250. err = dataprovider.Initialize(providerConf, configDir, true)
  251. assert.NoError(t, err)
  252. }
  253. func isDbDefenderSupported() bool {
  254. // SQLite shares the implementation with other SQL-based provider but it makes no sense
  255. // to use it outside test cases
  256. switch dataprovider.GetProviderStatus().Driver {
  257. case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName,
  258. dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName:
  259. return true
  260. default:
  261. return false
  262. }
  263. }