defender_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. package common
  2. import (
  3. "crypto/rand"
  4. "encoding/json"
  5. "fmt"
  6. "net"
  7. "os"
  8. "path/filepath"
  9. "runtime"
  10. "testing"
  11. "time"
  12. "github.com/stretchr/testify/assert"
  13. "github.com/stretchr/testify/require"
  14. "github.com/yl2chen/cidranger"
  15. )
  16. func TestBasicDefender(t *testing.T) {
  17. bl := HostListFile{
  18. IPAddresses: []string{"172.16.1.1", "172.16.1.2"},
  19. CIDRNetworks: []string{"10.8.0.0/24"},
  20. }
  21. sl := HostListFile{
  22. IPAddresses: []string{"172.16.1.3", "172.16.1.4"},
  23. CIDRNetworks: []string{"192.168.8.0/24"},
  24. }
  25. blFile := filepath.Join(os.TempDir(), "bl.json")
  26. slFile := filepath.Join(os.TempDir(), "sl.json")
  27. data, err := json.Marshal(bl)
  28. assert.NoError(t, err)
  29. err = os.WriteFile(blFile, data, os.ModePerm)
  30. assert.NoError(t, err)
  31. data, err = json.Marshal(sl)
  32. assert.NoError(t, err)
  33. err = os.WriteFile(slFile, data, os.ModePerm)
  34. assert.NoError(t, err)
  35. config := &DefenderConfig{
  36. Enabled: true,
  37. BanTime: 10,
  38. BanTimeIncrement: 2,
  39. Threshold: 5,
  40. ScoreInvalid: 2,
  41. ScoreValid: 1,
  42. ObservationTime: 15,
  43. EntriesSoftLimit: 1,
  44. EntriesHardLimit: 2,
  45. SafeListFile: "slFile",
  46. BlockListFile: "blFile",
  47. }
  48. _, err = newInMemoryDefender(config)
  49. assert.Error(t, err)
  50. config.BlockListFile = blFile
  51. _, err = newInMemoryDefender(config)
  52. assert.Error(t, err)
  53. config.SafeListFile = slFile
  54. d, err := newInMemoryDefender(config)
  55. assert.NoError(t, err)
  56. defender := d.(*memoryDefender)
  57. assert.True(t, defender.IsBanned("172.16.1.1"))
  58. assert.False(t, defender.IsBanned("172.16.1.10"))
  59. assert.False(t, defender.IsBanned("10.8.2.3"))
  60. assert.True(t, defender.IsBanned("10.8.0.3"))
  61. assert.False(t, defender.IsBanned("invalid ip"))
  62. assert.Equal(t, 0, defender.countBanned())
  63. assert.Equal(t, 0, defender.countHosts())
  64. defender.AddEvent("172.16.1.4", HostEventLoginFailed)
  65. defender.AddEvent("192.168.8.4", HostEventUserNotFound)
  66. assert.Equal(t, 0, defender.countHosts())
  67. testIP := "12.34.56.78"
  68. defender.AddEvent(testIP, HostEventLoginFailed)
  69. assert.Equal(t, 1, defender.countHosts())
  70. assert.Equal(t, 0, defender.countBanned())
  71. assert.Equal(t, 1, defender.GetScore(testIP))
  72. assert.Nil(t, defender.GetBanTime(testIP))
  73. defender.AddEvent(testIP, HostEventNoLoginTried)
  74. assert.Equal(t, 1, defender.countHosts())
  75. assert.Equal(t, 0, defender.countBanned())
  76. assert.Equal(t, 3, defender.GetScore(testIP))
  77. defender.AddEvent(testIP, HostEventNoLoginTried)
  78. assert.Equal(t, 0, defender.countHosts())
  79. assert.Equal(t, 1, defender.countBanned())
  80. assert.Equal(t, 0, defender.GetScore(testIP))
  81. assert.NotNil(t, defender.GetBanTime(testIP))
  82. // now test cleanup, testIP is already banned
  83. testIP1 := "12.34.56.79"
  84. testIP2 := "12.34.56.80"
  85. testIP3 := "12.34.56.81"
  86. defender.AddEvent(testIP1, HostEventNoLoginTried)
  87. defender.AddEvent(testIP2, HostEventNoLoginTried)
  88. assert.Equal(t, 2, defender.countHosts())
  89. time.Sleep(20 * time.Millisecond)
  90. defender.AddEvent(testIP3, HostEventNoLoginTried)
  91. assert.Equal(t, defender.config.EntriesSoftLimit, defender.countHosts())
  92. // testIP1 and testIP2 should be removed
  93. assert.Equal(t, defender.config.EntriesSoftLimit, defender.countHosts())
  94. assert.Equal(t, 0, defender.GetScore(testIP1))
  95. assert.Equal(t, 0, defender.GetScore(testIP2))
  96. assert.Equal(t, 2, defender.GetScore(testIP3))
  97. defender.AddEvent(testIP3, HostEventNoLoginTried)
  98. defender.AddEvent(testIP3, HostEventNoLoginTried)
  99. // IP3 is now banned
  100. assert.NotNil(t, defender.GetBanTime(testIP3))
  101. assert.Equal(t, 0, defender.countHosts())
  102. time.Sleep(20 * time.Millisecond)
  103. for i := 0; i < 3; i++ {
  104. defender.AddEvent(testIP1, HostEventNoLoginTried)
  105. }
  106. assert.Equal(t, 0, defender.countHosts())
  107. assert.Equal(t, config.EntriesSoftLimit, defender.countBanned())
  108. assert.Nil(t, defender.GetBanTime(testIP))
  109. assert.Nil(t, defender.GetBanTime(testIP3))
  110. assert.NotNil(t, defender.GetBanTime(testIP1))
  111. for i := 0; i < 3; i++ {
  112. defender.AddEvent(testIP, HostEventNoLoginTried)
  113. time.Sleep(10 * time.Millisecond)
  114. defender.AddEvent(testIP3, HostEventNoLoginTried)
  115. }
  116. assert.Equal(t, 0, defender.countHosts())
  117. assert.Equal(t, defender.config.EntriesSoftLimit, defender.countBanned())
  118. banTime := defender.GetBanTime(testIP3)
  119. if assert.NotNil(t, banTime) {
  120. assert.True(t, defender.IsBanned(testIP3))
  121. // ban time should increase
  122. newBanTime := defender.GetBanTime(testIP3)
  123. assert.True(t, newBanTime.After(*banTime))
  124. }
  125. assert.True(t, defender.Unban(testIP3))
  126. assert.False(t, defender.Unban(testIP3))
  127. err = os.Remove(slFile)
  128. assert.NoError(t, err)
  129. err = os.Remove(blFile)
  130. assert.NoError(t, err)
  131. }
  132. func TestLoadHostListFromFile(t *testing.T) {
  133. _, err := loadHostListFromFile(".")
  134. assert.Error(t, err)
  135. hostsFilePath := filepath.Join(os.TempDir(), "hostfile")
  136. content := make([]byte, 1048576*6)
  137. _, err = rand.Read(content)
  138. assert.NoError(t, err)
  139. err = os.WriteFile(hostsFilePath, content, os.ModePerm)
  140. assert.NoError(t, err)
  141. _, err = loadHostListFromFile(hostsFilePath)
  142. assert.Error(t, err)
  143. hl := HostListFile{
  144. IPAddresses: []string{},
  145. CIDRNetworks: []string{},
  146. }
  147. asJSON, err := json.Marshal(hl)
  148. assert.NoError(t, err)
  149. err = os.WriteFile(hostsFilePath, asJSON, os.ModePerm)
  150. assert.NoError(t, err)
  151. hostList, err := loadHostListFromFile(hostsFilePath)
  152. assert.NoError(t, err)
  153. assert.Nil(t, hostList)
  154. hl.IPAddresses = append(hl.IPAddresses, "invalidip")
  155. asJSON, err = json.Marshal(hl)
  156. assert.NoError(t, err)
  157. err = os.WriteFile(hostsFilePath, asJSON, os.ModePerm)
  158. assert.NoError(t, err)
  159. hostList, err = loadHostListFromFile(hostsFilePath)
  160. assert.NoError(t, err)
  161. assert.Len(t, hostList.IPAddresses, 0)
  162. hl.IPAddresses = nil
  163. hl.CIDRNetworks = append(hl.CIDRNetworks, "invalid net")
  164. asJSON, err = json.Marshal(hl)
  165. assert.NoError(t, err)
  166. err = os.WriteFile(hostsFilePath, asJSON, os.ModePerm)
  167. assert.NoError(t, err)
  168. hostList, err = loadHostListFromFile(hostsFilePath)
  169. assert.NoError(t, err)
  170. assert.NotNil(t, hostList)
  171. assert.Len(t, hostList.IPAddresses, 0)
  172. assert.Equal(t, 0, hostList.Ranges.Len())
  173. if runtime.GOOS != "windows" {
  174. err = os.Chmod(hostsFilePath, 0111)
  175. assert.NoError(t, err)
  176. _, err = loadHostListFromFile(hostsFilePath)
  177. assert.Error(t, err)
  178. err = os.Chmod(hostsFilePath, 0644)
  179. assert.NoError(t, err)
  180. }
  181. err = os.WriteFile(hostsFilePath, []byte("non json content"), os.ModePerm)
  182. assert.NoError(t, err)
  183. _, err = loadHostListFromFile(hostsFilePath)
  184. assert.Error(t, err)
  185. err = os.Remove(hostsFilePath)
  186. assert.NoError(t, err)
  187. }
  188. func TestDefenderCleanup(t *testing.T) {
  189. d := memoryDefender{
  190. banned: make(map[string]time.Time),
  191. hosts: make(map[string]hostScore),
  192. config: &DefenderConfig{
  193. ObservationTime: 1,
  194. EntriesSoftLimit: 2,
  195. EntriesHardLimit: 3,
  196. },
  197. }
  198. d.banned["1.1.1.1"] = time.Now().Add(-24 * time.Hour)
  199. d.banned["1.1.1.2"] = time.Now().Add(-24 * time.Hour)
  200. d.banned["1.1.1.3"] = time.Now().Add(-24 * time.Hour)
  201. d.banned["1.1.1.4"] = time.Now().Add(-24 * time.Hour)
  202. d.cleanupBanned()
  203. assert.Equal(t, 0, d.countBanned())
  204. d.banned["2.2.2.2"] = time.Now().Add(2 * time.Minute)
  205. d.banned["2.2.2.3"] = time.Now().Add(1 * time.Minute)
  206. d.banned["2.2.2.4"] = time.Now().Add(3 * time.Minute)
  207. d.banned["2.2.2.5"] = time.Now().Add(4 * time.Minute)
  208. d.cleanupBanned()
  209. assert.Equal(t, d.config.EntriesSoftLimit, d.countBanned())
  210. assert.Nil(t, d.GetBanTime("2.2.2.3"))
  211. d.hosts["3.3.3.3"] = hostScore{
  212. TotalScore: 0,
  213. Events: []hostEvent{
  214. {
  215. dateTime: time.Now().Add(-5 * time.Minute),
  216. score: 1,
  217. },
  218. {
  219. dateTime: time.Now().Add(-3 * time.Minute),
  220. score: 1,
  221. },
  222. {
  223. dateTime: time.Now(),
  224. score: 1,
  225. },
  226. },
  227. }
  228. d.hosts["3.3.3.4"] = hostScore{
  229. TotalScore: 1,
  230. Events: []hostEvent{
  231. {
  232. dateTime: time.Now().Add(-3 * time.Minute),
  233. score: 1,
  234. },
  235. },
  236. }
  237. d.hosts["3.3.3.5"] = hostScore{
  238. TotalScore: 1,
  239. Events: []hostEvent{
  240. {
  241. dateTime: time.Now().Add(-2 * time.Minute),
  242. score: 1,
  243. },
  244. },
  245. }
  246. d.hosts["3.3.3.6"] = hostScore{
  247. TotalScore: 1,
  248. Events: []hostEvent{
  249. {
  250. dateTime: time.Now().Add(-1 * time.Minute),
  251. score: 1,
  252. },
  253. },
  254. }
  255. assert.Equal(t, 1, d.GetScore("3.3.3.3"))
  256. d.cleanupHosts()
  257. assert.Equal(t, d.config.EntriesSoftLimit, d.countHosts())
  258. assert.Equal(t, 0, d.GetScore("3.3.3.4"))
  259. }
  260. func TestDefenderConfig(t *testing.T) {
  261. c := DefenderConfig{}
  262. err := c.validate()
  263. require.NoError(t, err)
  264. c.Enabled = true
  265. c.Threshold = 10
  266. c.ScoreInvalid = 10
  267. err = c.validate()
  268. require.Error(t, err)
  269. c.ScoreInvalid = 2
  270. c.ScoreValid = 10
  271. err = c.validate()
  272. require.Error(t, err)
  273. c.ScoreValid = 1
  274. c.BanTime = 0
  275. err = c.validate()
  276. require.Error(t, err)
  277. c.BanTime = 30
  278. c.BanTimeIncrement = 0
  279. err = c.validate()
  280. require.Error(t, err)
  281. c.BanTimeIncrement = 50
  282. c.ObservationTime = 0
  283. err = c.validate()
  284. require.Error(t, err)
  285. c.ObservationTime = 30
  286. err = c.validate()
  287. require.Error(t, err)
  288. c.EntriesSoftLimit = 10
  289. err = c.validate()
  290. require.Error(t, err)
  291. c.EntriesHardLimit = 10
  292. err = c.validate()
  293. require.Error(t, err)
  294. c.EntriesHardLimit = 20
  295. err = c.validate()
  296. require.NoError(t, err)
  297. }
  298. func BenchmarkDefenderBannedSearch(b *testing.B) {
  299. d := getDefenderForBench()
  300. ip, ipnet, err := net.ParseCIDR("10.8.0.0/12") // 1048574 ip addresses
  301. if err != nil {
  302. panic(err)
  303. }
  304. for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); inc(ip) {
  305. d.banned[ip.String()] = time.Now().Add(10 * time.Minute)
  306. }
  307. b.ResetTimer()
  308. for i := 0; i < b.N; i++ {
  309. d.IsBanned("192.168.1.1")
  310. }
  311. }
  312. func BenchmarkCleanup(b *testing.B) {
  313. d := getDefenderForBench()
  314. ip, ipnet, err := net.ParseCIDR("192.168.4.0/24")
  315. if err != nil {
  316. panic(err)
  317. }
  318. b.ResetTimer()
  319. for i := 0; i < b.N; i++ {
  320. for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); inc(ip) {
  321. d.AddEvent(ip.String(), HostEventLoginFailed)
  322. if d.countHosts() > d.config.EntriesHardLimit {
  323. panic("too many hosts")
  324. }
  325. if d.countBanned() > d.config.EntriesSoftLimit {
  326. panic("too many ip banned")
  327. }
  328. }
  329. }
  330. }
  331. func BenchmarkDefenderBannedSearchWithBlockList(b *testing.B) {
  332. d := getDefenderForBench()
  333. d.blockList = &HostList{
  334. IPAddresses: make(map[string]bool),
  335. Ranges: cidranger.NewPCTrieRanger(),
  336. }
  337. ip, ipnet, err := net.ParseCIDR("129.8.0.0/12") // 1048574 ip addresses
  338. if err != nil {
  339. panic(err)
  340. }
  341. for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); inc(ip) {
  342. d.banned[ip.String()] = time.Now().Add(10 * time.Minute)
  343. d.blockList.IPAddresses[ip.String()] = true
  344. }
  345. for i := 0; i < 255; i++ {
  346. cidr := fmt.Sprintf("10.8.%v.1/24", i)
  347. _, network, _ := net.ParseCIDR(cidr)
  348. if err := d.blockList.Ranges.Insert(cidranger.NewBasicRangerEntry(*network)); err != nil {
  349. panic(err)
  350. }
  351. }
  352. b.ResetTimer()
  353. for i := 0; i < b.N; i++ {
  354. d.IsBanned("192.168.1.1")
  355. }
  356. }
  357. func BenchmarkHostListSearch(b *testing.B) {
  358. hostlist := &HostList{
  359. IPAddresses: make(map[string]bool),
  360. Ranges: cidranger.NewPCTrieRanger(),
  361. }
  362. ip, ipnet, _ := net.ParseCIDR("172.16.0.0/16")
  363. for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); inc(ip) {
  364. hostlist.IPAddresses[ip.String()] = true
  365. }
  366. for i := 0; i < 255; i++ {
  367. cidr := fmt.Sprintf("10.8.%v.1/24", i)
  368. _, network, _ := net.ParseCIDR(cidr)
  369. if err := hostlist.Ranges.Insert(cidranger.NewBasicRangerEntry(*network)); err != nil {
  370. panic(err)
  371. }
  372. }
  373. b.ResetTimer()
  374. for i := 0; i < b.N; i++ {
  375. if hostlist.isListed("192.167.1.2") {
  376. panic("should not be listed")
  377. }
  378. }
  379. }
  380. func BenchmarkCIDRanger(b *testing.B) {
  381. ranger := cidranger.NewPCTrieRanger()
  382. for i := 0; i < 255; i++ {
  383. cidr := fmt.Sprintf("192.168.%v.1/24", i)
  384. _, network, _ := net.ParseCIDR(cidr)
  385. if err := ranger.Insert(cidranger.NewBasicRangerEntry(*network)); err != nil {
  386. panic(err)
  387. }
  388. }
  389. ipToMatch := net.ParseIP("192.167.1.2")
  390. b.ResetTimer()
  391. for i := 0; i < b.N; i++ {
  392. if _, err := ranger.Contains(ipToMatch); err != nil {
  393. panic(err)
  394. }
  395. }
  396. }
  397. func BenchmarkNetContains(b *testing.B) {
  398. var nets []*net.IPNet
  399. for i := 0; i < 255; i++ {
  400. cidr := fmt.Sprintf("192.168.%v.1/24", i)
  401. _, network, _ := net.ParseCIDR(cidr)
  402. nets = append(nets, network)
  403. }
  404. ipToMatch := net.ParseIP("192.167.1.1")
  405. b.ResetTimer()
  406. for i := 0; i < b.N; i++ {
  407. for _, n := range nets {
  408. n.Contains(ipToMatch)
  409. }
  410. }
  411. }
  412. func getDefenderForBench() *memoryDefender {
  413. config := &DefenderConfig{
  414. Enabled: true,
  415. BanTime: 30,
  416. BanTimeIncrement: 50,
  417. Threshold: 10,
  418. ScoreInvalid: 2,
  419. ScoreValid: 2,
  420. ObservationTime: 30,
  421. EntriesSoftLimit: 50,
  422. EntriesHardLimit: 100,
  423. }
  424. return &memoryDefender{
  425. config: config,
  426. hosts: make(map[string]hostScore),
  427. banned: make(map[string]time.Time),
  428. }
  429. }
  430. func inc(ip net.IP) {
  431. for j := len(ip) - 1; j >= 0; j-- {
  432. ip[j]++
  433. if ip[j] > 0 {
  434. break
  435. }
  436. }
  437. }