1
0

defenderdb_test.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. package common
  2. import (
  3. "encoding/hex"
  4. "encoding/json"
  5. "os"
  6. "path/filepath"
  7. "testing"
  8. "time"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/drakkan/sftpgo/v2/dataprovider"
  11. "github.com/drakkan/sftpgo/v2/util"
  12. )
  13. func TestBasicDbDefender(t *testing.T) {
  14. if !isDbDefenderSupported() {
  15. t.Skip("this test is not supported with the current database provider")
  16. }
  17. config := &DefenderConfig{
  18. Enabled: true,
  19. BanTime: 10,
  20. BanTimeIncrement: 2,
  21. Threshold: 5,
  22. ScoreInvalid: 2,
  23. ScoreValid: 1,
  24. ScoreLimitExceeded: 3,
  25. ObservationTime: 15,
  26. EntriesSoftLimit: 1,
  27. EntriesHardLimit: 10,
  28. SafeListFile: "slFile",
  29. BlockListFile: "blFile",
  30. }
  31. _, err := newDBDefender(config)
  32. assert.Error(t, err)
  33. bl := HostListFile{
  34. IPAddresses: []string{"172.16.1.1", "172.16.1.2"},
  35. CIDRNetworks: []string{"10.8.0.0/24"},
  36. }
  37. sl := HostListFile{
  38. IPAddresses: []string{"172.16.1.3", "172.16.1.4"},
  39. CIDRNetworks: []string{"192.168.8.0/24"},
  40. }
  41. blFile := filepath.Join(os.TempDir(), "bl.json")
  42. slFile := filepath.Join(os.TempDir(), "sl.json")
  43. data, err := json.Marshal(bl)
  44. assert.NoError(t, err)
  45. err = os.WriteFile(blFile, data, os.ModePerm)
  46. assert.NoError(t, err)
  47. data, err = json.Marshal(sl)
  48. assert.NoError(t, err)
  49. err = os.WriteFile(slFile, data, os.ModePerm)
  50. assert.NoError(t, err)
  51. config.BlockListFile = blFile
  52. _, err = newDBDefender(config)
  53. assert.Error(t, err)
  54. config.SafeListFile = slFile
  55. d, err := newDBDefender(config)
  56. assert.NoError(t, err)
  57. defender := d.(*dbDefender)
  58. assert.True(t, defender.IsBanned("172.16.1.1"))
  59. assert.False(t, defender.IsBanned("172.16.1.10"))
  60. assert.False(t, defender.IsBanned("10.8.1.3"))
  61. assert.True(t, defender.IsBanned("10.8.0.4"))
  62. assert.False(t, defender.IsBanned("invalid ip"))
  63. hosts, err := defender.GetHosts()
  64. assert.NoError(t, err)
  65. assert.Len(t, hosts, 0)
  66. _, err = defender.GetHost("10.8.0.3")
  67. assert.Error(t, err)
  68. defender.AddEvent("172.16.1.4", HostEventLoginFailed)
  69. defender.AddEvent("192.168.8.4", HostEventUserNotFound)
  70. defender.AddEvent("172.16.1.3", HostEventLimitExceeded)
  71. hosts, err = defender.GetHosts()
  72. assert.NoError(t, err)
  73. assert.Len(t, hosts, 0)
  74. assert.True(t, defender.getLastCleanup().IsZero())
  75. testIP := "123.45.67.89"
  76. defender.AddEvent(testIP, HostEventLoginFailed)
  77. lastCleanup := defender.getLastCleanup()
  78. assert.False(t, lastCleanup.IsZero())
  79. score, err := defender.GetScore(testIP)
  80. assert.NoError(t, err)
  81. assert.Equal(t, 1, score)
  82. hosts, err = defender.GetHosts()
  83. assert.NoError(t, err)
  84. if assert.Len(t, hosts, 1) {
  85. assert.Equal(t, 1, hosts[0].Score)
  86. assert.True(t, hosts[0].BanTime.IsZero())
  87. assert.Empty(t, hosts[0].GetBanTime())
  88. }
  89. host, err := defender.GetHost(testIP)
  90. assert.NoError(t, err)
  91. assert.Equal(t, 1, host.Score)
  92. assert.Empty(t, host.GetBanTime())
  93. banTime, err := defender.GetBanTime(testIP)
  94. assert.NoError(t, err)
  95. assert.Nil(t, banTime)
  96. defender.AddEvent(testIP, HostEventLimitExceeded)
  97. score, err = defender.GetScore(testIP)
  98. assert.NoError(t, err)
  99. assert.Equal(t, 4, score)
  100. hosts, err = defender.GetHosts()
  101. assert.NoError(t, err)
  102. if assert.Len(t, hosts, 1) {
  103. assert.Equal(t, 4, hosts[0].Score)
  104. assert.True(t, hosts[0].BanTime.IsZero())
  105. assert.Empty(t, hosts[0].GetBanTime())
  106. }
  107. defender.AddEvent(testIP, HostEventNoLoginTried)
  108. defender.AddEvent(testIP, HostEventNoLoginTried)
  109. score, err = defender.GetScore(testIP)
  110. assert.NoError(t, err)
  111. assert.Equal(t, 0, score)
  112. banTime, err = defender.GetBanTime(testIP)
  113. assert.NoError(t, err)
  114. assert.NotNil(t, banTime)
  115. hosts, err = defender.GetHosts()
  116. assert.NoError(t, err)
  117. if assert.Len(t, hosts, 1) {
  118. assert.Equal(t, 0, hosts[0].Score)
  119. assert.False(t, hosts[0].BanTime.IsZero())
  120. assert.NotEmpty(t, hosts[0].GetBanTime())
  121. assert.Equal(t, hex.EncodeToString([]byte(testIP)), hosts[0].GetID())
  122. }
  123. host, err = defender.GetHost(testIP)
  124. assert.NoError(t, err)
  125. assert.Equal(t, 0, host.Score)
  126. assert.NotEmpty(t, host.GetBanTime())
  127. // ban time should increase
  128. assert.True(t, defender.IsBanned(testIP))
  129. newBanTime, err := defender.GetBanTime(testIP)
  130. assert.NoError(t, err)
  131. assert.True(t, newBanTime.After(*banTime))
  132. assert.True(t, defender.DeleteHost(testIP))
  133. assert.False(t, defender.DeleteHost(testIP))
  134. // test cleanup
  135. testIP1 := "123.45.67.90"
  136. testIP2 := "123.45.67.91"
  137. testIP3 := "123.45.67.92"
  138. for i := 0; i < 3; i++ {
  139. defender.AddEvent(testIP, HostEventNoLoginTried)
  140. defender.AddEvent(testIP1, HostEventNoLoginTried)
  141. defender.AddEvent(testIP2, HostEventNoLoginTried)
  142. }
  143. hosts, err = defender.GetHosts()
  144. assert.NoError(t, err)
  145. assert.Len(t, hosts, 3)
  146. for _, host := range hosts {
  147. assert.Equal(t, 0, host.Score)
  148. assert.False(t, host.BanTime.IsZero())
  149. assert.NotEmpty(t, host.GetBanTime())
  150. }
  151. defender.AddEvent(testIP3, HostEventLoginFailed)
  152. hosts, err = defender.GetHosts()
  153. assert.NoError(t, err)
  154. assert.Len(t, hosts, 4)
  155. // now set a ban time in the past, so the host will be cleanead up
  156. for _, ip := range []string{testIP1, testIP2} {
  157. err = dataprovider.SetDefenderBanTime(ip, util.GetTimeAsMsSinceEpoch(time.Now().Add(-1*time.Minute)))
  158. assert.NoError(t, err)
  159. }
  160. hosts, err = defender.GetHosts()
  161. assert.NoError(t, err)
  162. assert.Len(t, hosts, 4)
  163. for _, host := range hosts {
  164. switch host.IP {
  165. case testIP:
  166. assert.Equal(t, 0, host.Score)
  167. assert.False(t, host.BanTime.IsZero())
  168. assert.NotEmpty(t, host.GetBanTime())
  169. case testIP3:
  170. assert.Equal(t, 1, host.Score)
  171. assert.True(t, host.BanTime.IsZero())
  172. assert.Empty(t, host.GetBanTime())
  173. default:
  174. assert.Equal(t, 6, host.Score)
  175. assert.True(t, host.BanTime.IsZero())
  176. assert.Empty(t, host.GetBanTime())
  177. }
  178. }
  179. host, err = defender.GetHost(testIP)
  180. assert.NoError(t, err)
  181. assert.Equal(t, 0, host.Score)
  182. assert.False(t, host.BanTime.IsZero())
  183. assert.NotEmpty(t, host.GetBanTime())
  184. host, err = defender.GetHost(testIP3)
  185. assert.NoError(t, err)
  186. assert.Equal(t, 1, host.Score)
  187. assert.True(t, host.BanTime.IsZero())
  188. assert.Empty(t, host.GetBanTime())
  189. // set a negative observation time so the from field in the queries will be in the future
  190. // we still should get the banned hosts
  191. defender.config.ObservationTime = -2
  192. assert.Greater(t, defender.getStartObservationTime(), time.Now().UnixMilli())
  193. hosts, err = defender.GetHosts()
  194. assert.NoError(t, err)
  195. if assert.Len(t, hosts, 1) {
  196. assert.Equal(t, testIP, hosts[0].IP)
  197. assert.Equal(t, 0, hosts[0].Score)
  198. assert.False(t, hosts[0].BanTime.IsZero())
  199. assert.NotEmpty(t, hosts[0].GetBanTime())
  200. }
  201. _, err = defender.GetHost(testIP)
  202. assert.NoError(t, err)
  203. // cleanup db
  204. err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute)))
  205. assert.NoError(t, err)
  206. // the banned host must still be there
  207. hosts, err = defender.GetHosts()
  208. assert.NoError(t, err)
  209. if assert.Len(t, hosts, 1) {
  210. assert.Equal(t, testIP, hosts[0].IP)
  211. assert.Equal(t, 0, hosts[0].Score)
  212. assert.False(t, hosts[0].BanTime.IsZero())
  213. assert.NotEmpty(t, hosts[0].GetBanTime())
  214. }
  215. _, err = defender.GetHost(testIP)
  216. assert.NoError(t, err)
  217. err = dataprovider.SetDefenderBanTime(testIP, util.GetTimeAsMsSinceEpoch(time.Now().Add(-1*time.Minute)))
  218. assert.NoError(t, err)
  219. err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute)))
  220. assert.NoError(t, err)
  221. hosts, err = defender.GetHosts()
  222. assert.NoError(t, err)
  223. assert.Len(t, hosts, 0)
  224. err = os.Remove(slFile)
  225. assert.NoError(t, err)
  226. err = os.Remove(blFile)
  227. assert.NoError(t, err)
  228. }
  229. func TestDbDefenderCleanup(t *testing.T) {
  230. if !isDbDefenderSupported() {
  231. t.Skip("this test is not supported with the current database provider")
  232. }
  233. config := &DefenderConfig{
  234. Enabled: true,
  235. BanTime: 10,
  236. BanTimeIncrement: 2,
  237. Threshold: 5,
  238. ScoreInvalid: 2,
  239. ScoreValid: 1,
  240. ScoreLimitExceeded: 3,
  241. ObservationTime: 15,
  242. EntriesSoftLimit: 1,
  243. EntriesHardLimit: 10,
  244. }
  245. d, err := newDBDefender(config)
  246. assert.NoError(t, err)
  247. defender := d.(*dbDefender)
  248. lastCleanup := defender.getLastCleanup()
  249. assert.True(t, lastCleanup.IsZero())
  250. defender.cleanup()
  251. lastCleanup = defender.getLastCleanup()
  252. assert.False(t, lastCleanup.IsZero())
  253. defender.cleanup()
  254. assert.Equal(t, lastCleanup, defender.getLastCleanup())
  255. defender.setLastCleanup(time.Now().Add(-time.Duration(config.ObservationTime) * time.Minute * 4))
  256. time.Sleep(20 * time.Millisecond)
  257. defender.cleanup()
  258. assert.True(t, lastCleanup.Before(defender.getLastCleanup()))
  259. providerConf := dataprovider.GetProviderConfig()
  260. err = dataprovider.Close()
  261. assert.NoError(t, err)
  262. lastCleanup = time.Now().Add(-time.Duration(config.ObservationTime) * time.Minute * 4)
  263. defender.setLastCleanup(lastCleanup)
  264. defender.cleanup()
  265. // cleanup will fail and so last cleanup should be reset to the previous value
  266. assert.Equal(t, lastCleanup, defender.getLastCleanup())
  267. err = dataprovider.Initialize(providerConf, configDir, true)
  268. assert.NoError(t, err)
  269. }
  270. func isDbDefenderSupported() bool {
  271. // SQLite shares the implementation with other SQL-based provider but it makes no sense
  272. // to use it outside test cases
  273. switch dataprovider.GetProviderStatus().Driver {
  274. case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName,
  275. dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName:
  276. return true
  277. default:
  278. return false
  279. }
  280. }