defender.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. package common
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "net"
  6. "os"
  7. "sort"
  8. "sync"
  9. "time"
  10. "github.com/yl2chen/cidranger"
  11. "github.com/drakkan/sftpgo/logger"
  12. "github.com/drakkan/sftpgo/utils"
  13. )
  14. // HostEvent is the enumerable for the support host event
  15. type HostEvent int
  16. // Supported host events
  17. const (
  18. HostEventLoginFailed HostEvent = iota
  19. HostEventUserNotFound
  20. HostEventNoLoginTried
  21. HostEventRateExceeded
  22. )
  23. // Defender defines the interface that a defender must implements
  24. type Defender interface {
  25. AddEvent(ip string, event HostEvent)
  26. IsBanned(ip string) bool
  27. GetBanTime(ip string) *time.Time
  28. GetScore(ip string) int
  29. Unban(ip string) bool
  30. Reload() error
  31. }
  32. // DefenderConfig defines the "defender" configuration
  33. type DefenderConfig struct {
  34. // Set to true to enable the defender
  35. Enabled bool `json:"enabled" mapstructure:"enabled"`
  36. // BanTime is the number of minutes that a host is banned
  37. BanTime int `json:"ban_time" mapstructure:"ban_time"`
  38. // Percentage increase of the ban time if a banned host tries to connect again
  39. BanTimeIncrement int `json:"ban_time_increment" mapstructure:"ban_time_increment"`
  40. // Threshold value for banning a client
  41. Threshold int `json:"threshold" mapstructure:"threshold"`
  42. // Score for invalid login attempts, eg. non-existent user accounts or
  43. // client disconnected for inactivity without authentication attempts
  44. ScoreInvalid int `json:"score_invalid" mapstructure:"score_invalid"`
  45. // Score for valid login attempts, eg. user accounts that exist
  46. ScoreValid int `json:"score_valid" mapstructure:"score_valid"`
  47. // Score for rate exceeded events, generated from the rate limiters
  48. ScoreRateExceeded int `json:"score_rate_exceeded" mapstructure:"score_rate_exceeded"`
  49. // Defines the time window, in minutes, for tracking client errors.
  50. // A host is banned if it has exceeded the defined threshold during
  51. // the last observation time minutes
  52. ObservationTime int `json:"observation_time" mapstructure:"observation_time"`
  53. // The number of banned IPs and host scores kept in memory will vary between the
  54. // soft and hard limit
  55. EntriesSoftLimit int `json:"entries_soft_limit" mapstructure:"entries_soft_limit"`
  56. EntriesHardLimit int `json:"entries_hard_limit" mapstructure:"entries_hard_limit"`
  57. // Path to a file containing a list of ip addresses and/or networks to never ban
  58. SafeListFile string `json:"safelist_file" mapstructure:"safelist_file"`
  59. // Path to a file containing a list of ip addresses and/or networks to always ban
  60. BlockListFile string `json:"blocklist_file" mapstructure:"blocklist_file"`
  61. }
  62. type memoryDefender struct {
  63. config *DefenderConfig
  64. sync.RWMutex
  65. // IP addresses of the clients trying to connected are stored inside hosts,
  66. // they are added to banned once the thresold is reached.
  67. // A violation from a banned host will increase the ban time
  68. // based on the configured BanTimeIncrement
  69. hosts map[string]hostScore // the key is the host IP
  70. banned map[string]time.Time // the key is the host IP
  71. safeList *HostList
  72. blockList *HostList
  73. }
  74. // HostListFile defines the structure expected for safe/block list files
  75. type HostListFile struct {
  76. IPAddresses []string `json:"addresses"`
  77. CIDRNetworks []string `json:"networks"`
  78. }
  79. // HostList defines the structure used to keep the HostListFile in memory
  80. type HostList struct {
  81. IPAddresses map[string]bool
  82. Ranges cidranger.Ranger
  83. }
  84. func (h *HostList) isListed(ip string) bool {
  85. if _, ok := h.IPAddresses[ip]; ok {
  86. return true
  87. }
  88. ok, err := h.Ranges.Contains(net.ParseIP(ip))
  89. if err != nil {
  90. return false
  91. }
  92. return ok
  93. }
  94. type hostEvent struct {
  95. dateTime time.Time
  96. score int
  97. }
  98. type hostScore struct {
  99. TotalScore int
  100. Events []hostEvent
  101. }
  102. // validate returns an error if the configuration is invalid
  103. func (c *DefenderConfig) validate() error {
  104. if !c.Enabled {
  105. return nil
  106. }
  107. if c.ScoreInvalid >= c.Threshold {
  108. return fmt.Errorf("score_invalid %v cannot be greater than threshold %v", c.ScoreInvalid, c.Threshold)
  109. }
  110. if c.ScoreValid >= c.Threshold {
  111. return fmt.Errorf("score_valid %v cannot be greater than threshold %v", c.ScoreValid, c.Threshold)
  112. }
  113. if c.ScoreRateExceeded >= c.Threshold {
  114. return fmt.Errorf("score_rate_exceeded %v cannot be greater than threshold %v", c.ScoreRateExceeded, c.Threshold)
  115. }
  116. if c.BanTime <= 0 {
  117. return fmt.Errorf("invalid ban_time %v", c.BanTime)
  118. }
  119. if c.BanTimeIncrement <= 0 {
  120. return fmt.Errorf("invalid ban_time_increment %v", c.BanTimeIncrement)
  121. }
  122. if c.ObservationTime <= 0 {
  123. return fmt.Errorf("invalid observation_time %v", c.ObservationTime)
  124. }
  125. if c.EntriesSoftLimit <= 0 {
  126. return fmt.Errorf("invalid entries_soft_limit %v", c.EntriesSoftLimit)
  127. }
  128. if c.EntriesHardLimit <= c.EntriesSoftLimit {
  129. return fmt.Errorf("invalid entries_hard_limit %v must be > %v", c.EntriesHardLimit, c.EntriesSoftLimit)
  130. }
  131. return nil
  132. }
  133. func newInMemoryDefender(config *DefenderConfig) (Defender, error) {
  134. err := config.validate()
  135. if err != nil {
  136. return nil, err
  137. }
  138. defender := &memoryDefender{
  139. config: config,
  140. hosts: make(map[string]hostScore),
  141. banned: make(map[string]time.Time),
  142. }
  143. if err := defender.Reload(); err != nil {
  144. return nil, err
  145. }
  146. return defender, nil
  147. }
  148. // Reload reloads block and safe lists
  149. func (d *memoryDefender) Reload() error {
  150. blockList, err := loadHostListFromFile(d.config.BlockListFile)
  151. if err != nil {
  152. return err
  153. }
  154. d.Lock()
  155. d.blockList = blockList
  156. d.Unlock()
  157. safeList, err := loadHostListFromFile(d.config.SafeListFile)
  158. if err != nil {
  159. return err
  160. }
  161. d.Lock()
  162. d.safeList = safeList
  163. d.Unlock()
  164. return nil
  165. }
  166. // IsBanned returns true if the specified IP is banned
  167. // and increase ban time if the IP is found.
  168. // This method must be called as soon as the client connects
  169. func (d *memoryDefender) IsBanned(ip string) bool {
  170. d.RLock()
  171. if banTime, ok := d.banned[ip]; ok {
  172. if banTime.After(time.Now()) {
  173. increment := d.config.BanTime * d.config.BanTimeIncrement / 100
  174. if increment == 0 {
  175. increment++
  176. }
  177. d.RUnlock()
  178. // we can save an earlier ban time if there are contemporary updates
  179. // but this should not make much difference. I prefer to hold a read lock
  180. // until possible for performance reasons, this method is called each
  181. // time a new client connects and it must be as fast as possible
  182. d.Lock()
  183. d.banned[ip] = banTime.Add(time.Duration(increment) * time.Minute)
  184. d.Unlock()
  185. return true
  186. }
  187. }
  188. defer d.RUnlock()
  189. if d.blockList != nil && d.blockList.isListed(ip) {
  190. // permanent ban
  191. return true
  192. }
  193. return false
  194. }
  195. // Unban removes the specified IP address from the banned ones
  196. func (d *memoryDefender) Unban(ip string) bool {
  197. d.Lock()
  198. defer d.Unlock()
  199. if _, ok := d.banned[ip]; ok {
  200. delete(d.banned, ip)
  201. return true
  202. }
  203. return false
  204. }
  205. // AddEvent adds an event for the given IP.
  206. // This method must be called for clients not yet banned
  207. func (d *memoryDefender) AddEvent(ip string, event HostEvent) {
  208. d.Lock()
  209. defer d.Unlock()
  210. if d.safeList != nil && d.safeList.isListed(ip) {
  211. return
  212. }
  213. var score int
  214. switch event {
  215. case HostEventLoginFailed:
  216. score = d.config.ScoreValid
  217. case HostEventRateExceeded:
  218. score = d.config.ScoreRateExceeded
  219. case HostEventUserNotFound, HostEventNoLoginTried:
  220. score = d.config.ScoreInvalid
  221. }
  222. ev := hostEvent{
  223. dateTime: time.Now(),
  224. score: score,
  225. }
  226. if hs, ok := d.hosts[ip]; ok {
  227. hs.Events = append(hs.Events, ev)
  228. hs.TotalScore = 0
  229. idx := 0
  230. for _, event := range hs.Events {
  231. if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
  232. hs.Events[idx] = event
  233. hs.TotalScore += event.score
  234. idx++
  235. }
  236. }
  237. hs.Events = hs.Events[:idx]
  238. if hs.TotalScore >= d.config.Threshold {
  239. d.banned[ip] = time.Now().Add(time.Duration(d.config.BanTime) * time.Minute)
  240. delete(d.hosts, ip)
  241. d.cleanupBanned()
  242. } else {
  243. d.hosts[ip] = hs
  244. }
  245. } else {
  246. d.hosts[ip] = hostScore{
  247. TotalScore: ev.score,
  248. Events: []hostEvent{ev},
  249. }
  250. d.cleanupHosts()
  251. }
  252. }
  253. func (d *memoryDefender) countBanned() int {
  254. d.RLock()
  255. defer d.RUnlock()
  256. return len(d.banned)
  257. }
  258. func (d *memoryDefender) countHosts() int {
  259. d.RLock()
  260. defer d.RUnlock()
  261. return len(d.hosts)
  262. }
  263. // GetBanTime returns the ban time for the given IP or nil if the IP is not banned
  264. func (d *memoryDefender) GetBanTime(ip string) *time.Time {
  265. d.RLock()
  266. defer d.RUnlock()
  267. if banTime, ok := d.banned[ip]; ok {
  268. return &banTime
  269. }
  270. return nil
  271. }
  272. // GetScore returns the score for the given IP
  273. func (d *memoryDefender) GetScore(ip string) int {
  274. d.RLock()
  275. defer d.RUnlock()
  276. score := 0
  277. if hs, ok := d.hosts[ip]; ok {
  278. for _, event := range hs.Events {
  279. if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
  280. score += event.score
  281. }
  282. }
  283. }
  284. return score
  285. }
  286. func (d *memoryDefender) cleanupBanned() {
  287. if len(d.banned) > d.config.EntriesHardLimit {
  288. kvList := make(kvList, 0, len(d.banned))
  289. for k, v := range d.banned {
  290. if v.Before(time.Now()) {
  291. delete(d.banned, k)
  292. }
  293. kvList = append(kvList, kv{
  294. Key: k,
  295. Value: v.UnixNano(),
  296. })
  297. }
  298. // we removed expired ip addresses, if any, above, this could be enough
  299. numToRemove := len(d.banned) - d.config.EntriesSoftLimit
  300. if numToRemove <= 0 {
  301. return
  302. }
  303. sort.Sort(kvList)
  304. for idx, kv := range kvList {
  305. if idx >= numToRemove {
  306. break
  307. }
  308. delete(d.banned, kv.Key)
  309. }
  310. }
  311. }
  312. func (d *memoryDefender) cleanupHosts() {
  313. if len(d.hosts) > d.config.EntriesHardLimit {
  314. kvList := make(kvList, 0, len(d.hosts))
  315. for k, v := range d.hosts {
  316. value := int64(0)
  317. if len(v.Events) > 0 {
  318. value = v.Events[len(v.Events)-1].dateTime.UnixNano()
  319. }
  320. kvList = append(kvList, kv{
  321. Key: k,
  322. Value: value,
  323. })
  324. }
  325. sort.Sort(kvList)
  326. numToRemove := len(d.hosts) - d.config.EntriesSoftLimit
  327. for idx, kv := range kvList {
  328. if idx >= numToRemove {
  329. break
  330. }
  331. delete(d.hosts, kv.Key)
  332. }
  333. }
  334. }
  335. func loadHostListFromFile(name string) (*HostList, error) {
  336. if name == "" {
  337. return nil, nil
  338. }
  339. if !utils.IsFileInputValid(name) {
  340. return nil, fmt.Errorf("invalid host list file name %#v", name)
  341. }
  342. info, err := os.Stat(name)
  343. if err != nil {
  344. return nil, err
  345. }
  346. // opinionated max size, you should avoid big host lists
  347. if info.Size() > 1048576*5 { // 5MB
  348. return nil, fmt.Errorf("host list file %#v is too big: %v bytes", name, info.Size())
  349. }
  350. content, err := os.ReadFile(name)
  351. if err != nil {
  352. return nil, fmt.Errorf("unable to read input file %#v: %v", name, err)
  353. }
  354. var hostList HostListFile
  355. err = json.Unmarshal(content, &hostList)
  356. if err != nil {
  357. return nil, err
  358. }
  359. if len(hostList.CIDRNetworks) > 0 || len(hostList.IPAddresses) > 0 {
  360. result := &HostList{
  361. IPAddresses: make(map[string]bool),
  362. Ranges: cidranger.NewPCTrieRanger(),
  363. }
  364. ipCount := 0
  365. cdrCount := 0
  366. for _, ip := range hostList.IPAddresses {
  367. if net.ParseIP(ip) == nil {
  368. logger.Warn(logSender, "", "unable to parse IP %#v", ip)
  369. continue
  370. }
  371. result.IPAddresses[ip] = true
  372. ipCount++
  373. }
  374. for _, cidrNet := range hostList.CIDRNetworks {
  375. _, network, err := net.ParseCIDR(cidrNet)
  376. if err != nil {
  377. logger.Warn(logSender, "", "unable to parse CIDR network %#v", cidrNet)
  378. continue
  379. }
  380. err = result.Ranges.Insert(cidranger.NewBasicRangerEntry(*network))
  381. if err == nil {
  382. cdrCount++
  383. }
  384. }
  385. logger.Info(logSender, "", "list %#v loaded, ip addresses loaded: %v/%v networks loaded: %v/%v",
  386. name, ipCount, len(hostList.IPAddresses), cdrCount, len(hostList.CIDRNetworks))
  387. return result, nil
  388. }
  389. return nil, nil
  390. }
  391. type kv struct {
  392. Key string
  393. Value int64
  394. }
  395. type kvList []kv
  396. func (p kvList) Len() int { return len(p) }
  397. func (p kvList) Less(i, j int) bool { return p[i].Value < p[j].Value }
  398. func (p kvList) Swap(i, j int) { p[i], p[j] = p[j], p[i] }