defenderdb_test.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. // Copyright (C) 2019-2022 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package common
  15. import (
  16. "encoding/hex"
  17. "encoding/json"
  18. "os"
  19. "path/filepath"
  20. "testing"
  21. "time"
  22. "github.com/stretchr/testify/assert"
  23. "github.com/drakkan/sftpgo/v2/internal/dataprovider"
  24. "github.com/drakkan/sftpgo/v2/internal/util"
  25. )
  26. func TestBasicDbDefender(t *testing.T) {
  27. if !isDbDefenderSupported() {
  28. t.Skip("this test is not supported with the current database provider")
  29. }
  30. config := &DefenderConfig{
  31. Enabled: true,
  32. BanTime: 10,
  33. BanTimeIncrement: 2,
  34. Threshold: 5,
  35. ScoreInvalid: 2,
  36. ScoreValid: 1,
  37. ScoreNoAuth: 2,
  38. ScoreLimitExceeded: 3,
  39. ObservationTime: 15,
  40. EntriesSoftLimit: 1,
  41. EntriesHardLimit: 10,
  42. SafeListFile: "slFile",
  43. BlockListFile: "blFile",
  44. }
  45. _, err := newDBDefender(config)
  46. assert.Error(t, err)
  47. bl := HostListFile{
  48. IPAddresses: []string{"172.16.1.1", "172.16.1.2"},
  49. CIDRNetworks: []string{"10.8.0.0/24"},
  50. }
  51. sl := HostListFile{
  52. IPAddresses: []string{"172.16.1.3", "172.16.1.4"},
  53. CIDRNetworks: []string{"192.168.8.0/24"},
  54. }
  55. blFile := filepath.Join(os.TempDir(), "bl.json")
  56. slFile := filepath.Join(os.TempDir(), "sl.json")
  57. data, err := json.Marshal(bl)
  58. assert.NoError(t, err)
  59. err = os.WriteFile(blFile, data, os.ModePerm)
  60. assert.NoError(t, err)
  61. data, err = json.Marshal(sl)
  62. assert.NoError(t, err)
  63. err = os.WriteFile(slFile, data, os.ModePerm)
  64. assert.NoError(t, err)
  65. config.BlockListFile = blFile
  66. _, err = newDBDefender(config)
  67. assert.Error(t, err)
  68. config.SafeListFile = slFile
  69. d, err := newDBDefender(config)
  70. assert.NoError(t, err)
  71. defender := d.(*dbDefender)
  72. assert.True(t, defender.IsBanned("172.16.1.1"))
  73. assert.False(t, defender.IsBanned("172.16.1.10"))
  74. assert.False(t, defender.IsBanned("10.8.1.3"))
  75. assert.True(t, defender.IsBanned("10.8.0.4"))
  76. assert.False(t, defender.IsBanned("invalid ip"))
  77. hosts, err := defender.GetHosts()
  78. assert.NoError(t, err)
  79. assert.Len(t, hosts, 0)
  80. _, err = defender.GetHost("10.8.0.3")
  81. assert.Error(t, err)
  82. defender.AddEvent("172.16.1.4", HostEventLoginFailed)
  83. defender.AddEvent("192.168.8.4", HostEventUserNotFound)
  84. defender.AddEvent("172.16.1.3", HostEventLimitExceeded)
  85. hosts, err = defender.GetHosts()
  86. assert.NoError(t, err)
  87. assert.Len(t, hosts, 0)
  88. assert.True(t, defender.getLastCleanup().IsZero())
  89. testIP := "123.45.67.89"
  90. defender.AddEvent(testIP, HostEventLoginFailed)
  91. lastCleanup := defender.getLastCleanup()
  92. assert.False(t, lastCleanup.IsZero())
  93. score, err := defender.GetScore(testIP)
  94. assert.NoError(t, err)
  95. assert.Equal(t, 1, score)
  96. hosts, err = defender.GetHosts()
  97. assert.NoError(t, err)
  98. if assert.Len(t, hosts, 1) {
  99. assert.Equal(t, 1, hosts[0].Score)
  100. assert.True(t, hosts[0].BanTime.IsZero())
  101. assert.Empty(t, hosts[0].GetBanTime())
  102. }
  103. host, err := defender.GetHost(testIP)
  104. assert.NoError(t, err)
  105. assert.Equal(t, 1, host.Score)
  106. assert.Empty(t, host.GetBanTime())
  107. banTime, err := defender.GetBanTime(testIP)
  108. assert.NoError(t, err)
  109. assert.Nil(t, banTime)
  110. defender.AddEvent(testIP, HostEventLimitExceeded)
  111. score, err = defender.GetScore(testIP)
  112. assert.NoError(t, err)
  113. assert.Equal(t, 4, score)
  114. hosts, err = defender.GetHosts()
  115. assert.NoError(t, err)
  116. if assert.Len(t, hosts, 1) {
  117. assert.Equal(t, 4, hosts[0].Score)
  118. assert.True(t, hosts[0].BanTime.IsZero())
  119. assert.Empty(t, hosts[0].GetBanTime())
  120. }
  121. defender.AddEvent(testIP, HostEventNoLoginTried)
  122. defender.AddEvent(testIP, HostEventNoLoginTried)
  123. score, err = defender.GetScore(testIP)
  124. assert.NoError(t, err)
  125. assert.Equal(t, 0, score)
  126. banTime, err = defender.GetBanTime(testIP)
  127. assert.NoError(t, err)
  128. assert.NotNil(t, banTime)
  129. hosts, err = defender.GetHosts()
  130. assert.NoError(t, err)
  131. if assert.Len(t, hosts, 1) {
  132. assert.Equal(t, 0, hosts[0].Score)
  133. assert.False(t, hosts[0].BanTime.IsZero())
  134. assert.NotEmpty(t, hosts[0].GetBanTime())
  135. assert.Equal(t, hex.EncodeToString([]byte(testIP)), hosts[0].GetID())
  136. }
  137. host, err = defender.GetHost(testIP)
  138. assert.NoError(t, err)
  139. assert.Equal(t, 0, host.Score)
  140. assert.NotEmpty(t, host.GetBanTime())
  141. // ban time should increase
  142. assert.True(t, defender.IsBanned(testIP))
  143. newBanTime, err := defender.GetBanTime(testIP)
  144. assert.NoError(t, err)
  145. assert.True(t, newBanTime.After(*banTime))
  146. assert.True(t, defender.DeleteHost(testIP))
  147. assert.False(t, defender.DeleteHost(testIP))
  148. // test cleanup
  149. testIP1 := "123.45.67.90"
  150. testIP2 := "123.45.67.91"
  151. testIP3 := "123.45.67.92"
  152. for i := 0; i < 3; i++ {
  153. defender.AddEvent(testIP, HostEventUserNotFound)
  154. defender.AddEvent(testIP1, HostEventNoLoginTried)
  155. defender.AddEvent(testIP2, HostEventUserNotFound)
  156. }
  157. hosts, err = defender.GetHosts()
  158. assert.NoError(t, err)
  159. assert.Len(t, hosts, 3)
  160. for _, host := range hosts {
  161. assert.Equal(t, 0, host.Score)
  162. assert.False(t, host.BanTime.IsZero())
  163. assert.NotEmpty(t, host.GetBanTime())
  164. }
  165. defender.AddEvent(testIP3, HostEventLoginFailed)
  166. hosts, err = defender.GetHosts()
  167. assert.NoError(t, err)
  168. assert.Len(t, hosts, 4)
  169. // now set a ban time in the past, so the host will be cleanead up
  170. for _, ip := range []string{testIP1, testIP2} {
  171. err = dataprovider.SetDefenderBanTime(ip, util.GetTimeAsMsSinceEpoch(time.Now().Add(-1*time.Minute)))
  172. assert.NoError(t, err)
  173. }
  174. hosts, err = defender.GetHosts()
  175. assert.NoError(t, err)
  176. assert.Len(t, hosts, 4)
  177. for _, host := range hosts {
  178. switch host.IP {
  179. case testIP:
  180. assert.Equal(t, 0, host.Score)
  181. assert.False(t, host.BanTime.IsZero())
  182. assert.NotEmpty(t, host.GetBanTime())
  183. case testIP3:
  184. assert.Equal(t, 1, host.Score)
  185. assert.True(t, host.BanTime.IsZero())
  186. assert.Empty(t, host.GetBanTime())
  187. default:
  188. assert.Equal(t, 6, host.Score)
  189. assert.True(t, host.BanTime.IsZero())
  190. assert.Empty(t, host.GetBanTime())
  191. }
  192. }
  193. host, err = defender.GetHost(testIP)
  194. assert.NoError(t, err)
  195. assert.Equal(t, 0, host.Score)
  196. assert.False(t, host.BanTime.IsZero())
  197. assert.NotEmpty(t, host.GetBanTime())
  198. host, err = defender.GetHost(testIP3)
  199. assert.NoError(t, err)
  200. assert.Equal(t, 1, host.Score)
  201. assert.True(t, host.BanTime.IsZero())
  202. assert.Empty(t, host.GetBanTime())
  203. // set a negative observation time so the from field in the queries will be in the future
  204. // we still should get the banned hosts
  205. defender.config.ObservationTime = -2
  206. assert.Greater(t, defender.getStartObservationTime(), time.Now().UnixMilli())
  207. hosts, err = defender.GetHosts()
  208. assert.NoError(t, err)
  209. if assert.Len(t, hosts, 1) {
  210. assert.Equal(t, testIP, hosts[0].IP)
  211. assert.Equal(t, 0, hosts[0].Score)
  212. assert.False(t, hosts[0].BanTime.IsZero())
  213. assert.NotEmpty(t, hosts[0].GetBanTime())
  214. }
  215. _, err = defender.GetHost(testIP)
  216. assert.NoError(t, err)
  217. // cleanup db
  218. err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute)))
  219. assert.NoError(t, err)
  220. // the banned host must still be there
  221. hosts, err = defender.GetHosts()
  222. assert.NoError(t, err)
  223. if assert.Len(t, hosts, 1) {
  224. assert.Equal(t, testIP, hosts[0].IP)
  225. assert.Equal(t, 0, hosts[0].Score)
  226. assert.False(t, hosts[0].BanTime.IsZero())
  227. assert.NotEmpty(t, hosts[0].GetBanTime())
  228. }
  229. _, err = defender.GetHost(testIP)
  230. assert.NoError(t, err)
  231. err = dataprovider.SetDefenderBanTime(testIP, util.GetTimeAsMsSinceEpoch(time.Now().Add(-1*time.Minute)))
  232. assert.NoError(t, err)
  233. err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute)))
  234. assert.NoError(t, err)
  235. hosts, err = defender.GetHosts()
  236. assert.NoError(t, err)
  237. assert.Len(t, hosts, 0)
  238. err = os.Remove(slFile)
  239. assert.NoError(t, err)
  240. err = os.Remove(blFile)
  241. assert.NoError(t, err)
  242. }
  243. func TestDbDefenderCleanup(t *testing.T) {
  244. if !isDbDefenderSupported() {
  245. t.Skip("this test is not supported with the current database provider")
  246. }
  247. config := &DefenderConfig{
  248. Enabled: true,
  249. BanTime: 10,
  250. BanTimeIncrement: 2,
  251. Threshold: 5,
  252. ScoreInvalid: 2,
  253. ScoreValid: 1,
  254. ScoreLimitExceeded: 3,
  255. ObservationTime: 15,
  256. EntriesSoftLimit: 1,
  257. EntriesHardLimit: 10,
  258. }
  259. d, err := newDBDefender(config)
  260. assert.NoError(t, err)
  261. defender := d.(*dbDefender)
  262. lastCleanup := defender.getLastCleanup()
  263. assert.True(t, lastCleanup.IsZero())
  264. defender.cleanup()
  265. lastCleanup = defender.getLastCleanup()
  266. assert.False(t, lastCleanup.IsZero())
  267. defender.cleanup()
  268. assert.Equal(t, lastCleanup, defender.getLastCleanup())
  269. defender.setLastCleanup(time.Now().Add(-time.Duration(config.ObservationTime) * time.Minute * 4))
  270. time.Sleep(20 * time.Millisecond)
  271. defender.cleanup()
  272. assert.True(t, lastCleanup.Before(defender.getLastCleanup()))
  273. providerConf := dataprovider.GetProviderConfig()
  274. err = dataprovider.Close()
  275. assert.NoError(t, err)
  276. lastCleanup = time.Now().Add(-time.Duration(config.ObservationTime) * time.Minute * 4)
  277. defender.setLastCleanup(lastCleanup)
  278. defender.cleanup()
  279. // cleanup will fail and so last cleanup should be reset to the previous value
  280. assert.Equal(t, lastCleanup, defender.getLastCleanup())
  281. err = dataprovider.Initialize(providerConf, configDir, true)
  282. assert.NoError(t, err)
  283. }
  284. func isDbDefenderSupported() bool {
  285. // SQLite shares the implementation with other SQL-based provider but it makes no sense
  286. // to use it outside test cases
  287. switch dataprovider.GetProviderStatus().Driver {
  288. case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName,
  289. dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName:
  290. return true
  291. default:
  292. return false
  293. }
  294. }