defender_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. package common
  2. import (
  3. "crypto/rand"
  4. "encoding/json"
  5. "fmt"
  6. "net"
  7. "os"
  8. "path/filepath"
  9. "runtime"
  10. "testing"
  11. "time"
  12. "github.com/stretchr/testify/assert"
  13. "github.com/stretchr/testify/require"
  14. "github.com/yl2chen/cidranger"
  15. )
  16. func TestBasicDefender(t *testing.T) {
  17. bl := HostListFile{
  18. IPAddresses: []string{"172.16.1.1", "172.16.1.2"},
  19. CIDRNetworks: []string{"10.8.0.0/24"},
  20. }
  21. sl := HostListFile{
  22. IPAddresses: []string{"172.16.1.3", "172.16.1.4"},
  23. CIDRNetworks: []string{"192.168.8.0/24"},
  24. }
  25. blFile := filepath.Join(os.TempDir(), "bl.json")
  26. slFile := filepath.Join(os.TempDir(), "sl.json")
  27. data, err := json.Marshal(bl)
  28. assert.NoError(t, err)
  29. err = os.WriteFile(blFile, data, os.ModePerm)
  30. assert.NoError(t, err)
  31. data, err = json.Marshal(sl)
  32. assert.NoError(t, err)
  33. err = os.WriteFile(slFile, data, os.ModePerm)
  34. assert.NoError(t, err)
  35. config := &DefenderConfig{
  36. Enabled: true,
  37. BanTime: 10,
  38. BanTimeIncrement: 2,
  39. Threshold: 5,
  40. ScoreInvalid: 2,
  41. ScoreValid: 1,
  42. ScoreRateExceeded: 3,
  43. ObservationTime: 15,
  44. EntriesSoftLimit: 1,
  45. EntriesHardLimit: 2,
  46. SafeListFile: "slFile",
  47. BlockListFile: "blFile",
  48. }
  49. _, err = newInMemoryDefender(config)
  50. assert.Error(t, err)
  51. config.BlockListFile = blFile
  52. _, err = newInMemoryDefender(config)
  53. assert.Error(t, err)
  54. config.SafeListFile = slFile
  55. d, err := newInMemoryDefender(config)
  56. assert.NoError(t, err)
  57. defender := d.(*memoryDefender)
  58. assert.True(t, defender.IsBanned("172.16.1.1"))
  59. assert.False(t, defender.IsBanned("172.16.1.10"))
  60. assert.False(t, defender.IsBanned("10.8.2.3"))
  61. assert.True(t, defender.IsBanned("10.8.0.3"))
  62. assert.False(t, defender.IsBanned("invalid ip"))
  63. assert.Equal(t, 0, defender.countBanned())
  64. assert.Equal(t, 0, defender.countHosts())
  65. defender.AddEvent("172.16.1.4", HostEventLoginFailed)
  66. defender.AddEvent("192.168.8.4", HostEventUserNotFound)
  67. defender.AddEvent("172.16.1.3", HostEventRateExceeded)
  68. assert.Equal(t, 0, defender.countHosts())
  69. testIP := "12.34.56.78"
  70. defender.AddEvent(testIP, HostEventLoginFailed)
  71. assert.Equal(t, 1, defender.countHosts())
  72. assert.Equal(t, 0, defender.countBanned())
  73. assert.Equal(t, 1, defender.GetScore(testIP))
  74. assert.Nil(t, defender.GetBanTime(testIP))
  75. defender.AddEvent(testIP, HostEventRateExceeded)
  76. assert.Equal(t, 1, defender.countHosts())
  77. assert.Equal(t, 0, defender.countBanned())
  78. assert.Equal(t, 4, defender.GetScore(testIP))
  79. defender.AddEvent(testIP, HostEventNoLoginTried)
  80. assert.Equal(t, 0, defender.countHosts())
  81. assert.Equal(t, 1, defender.countBanned())
  82. assert.Equal(t, 0, defender.GetScore(testIP))
  83. assert.NotNil(t, defender.GetBanTime(testIP))
  84. // now test cleanup, testIP is already banned
  85. testIP1 := "12.34.56.79"
  86. testIP2 := "12.34.56.80"
  87. testIP3 := "12.34.56.81"
  88. defender.AddEvent(testIP1, HostEventNoLoginTried)
  89. defender.AddEvent(testIP2, HostEventNoLoginTried)
  90. assert.Equal(t, 2, defender.countHosts())
  91. time.Sleep(20 * time.Millisecond)
  92. defender.AddEvent(testIP3, HostEventNoLoginTried)
  93. assert.Equal(t, defender.config.EntriesSoftLimit, defender.countHosts())
  94. // testIP1 and testIP2 should be removed
  95. assert.Equal(t, defender.config.EntriesSoftLimit, defender.countHosts())
  96. assert.Equal(t, 0, defender.GetScore(testIP1))
  97. assert.Equal(t, 0, defender.GetScore(testIP2))
  98. assert.Equal(t, 2, defender.GetScore(testIP3))
  99. defender.AddEvent(testIP3, HostEventNoLoginTried)
  100. defender.AddEvent(testIP3, HostEventNoLoginTried)
  101. // IP3 is now banned
  102. assert.NotNil(t, defender.GetBanTime(testIP3))
  103. assert.Equal(t, 0, defender.countHosts())
  104. time.Sleep(20 * time.Millisecond)
  105. for i := 0; i < 3; i++ {
  106. defender.AddEvent(testIP1, HostEventNoLoginTried)
  107. }
  108. assert.Equal(t, 0, defender.countHosts())
  109. assert.Equal(t, config.EntriesSoftLimit, defender.countBanned())
  110. assert.Nil(t, defender.GetBanTime(testIP))
  111. assert.Nil(t, defender.GetBanTime(testIP3))
  112. assert.NotNil(t, defender.GetBanTime(testIP1))
  113. for i := 0; i < 3; i++ {
  114. defender.AddEvent(testIP, HostEventNoLoginTried)
  115. time.Sleep(10 * time.Millisecond)
  116. defender.AddEvent(testIP3, HostEventNoLoginTried)
  117. }
  118. assert.Equal(t, 0, defender.countHosts())
  119. assert.Equal(t, defender.config.EntriesSoftLimit, defender.countBanned())
  120. banTime := defender.GetBanTime(testIP3)
  121. if assert.NotNil(t, banTime) {
  122. assert.True(t, defender.IsBanned(testIP3))
  123. // ban time should increase
  124. newBanTime := defender.GetBanTime(testIP3)
  125. assert.True(t, newBanTime.After(*banTime))
  126. }
  127. assert.True(t, defender.Unban(testIP3))
  128. assert.False(t, defender.Unban(testIP3))
  129. err = os.Remove(slFile)
  130. assert.NoError(t, err)
  131. err = os.Remove(blFile)
  132. assert.NoError(t, err)
  133. }
  134. func TestLoadHostListFromFile(t *testing.T) {
  135. _, err := loadHostListFromFile(".")
  136. assert.Error(t, err)
  137. hostsFilePath := filepath.Join(os.TempDir(), "hostfile")
  138. content := make([]byte, 1048576*6)
  139. _, err = rand.Read(content)
  140. assert.NoError(t, err)
  141. err = os.WriteFile(hostsFilePath, content, os.ModePerm)
  142. assert.NoError(t, err)
  143. _, err = loadHostListFromFile(hostsFilePath)
  144. assert.Error(t, err)
  145. hl := HostListFile{
  146. IPAddresses: []string{},
  147. CIDRNetworks: []string{},
  148. }
  149. asJSON, err := json.Marshal(hl)
  150. assert.NoError(t, err)
  151. err = os.WriteFile(hostsFilePath, asJSON, os.ModePerm)
  152. assert.NoError(t, err)
  153. hostList, err := loadHostListFromFile(hostsFilePath)
  154. assert.NoError(t, err)
  155. assert.Nil(t, hostList)
  156. hl.IPAddresses = append(hl.IPAddresses, "invalidip")
  157. asJSON, err = json.Marshal(hl)
  158. assert.NoError(t, err)
  159. err = os.WriteFile(hostsFilePath, asJSON, os.ModePerm)
  160. assert.NoError(t, err)
  161. hostList, err = loadHostListFromFile(hostsFilePath)
  162. assert.NoError(t, err)
  163. assert.Len(t, hostList.IPAddresses, 0)
  164. hl.IPAddresses = nil
  165. hl.CIDRNetworks = append(hl.CIDRNetworks, "invalid net")
  166. asJSON, err = json.Marshal(hl)
  167. assert.NoError(t, err)
  168. err = os.WriteFile(hostsFilePath, asJSON, os.ModePerm)
  169. assert.NoError(t, err)
  170. hostList, err = loadHostListFromFile(hostsFilePath)
  171. assert.NoError(t, err)
  172. assert.NotNil(t, hostList)
  173. assert.Len(t, hostList.IPAddresses, 0)
  174. assert.Equal(t, 0, hostList.Ranges.Len())
  175. if runtime.GOOS != "windows" {
  176. err = os.Chmod(hostsFilePath, 0111)
  177. assert.NoError(t, err)
  178. _, err = loadHostListFromFile(hostsFilePath)
  179. assert.Error(t, err)
  180. err = os.Chmod(hostsFilePath, 0644)
  181. assert.NoError(t, err)
  182. }
  183. err = os.WriteFile(hostsFilePath, []byte("non json content"), os.ModePerm)
  184. assert.NoError(t, err)
  185. _, err = loadHostListFromFile(hostsFilePath)
  186. assert.Error(t, err)
  187. err = os.Remove(hostsFilePath)
  188. assert.NoError(t, err)
  189. }
  190. func TestDefenderCleanup(t *testing.T) {
  191. d := memoryDefender{
  192. banned: make(map[string]time.Time),
  193. hosts: make(map[string]hostScore),
  194. config: &DefenderConfig{
  195. ObservationTime: 1,
  196. EntriesSoftLimit: 2,
  197. EntriesHardLimit: 3,
  198. },
  199. }
  200. d.banned["1.1.1.1"] = time.Now().Add(-24 * time.Hour)
  201. d.banned["1.1.1.2"] = time.Now().Add(-24 * time.Hour)
  202. d.banned["1.1.1.3"] = time.Now().Add(-24 * time.Hour)
  203. d.banned["1.1.1.4"] = time.Now().Add(-24 * time.Hour)
  204. d.cleanupBanned()
  205. assert.Equal(t, 0, d.countBanned())
  206. d.banned["2.2.2.2"] = time.Now().Add(2 * time.Minute)
  207. d.banned["2.2.2.3"] = time.Now().Add(1 * time.Minute)
  208. d.banned["2.2.2.4"] = time.Now().Add(3 * time.Minute)
  209. d.banned["2.2.2.5"] = time.Now().Add(4 * time.Minute)
  210. d.cleanupBanned()
  211. assert.Equal(t, d.config.EntriesSoftLimit, d.countBanned())
  212. assert.Nil(t, d.GetBanTime("2.2.2.3"))
  213. d.hosts["3.3.3.3"] = hostScore{
  214. TotalScore: 0,
  215. Events: []hostEvent{
  216. {
  217. dateTime: time.Now().Add(-5 * time.Minute),
  218. score: 1,
  219. },
  220. {
  221. dateTime: time.Now().Add(-3 * time.Minute),
  222. score: 1,
  223. },
  224. {
  225. dateTime: time.Now(),
  226. score: 1,
  227. },
  228. },
  229. }
  230. d.hosts["3.3.3.4"] = hostScore{
  231. TotalScore: 1,
  232. Events: []hostEvent{
  233. {
  234. dateTime: time.Now().Add(-3 * time.Minute),
  235. score: 1,
  236. },
  237. },
  238. }
  239. d.hosts["3.3.3.5"] = hostScore{
  240. TotalScore: 1,
  241. Events: []hostEvent{
  242. {
  243. dateTime: time.Now().Add(-2 * time.Minute),
  244. score: 1,
  245. },
  246. },
  247. }
  248. d.hosts["3.3.3.6"] = hostScore{
  249. TotalScore: 1,
  250. Events: []hostEvent{
  251. {
  252. dateTime: time.Now().Add(-1 * time.Minute),
  253. score: 1,
  254. },
  255. },
  256. }
  257. assert.Equal(t, 1, d.GetScore("3.3.3.3"))
  258. d.cleanupHosts()
  259. assert.Equal(t, d.config.EntriesSoftLimit, d.countHosts())
  260. assert.Equal(t, 0, d.GetScore("3.3.3.4"))
  261. }
  262. func TestDefenderConfig(t *testing.T) {
  263. c := DefenderConfig{}
  264. err := c.validate()
  265. require.NoError(t, err)
  266. c.Enabled = true
  267. c.Threshold = 10
  268. c.ScoreInvalid = 10
  269. err = c.validate()
  270. require.Error(t, err)
  271. c.ScoreInvalid = 2
  272. c.ScoreRateExceeded = 10
  273. err = c.validate()
  274. require.Error(t, err)
  275. c.ScoreRateExceeded = 2
  276. c.ScoreValid = 10
  277. err = c.validate()
  278. require.Error(t, err)
  279. c.ScoreValid = 1
  280. c.BanTime = 0
  281. err = c.validate()
  282. require.Error(t, err)
  283. c.BanTime = 30
  284. c.BanTimeIncrement = 0
  285. err = c.validate()
  286. require.Error(t, err)
  287. c.BanTimeIncrement = 50
  288. c.ObservationTime = 0
  289. err = c.validate()
  290. require.Error(t, err)
  291. c.ObservationTime = 30
  292. err = c.validate()
  293. require.Error(t, err)
  294. c.EntriesSoftLimit = 10
  295. err = c.validate()
  296. require.Error(t, err)
  297. c.EntriesHardLimit = 10
  298. err = c.validate()
  299. require.Error(t, err)
  300. c.EntriesHardLimit = 20
  301. err = c.validate()
  302. require.NoError(t, err)
  303. }
  304. func BenchmarkDefenderBannedSearch(b *testing.B) {
  305. d := getDefenderForBench()
  306. ip, ipnet, err := net.ParseCIDR("10.8.0.0/12") // 1048574 ip addresses
  307. if err != nil {
  308. panic(err)
  309. }
  310. for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); inc(ip) {
  311. d.banned[ip.String()] = time.Now().Add(10 * time.Minute)
  312. }
  313. b.ResetTimer()
  314. for i := 0; i < b.N; i++ {
  315. d.IsBanned("192.168.1.1")
  316. }
  317. }
  318. func BenchmarkCleanup(b *testing.B) {
  319. d := getDefenderForBench()
  320. ip, ipnet, err := net.ParseCIDR("192.168.4.0/24")
  321. if err != nil {
  322. panic(err)
  323. }
  324. b.ResetTimer()
  325. for i := 0; i < b.N; i++ {
  326. for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); inc(ip) {
  327. d.AddEvent(ip.String(), HostEventLoginFailed)
  328. if d.countHosts() > d.config.EntriesHardLimit {
  329. panic("too many hosts")
  330. }
  331. if d.countBanned() > d.config.EntriesSoftLimit {
  332. panic("too many ip banned")
  333. }
  334. }
  335. }
  336. }
  337. func BenchmarkDefenderBannedSearchWithBlockList(b *testing.B) {
  338. d := getDefenderForBench()
  339. d.blockList = &HostList{
  340. IPAddresses: make(map[string]bool),
  341. Ranges: cidranger.NewPCTrieRanger(),
  342. }
  343. ip, ipnet, err := net.ParseCIDR("129.8.0.0/12") // 1048574 ip addresses
  344. if err != nil {
  345. panic(err)
  346. }
  347. for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); inc(ip) {
  348. d.banned[ip.String()] = time.Now().Add(10 * time.Minute)
  349. d.blockList.IPAddresses[ip.String()] = true
  350. }
  351. for i := 0; i < 255; i++ {
  352. cidr := fmt.Sprintf("10.8.%v.1/24", i)
  353. _, network, _ := net.ParseCIDR(cidr)
  354. if err := d.blockList.Ranges.Insert(cidranger.NewBasicRangerEntry(*network)); err != nil {
  355. panic(err)
  356. }
  357. }
  358. b.ResetTimer()
  359. for i := 0; i < b.N; i++ {
  360. d.IsBanned("192.168.1.1")
  361. }
  362. }
  363. func BenchmarkHostListSearch(b *testing.B) {
  364. hostlist := &HostList{
  365. IPAddresses: make(map[string]bool),
  366. Ranges: cidranger.NewPCTrieRanger(),
  367. }
  368. ip, ipnet, _ := net.ParseCIDR("172.16.0.0/16")
  369. for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); inc(ip) {
  370. hostlist.IPAddresses[ip.String()] = true
  371. }
  372. for i := 0; i < 255; i++ {
  373. cidr := fmt.Sprintf("10.8.%v.1/24", i)
  374. _, network, _ := net.ParseCIDR(cidr)
  375. if err := hostlist.Ranges.Insert(cidranger.NewBasicRangerEntry(*network)); err != nil {
  376. panic(err)
  377. }
  378. }
  379. b.ResetTimer()
  380. for i := 0; i < b.N; i++ {
  381. if hostlist.isListed("192.167.1.2") {
  382. panic("should not be listed")
  383. }
  384. }
  385. }
  386. func BenchmarkCIDRanger(b *testing.B) {
  387. ranger := cidranger.NewPCTrieRanger()
  388. for i := 0; i < 255; i++ {
  389. cidr := fmt.Sprintf("192.168.%v.1/24", i)
  390. _, network, _ := net.ParseCIDR(cidr)
  391. if err := ranger.Insert(cidranger.NewBasicRangerEntry(*network)); err != nil {
  392. panic(err)
  393. }
  394. }
  395. ipToMatch := net.ParseIP("192.167.1.2")
  396. b.ResetTimer()
  397. for i := 0; i < b.N; i++ {
  398. if _, err := ranger.Contains(ipToMatch); err != nil {
  399. panic(err)
  400. }
  401. }
  402. }
  403. func BenchmarkNetContains(b *testing.B) {
  404. var nets []*net.IPNet
  405. for i := 0; i < 255; i++ {
  406. cidr := fmt.Sprintf("192.168.%v.1/24", i)
  407. _, network, _ := net.ParseCIDR(cidr)
  408. nets = append(nets, network)
  409. }
  410. ipToMatch := net.ParseIP("192.167.1.1")
  411. b.ResetTimer()
  412. for i := 0; i < b.N; i++ {
  413. for _, n := range nets {
  414. n.Contains(ipToMatch)
  415. }
  416. }
  417. }
  418. func getDefenderForBench() *memoryDefender {
  419. config := &DefenderConfig{
  420. Enabled: true,
  421. BanTime: 30,
  422. BanTimeIncrement: 50,
  423. Threshold: 10,
  424. ScoreInvalid: 2,
  425. ScoreValid: 2,
  426. ObservationTime: 30,
  427. EntriesSoftLimit: 50,
  428. EntriesHardLimit: 100,
  429. }
  430. return &memoryDefender{
  431. config: config,
  432. hosts: make(map[string]hostScore),
  433. banned: make(map[string]time.Time),
  434. }
  435. }
  436. func inc(ip net.IP) {
  437. for j := len(ip) - 1; j >= 0; j-- {
  438. ip[j]++
  439. if ip[j] > 0 {
  440. break
  441. }
  442. }
  443. }