defender.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. package common
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "io/ioutil"
  6. "net"
  7. "os"
  8. "sort"
  9. "sync"
  10. "time"
  11. "github.com/yl2chen/cidranger"
  12. "github.com/drakkan/sftpgo/logger"
  13. "github.com/drakkan/sftpgo/utils"
  14. )
  15. // HostEvent is the enumerable for the support host event
  16. type HostEvent int
  17. // Supported host events
  18. const (
  19. HostEventLoginFailed HostEvent = iota
  20. HostEventUserNotFound
  21. HostEventNoLoginTried
  22. )
  23. // Defender defines the interface that a defender must implements
  24. type Defender interface {
  25. AddEvent(ip string, event HostEvent)
  26. IsBanned(ip string) bool
  27. GetBanTime(ip string) *time.Time
  28. GetScore(ip string) int
  29. Unban(ip string) bool
  30. Reload() error
  31. }
  32. // DefenderConfig defines the "defender" configuration
  33. type DefenderConfig struct {
  34. // Set to true to enable the defender
  35. Enabled bool `json:"enabled" mapstructure:"enabled"`
  36. // BanTime is the number of minutes that a host is banned
  37. BanTime int `json:"ban_time" mapstructure:"ban_time"`
  38. // Percentage increase of the ban time if a banned host tries to connect again
  39. BanTimeIncrement int `json:"ban_time_increment" mapstructure:"ban_time_increment"`
  40. // Threshold value for banning a client
  41. Threshold int `json:"threshold" mapstructure:"threshold"`
  42. // Score for invalid login attempts, eg. non-existent user accounts or
  43. // client disconnected for inactivity without authentication attempts
  44. ScoreInvalid int `json:"score_invalid" mapstructure:"score_invalid"`
  45. // Score for valid login attempts, eg. user accounts that exist
  46. ScoreValid int `json:"score_valid" mapstructure:"score_valid"`
  47. // Defines the time window, in minutes, for tracking client errors.
  48. // A host is banned if it has exceeded the defined threshold during
  49. // the last observation time minutes
  50. ObservationTime int `json:"observation_time" mapstructure:"observation_time"`
  51. // The number of banned IPs and host scores kept in memory will vary between the
  52. // soft and hard limit
  53. EntriesSoftLimit int `json:"entries_soft_limit" mapstructure:"entries_soft_limit"`
  54. EntriesHardLimit int `json:"entries_hard_limit" mapstructure:"entries_hard_limit"`
  55. // Path to a file containing a list of ip addresses and/or networks to never ban
  56. SafeListFile string `json:"safelist_file" mapstructure:"safelist_file"`
  57. // Path to a file containing a list of ip addresses and/or networks to always ban
  58. BlockListFile string `json:"blocklist_file" mapstructure:"blocklist_file"`
  59. }
  60. type memoryDefender struct {
  61. config *DefenderConfig
  62. sync.RWMutex
  63. // IP addresses of the clients trying to connected are stored inside hosts,
  64. // they are added to banned once the thresold is reached.
  65. // A violation from a banned host will increase the ban time
  66. // based on the configured BanTimeIncrement
  67. hosts map[string]hostScore // the key is the host IP
  68. banned map[string]time.Time // the key is the host IP
  69. safeList *HostList
  70. blockList *HostList
  71. }
  72. // HostListFile defines the structure expected for safe/block list files
  73. type HostListFile struct {
  74. IPAddresses []string `json:"addresses"`
  75. CIDRNetworks []string `json:"networks"`
  76. }
  77. // HostList defines the structure used to keep the HostListFile in memory
  78. type HostList struct {
  79. IPAddresses map[string]bool
  80. Ranges cidranger.Ranger
  81. }
  82. func (h *HostList) isListed(ip string) bool {
  83. if _, ok := h.IPAddresses[ip]; ok {
  84. return true
  85. }
  86. ok, err := h.Ranges.Contains(net.ParseIP(ip))
  87. if err != nil {
  88. return false
  89. }
  90. return ok
  91. }
  92. type hostEvent struct {
  93. dateTime time.Time
  94. score int
  95. }
  96. type hostScore struct {
  97. TotalScore int
  98. Events []hostEvent
  99. }
  100. // validate returns an error if the configuration is invalid
  101. func (c *DefenderConfig) validate() error {
  102. if !c.Enabled {
  103. return nil
  104. }
  105. if c.ScoreInvalid >= c.Threshold {
  106. return fmt.Errorf("score_invalid %v cannot be greater than threshold %v", c.ScoreInvalid, c.Threshold)
  107. }
  108. if c.ScoreValid >= c.Threshold {
  109. return fmt.Errorf("score_valid %v cannot be greater than threshold %v", c.ScoreValid, c.Threshold)
  110. }
  111. if c.BanTime <= 0 {
  112. return fmt.Errorf("invalid ban_time %v", c.BanTime)
  113. }
  114. if c.BanTimeIncrement <= 0 {
  115. return fmt.Errorf("invalid ban_time_increment %v", c.BanTimeIncrement)
  116. }
  117. if c.ObservationTime <= 0 {
  118. return fmt.Errorf("invalid observation_time %v", c.ObservationTime)
  119. }
  120. if c.EntriesSoftLimit <= 0 {
  121. return fmt.Errorf("invalid entries_soft_limit %v", c.EntriesSoftLimit)
  122. }
  123. if c.EntriesHardLimit <= c.EntriesSoftLimit {
  124. return fmt.Errorf("invalid entries_hard_limit %v must be > %v", c.EntriesHardLimit, c.EntriesSoftLimit)
  125. }
  126. return nil
  127. }
  128. func newInMemoryDefender(config *DefenderConfig) (Defender, error) {
  129. err := config.validate()
  130. if err != nil {
  131. return nil, err
  132. }
  133. defender := &memoryDefender{
  134. config: config,
  135. hosts: make(map[string]hostScore),
  136. banned: make(map[string]time.Time),
  137. }
  138. if err := defender.Reload(); err != nil {
  139. return nil, err
  140. }
  141. return defender, nil
  142. }
  143. // Reload reloads block and safe lists
  144. func (d *memoryDefender) Reload() error {
  145. blockList, err := loadHostListFromFile(d.config.BlockListFile)
  146. if err != nil {
  147. return err
  148. }
  149. d.Lock()
  150. d.blockList = blockList
  151. d.Unlock()
  152. safeList, err := loadHostListFromFile(d.config.SafeListFile)
  153. if err != nil {
  154. return err
  155. }
  156. d.Lock()
  157. d.safeList = safeList
  158. d.Unlock()
  159. return nil
  160. }
  161. // IsBanned returns true if the specified IP is banned
  162. // and increase ban time if the IP is found.
  163. // This method must be called as soon as the client connects
  164. func (d *memoryDefender) IsBanned(ip string) bool {
  165. d.RLock()
  166. if banTime, ok := d.banned[ip]; ok {
  167. if banTime.After(time.Now()) {
  168. increment := d.config.BanTime * d.config.BanTimeIncrement / 100
  169. if increment == 0 {
  170. increment++
  171. }
  172. d.RUnlock()
  173. // we can save an earlier ban time if there are contemporary updates
  174. // but this should not make much difference. I prefer to hold a read lock
  175. // until possible for performance reasons, this method is called each
  176. // time a new client connects and it must be as fast as possible
  177. d.Lock()
  178. d.banned[ip] = banTime.Add(time.Duration(increment) * time.Minute)
  179. d.Unlock()
  180. return true
  181. }
  182. }
  183. defer d.RUnlock()
  184. if d.blockList != nil && d.blockList.isListed(ip) {
  185. // permanent ban
  186. return true
  187. }
  188. return false
  189. }
  190. // Unban removes the specified IP address from the banned ones
  191. func (d *memoryDefender) Unban(ip string) bool {
  192. d.Lock()
  193. defer d.Unlock()
  194. if _, ok := d.banned[ip]; ok {
  195. delete(d.banned, ip)
  196. return true
  197. }
  198. return false
  199. }
  200. // AddEvent adds an event for the given IP.
  201. // This method must be called for clients not yet banned
  202. func (d *memoryDefender) AddEvent(ip string, event HostEvent) {
  203. d.Lock()
  204. defer d.Unlock()
  205. if d.safeList != nil && d.safeList.isListed(ip) {
  206. return
  207. }
  208. var score int
  209. switch event {
  210. case HostEventLoginFailed:
  211. score = d.config.ScoreValid
  212. case HostEventUserNotFound, HostEventNoLoginTried:
  213. score = d.config.ScoreInvalid
  214. }
  215. ev := hostEvent{
  216. dateTime: time.Now(),
  217. score: score,
  218. }
  219. if hs, ok := d.hosts[ip]; ok {
  220. hs.Events = append(hs.Events, ev)
  221. hs.TotalScore = 0
  222. idx := 0
  223. for _, event := range hs.Events {
  224. if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
  225. hs.Events[idx] = event
  226. hs.TotalScore += event.score
  227. idx++
  228. }
  229. }
  230. hs.Events = hs.Events[:idx]
  231. if hs.TotalScore >= d.config.Threshold {
  232. d.banned[ip] = time.Now().Add(time.Duration(d.config.BanTime) * time.Minute)
  233. delete(d.hosts, ip)
  234. d.cleanupBanned()
  235. } else {
  236. d.hosts[ip] = hs
  237. }
  238. } else {
  239. d.hosts[ip] = hostScore{
  240. TotalScore: ev.score,
  241. Events: []hostEvent{ev},
  242. }
  243. d.cleanupHosts()
  244. }
  245. }
  246. func (d *memoryDefender) countBanned() int {
  247. d.RLock()
  248. defer d.RUnlock()
  249. return len(d.banned)
  250. }
  251. func (d *memoryDefender) countHosts() int {
  252. d.RLock()
  253. defer d.RUnlock()
  254. return len(d.hosts)
  255. }
  256. // GetBanTime returns the ban time for the given IP or nil if the IP is not banned
  257. func (d *memoryDefender) GetBanTime(ip string) *time.Time {
  258. d.RLock()
  259. defer d.RUnlock()
  260. if banTime, ok := d.banned[ip]; ok {
  261. return &banTime
  262. }
  263. return nil
  264. }
  265. // GetScore returns the score for the given IP
  266. func (d *memoryDefender) GetScore(ip string) int {
  267. d.RLock()
  268. defer d.RUnlock()
  269. score := 0
  270. if hs, ok := d.hosts[ip]; ok {
  271. for _, event := range hs.Events {
  272. if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
  273. score += event.score
  274. }
  275. }
  276. }
  277. return score
  278. }
  279. func (d *memoryDefender) cleanupBanned() {
  280. if len(d.banned) > d.config.EntriesHardLimit {
  281. kvList := make(kvList, 0, len(d.banned))
  282. for k, v := range d.banned {
  283. if v.Before(time.Now()) {
  284. delete(d.banned, k)
  285. }
  286. kvList = append(kvList, kv{
  287. Key: k,
  288. Value: v.UnixNano(),
  289. })
  290. }
  291. // we removed expired ip addresses, if any, above, this could be enough
  292. numToRemove := len(d.banned) - d.config.EntriesSoftLimit
  293. if numToRemove <= 0 {
  294. return
  295. }
  296. sort.Sort(kvList)
  297. for idx, kv := range kvList {
  298. if idx >= numToRemove {
  299. break
  300. }
  301. delete(d.banned, kv.Key)
  302. }
  303. }
  304. }
  305. func (d *memoryDefender) cleanupHosts() {
  306. if len(d.hosts) > d.config.EntriesHardLimit {
  307. kvList := make(kvList, 0, len(d.hosts))
  308. for k, v := range d.hosts {
  309. value := int64(0)
  310. if len(v.Events) > 0 {
  311. value = v.Events[len(v.Events)-1].dateTime.UnixNano()
  312. }
  313. kvList = append(kvList, kv{
  314. Key: k,
  315. Value: value,
  316. })
  317. }
  318. sort.Sort(kvList)
  319. numToRemove := len(d.hosts) - d.config.EntriesSoftLimit
  320. for idx, kv := range kvList {
  321. if idx >= numToRemove {
  322. break
  323. }
  324. delete(d.hosts, kv.Key)
  325. }
  326. }
  327. }
  328. func loadHostListFromFile(name string) (*HostList, error) {
  329. if name == "" {
  330. return nil, nil
  331. }
  332. if !utils.IsFileInputValid(name) {
  333. return nil, fmt.Errorf("invalid host list file name %#v", name)
  334. }
  335. info, err := os.Stat(name)
  336. if err != nil {
  337. return nil, err
  338. }
  339. // opinionated max size, you should avoid big host lists
  340. if info.Size() > 1048576*5 { // 5MB
  341. return nil, fmt.Errorf("host list file %#v is too big: %v bytes", name, info.Size())
  342. }
  343. content, err := ioutil.ReadFile(name)
  344. if err != nil {
  345. return nil, fmt.Errorf("unable to read input file %#v: %v", name, err)
  346. }
  347. var hostList HostListFile
  348. err = json.Unmarshal(content, &hostList)
  349. if err != nil {
  350. return nil, err
  351. }
  352. if len(hostList.CIDRNetworks) > 0 || len(hostList.IPAddresses) > 0 {
  353. result := &HostList{
  354. IPAddresses: make(map[string]bool),
  355. Ranges: cidranger.NewPCTrieRanger(),
  356. }
  357. ipCount := 0
  358. cdrCount := 0
  359. for _, ip := range hostList.IPAddresses {
  360. if net.ParseIP(ip) == nil {
  361. logger.Warn(logSender, "", "unable to parse IP %#v", ip)
  362. continue
  363. }
  364. result.IPAddresses[ip] = true
  365. ipCount++
  366. }
  367. for _, cidrNet := range hostList.CIDRNetworks {
  368. _, network, err := net.ParseCIDR(cidrNet)
  369. if err != nil {
  370. logger.Warn(logSender, "", "unable to parse CIDR network %#v", cidrNet)
  371. continue
  372. }
  373. err = result.Ranges.Insert(cidranger.NewBasicRangerEntry(*network))
  374. if err == nil {
  375. cdrCount++
  376. }
  377. }
  378. logger.Info(logSender, "", "list %#v loaded, ip addresses loaded: %v/%v networks loaded: %v/%v",
  379. name, ipCount, len(hostList.IPAddresses), cdrCount, len(hostList.CIDRNetworks))
  380. return result, nil
  381. }
  382. return nil, nil
  383. }
  384. type kv struct {
  385. Key string
  386. Value int64
  387. }
  388. type kvList []kv
  389. func (p kvList) Len() int { return len(p) }
  390. func (p kvList) Less(i, j int) bool { return p[i].Value < p[j].Value }
  391. func (p kvList) Swap(i, j int) { p[i], p[j] = p[j], p[i] }