defender.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. package common
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "net"
  6. "os"
  7. "sort"
  8. "sync"
  9. "time"
  10. "github.com/yl2chen/cidranger"
  11. "github.com/drakkan/sftpgo/logger"
  12. "github.com/drakkan/sftpgo/utils"
  13. )
  14. // HostEvent is the enumerable for the support host event
  15. type HostEvent int
  16. // Supported host events
  17. const (
  18. HostEventLoginFailed HostEvent = iota
  19. HostEventUserNotFound
  20. HostEventNoLoginTried
  21. )
  22. // Defender defines the interface that a defender must implements
  23. type Defender interface {
  24. AddEvent(ip string, event HostEvent)
  25. IsBanned(ip string) bool
  26. GetBanTime(ip string) *time.Time
  27. GetScore(ip string) int
  28. Unban(ip string) bool
  29. Reload() error
  30. }
  31. // DefenderConfig defines the "defender" configuration
  32. type DefenderConfig struct {
  33. // Set to true to enable the defender
  34. Enabled bool `json:"enabled" mapstructure:"enabled"`
  35. // BanTime is the number of minutes that a host is banned
  36. BanTime int `json:"ban_time" mapstructure:"ban_time"`
  37. // Percentage increase of the ban time if a banned host tries to connect again
  38. BanTimeIncrement int `json:"ban_time_increment" mapstructure:"ban_time_increment"`
  39. // Threshold value for banning a client
  40. Threshold int `json:"threshold" mapstructure:"threshold"`
  41. // Score for invalid login attempts, eg. non-existent user accounts or
  42. // client disconnected for inactivity without authentication attempts
  43. ScoreInvalid int `json:"score_invalid" mapstructure:"score_invalid"`
  44. // Score for valid login attempts, eg. user accounts that exist
  45. ScoreValid int `json:"score_valid" mapstructure:"score_valid"`
  46. // Defines the time window, in minutes, for tracking client errors.
  47. // A host is banned if it has exceeded the defined threshold during
  48. // the last observation time minutes
  49. ObservationTime int `json:"observation_time" mapstructure:"observation_time"`
  50. // The number of banned IPs and host scores kept in memory will vary between the
  51. // soft and hard limit
  52. EntriesSoftLimit int `json:"entries_soft_limit" mapstructure:"entries_soft_limit"`
  53. EntriesHardLimit int `json:"entries_hard_limit" mapstructure:"entries_hard_limit"`
  54. // Path to a file containing a list of ip addresses and/or networks to never ban
  55. SafeListFile string `json:"safelist_file" mapstructure:"safelist_file"`
  56. // Path to a file containing a list of ip addresses and/or networks to always ban
  57. BlockListFile string `json:"blocklist_file" mapstructure:"blocklist_file"`
  58. }
  59. type memoryDefender struct {
  60. config *DefenderConfig
  61. sync.RWMutex
  62. // IP addresses of the clients trying to connected are stored inside hosts,
  63. // they are added to banned once the thresold is reached.
  64. // A violation from a banned host will increase the ban time
  65. // based on the configured BanTimeIncrement
  66. hosts map[string]hostScore // the key is the host IP
  67. banned map[string]time.Time // the key is the host IP
  68. safeList *HostList
  69. blockList *HostList
  70. }
  71. // HostListFile defines the structure expected for safe/block list files
  72. type HostListFile struct {
  73. IPAddresses []string `json:"addresses"`
  74. CIDRNetworks []string `json:"networks"`
  75. }
  76. // HostList defines the structure used to keep the HostListFile in memory
  77. type HostList struct {
  78. IPAddresses map[string]bool
  79. Ranges cidranger.Ranger
  80. }
  81. func (h *HostList) isListed(ip string) bool {
  82. if _, ok := h.IPAddresses[ip]; ok {
  83. return true
  84. }
  85. ok, err := h.Ranges.Contains(net.ParseIP(ip))
  86. if err != nil {
  87. return false
  88. }
  89. return ok
  90. }
  91. type hostEvent struct {
  92. dateTime time.Time
  93. score int
  94. }
  95. type hostScore struct {
  96. TotalScore int
  97. Events []hostEvent
  98. }
  99. // validate returns an error if the configuration is invalid
  100. func (c *DefenderConfig) validate() error {
  101. if !c.Enabled {
  102. return nil
  103. }
  104. if c.ScoreInvalid >= c.Threshold {
  105. return fmt.Errorf("score_invalid %v cannot be greater than threshold %v", c.ScoreInvalid, c.Threshold)
  106. }
  107. if c.ScoreValid >= c.Threshold {
  108. return fmt.Errorf("score_valid %v cannot be greater than threshold %v", c.ScoreValid, c.Threshold)
  109. }
  110. if c.BanTime <= 0 {
  111. return fmt.Errorf("invalid ban_time %v", c.BanTime)
  112. }
  113. if c.BanTimeIncrement <= 0 {
  114. return fmt.Errorf("invalid ban_time_increment %v", c.BanTimeIncrement)
  115. }
  116. if c.ObservationTime <= 0 {
  117. return fmt.Errorf("invalid observation_time %v", c.ObservationTime)
  118. }
  119. if c.EntriesSoftLimit <= 0 {
  120. return fmt.Errorf("invalid entries_soft_limit %v", c.EntriesSoftLimit)
  121. }
  122. if c.EntriesHardLimit <= c.EntriesSoftLimit {
  123. return fmt.Errorf("invalid entries_hard_limit %v must be > %v", c.EntriesHardLimit, c.EntriesSoftLimit)
  124. }
  125. return nil
  126. }
  127. func newInMemoryDefender(config *DefenderConfig) (Defender, error) {
  128. err := config.validate()
  129. if err != nil {
  130. return nil, err
  131. }
  132. defender := &memoryDefender{
  133. config: config,
  134. hosts: make(map[string]hostScore),
  135. banned: make(map[string]time.Time),
  136. }
  137. if err := defender.Reload(); err != nil {
  138. return nil, err
  139. }
  140. return defender, nil
  141. }
  142. // Reload reloads block and safe lists
  143. func (d *memoryDefender) Reload() error {
  144. blockList, err := loadHostListFromFile(d.config.BlockListFile)
  145. if err != nil {
  146. return err
  147. }
  148. d.Lock()
  149. d.blockList = blockList
  150. d.Unlock()
  151. safeList, err := loadHostListFromFile(d.config.SafeListFile)
  152. if err != nil {
  153. return err
  154. }
  155. d.Lock()
  156. d.safeList = safeList
  157. d.Unlock()
  158. return nil
  159. }
  160. // IsBanned returns true if the specified IP is banned
  161. // and increase ban time if the IP is found.
  162. // This method must be called as soon as the client connects
  163. func (d *memoryDefender) IsBanned(ip string) bool {
  164. d.RLock()
  165. if banTime, ok := d.banned[ip]; ok {
  166. if banTime.After(time.Now()) {
  167. increment := d.config.BanTime * d.config.BanTimeIncrement / 100
  168. if increment == 0 {
  169. increment++
  170. }
  171. d.RUnlock()
  172. // we can save an earlier ban time if there are contemporary updates
  173. // but this should not make much difference. I prefer to hold a read lock
  174. // until possible for performance reasons, this method is called each
  175. // time a new client connects and it must be as fast as possible
  176. d.Lock()
  177. d.banned[ip] = banTime.Add(time.Duration(increment) * time.Minute)
  178. d.Unlock()
  179. return true
  180. }
  181. }
  182. defer d.RUnlock()
  183. if d.blockList != nil && d.blockList.isListed(ip) {
  184. // permanent ban
  185. return true
  186. }
  187. return false
  188. }
  189. // Unban removes the specified IP address from the banned ones
  190. func (d *memoryDefender) Unban(ip string) bool {
  191. d.Lock()
  192. defer d.Unlock()
  193. if _, ok := d.banned[ip]; ok {
  194. delete(d.banned, ip)
  195. return true
  196. }
  197. return false
  198. }
  199. // AddEvent adds an event for the given IP.
  200. // This method must be called for clients not yet banned
  201. func (d *memoryDefender) AddEvent(ip string, event HostEvent) {
  202. d.Lock()
  203. defer d.Unlock()
  204. if d.safeList != nil && d.safeList.isListed(ip) {
  205. return
  206. }
  207. var score int
  208. switch event {
  209. case HostEventLoginFailed:
  210. score = d.config.ScoreValid
  211. case HostEventUserNotFound, HostEventNoLoginTried:
  212. score = d.config.ScoreInvalid
  213. }
  214. ev := hostEvent{
  215. dateTime: time.Now(),
  216. score: score,
  217. }
  218. if hs, ok := d.hosts[ip]; ok {
  219. hs.Events = append(hs.Events, ev)
  220. hs.TotalScore = 0
  221. idx := 0
  222. for _, event := range hs.Events {
  223. if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
  224. hs.Events[idx] = event
  225. hs.TotalScore += event.score
  226. idx++
  227. }
  228. }
  229. hs.Events = hs.Events[:idx]
  230. if hs.TotalScore >= d.config.Threshold {
  231. d.banned[ip] = time.Now().Add(time.Duration(d.config.BanTime) * time.Minute)
  232. delete(d.hosts, ip)
  233. d.cleanupBanned()
  234. } else {
  235. d.hosts[ip] = hs
  236. }
  237. } else {
  238. d.hosts[ip] = hostScore{
  239. TotalScore: ev.score,
  240. Events: []hostEvent{ev},
  241. }
  242. d.cleanupHosts()
  243. }
  244. }
  245. func (d *memoryDefender) countBanned() int {
  246. d.RLock()
  247. defer d.RUnlock()
  248. return len(d.banned)
  249. }
  250. func (d *memoryDefender) countHosts() int {
  251. d.RLock()
  252. defer d.RUnlock()
  253. return len(d.hosts)
  254. }
  255. // GetBanTime returns the ban time for the given IP or nil if the IP is not banned
  256. func (d *memoryDefender) GetBanTime(ip string) *time.Time {
  257. d.RLock()
  258. defer d.RUnlock()
  259. if banTime, ok := d.banned[ip]; ok {
  260. return &banTime
  261. }
  262. return nil
  263. }
  264. // GetScore returns the score for the given IP
  265. func (d *memoryDefender) GetScore(ip string) int {
  266. d.RLock()
  267. defer d.RUnlock()
  268. score := 0
  269. if hs, ok := d.hosts[ip]; ok {
  270. for _, event := range hs.Events {
  271. if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
  272. score += event.score
  273. }
  274. }
  275. }
  276. return score
  277. }
  278. func (d *memoryDefender) cleanupBanned() {
  279. if len(d.banned) > d.config.EntriesHardLimit {
  280. kvList := make(kvList, 0, len(d.banned))
  281. for k, v := range d.banned {
  282. if v.Before(time.Now()) {
  283. delete(d.banned, k)
  284. }
  285. kvList = append(kvList, kv{
  286. Key: k,
  287. Value: v.UnixNano(),
  288. })
  289. }
  290. // we removed expired ip addresses, if any, above, this could be enough
  291. numToRemove := len(d.banned) - d.config.EntriesSoftLimit
  292. if numToRemove <= 0 {
  293. return
  294. }
  295. sort.Sort(kvList)
  296. for idx, kv := range kvList {
  297. if idx >= numToRemove {
  298. break
  299. }
  300. delete(d.banned, kv.Key)
  301. }
  302. }
  303. }
  304. func (d *memoryDefender) cleanupHosts() {
  305. if len(d.hosts) > d.config.EntriesHardLimit {
  306. kvList := make(kvList, 0, len(d.hosts))
  307. for k, v := range d.hosts {
  308. value := int64(0)
  309. if len(v.Events) > 0 {
  310. value = v.Events[len(v.Events)-1].dateTime.UnixNano()
  311. }
  312. kvList = append(kvList, kv{
  313. Key: k,
  314. Value: value,
  315. })
  316. }
  317. sort.Sort(kvList)
  318. numToRemove := len(d.hosts) - d.config.EntriesSoftLimit
  319. for idx, kv := range kvList {
  320. if idx >= numToRemove {
  321. break
  322. }
  323. delete(d.hosts, kv.Key)
  324. }
  325. }
  326. }
  327. func loadHostListFromFile(name string) (*HostList, error) {
  328. if name == "" {
  329. return nil, nil
  330. }
  331. if !utils.IsFileInputValid(name) {
  332. return nil, fmt.Errorf("invalid host list file name %#v", name)
  333. }
  334. info, err := os.Stat(name)
  335. if err != nil {
  336. return nil, err
  337. }
  338. // opinionated max size, you should avoid big host lists
  339. if info.Size() > 1048576*5 { // 5MB
  340. return nil, fmt.Errorf("host list file %#v is too big: %v bytes", name, info.Size())
  341. }
  342. content, err := os.ReadFile(name)
  343. if err != nil {
  344. return nil, fmt.Errorf("unable to read input file %#v: %v", name, err)
  345. }
  346. var hostList HostListFile
  347. err = json.Unmarshal(content, &hostList)
  348. if err != nil {
  349. return nil, err
  350. }
  351. if len(hostList.CIDRNetworks) > 0 || len(hostList.IPAddresses) > 0 {
  352. result := &HostList{
  353. IPAddresses: make(map[string]bool),
  354. Ranges: cidranger.NewPCTrieRanger(),
  355. }
  356. ipCount := 0
  357. cdrCount := 0
  358. for _, ip := range hostList.IPAddresses {
  359. if net.ParseIP(ip) == nil {
  360. logger.Warn(logSender, "", "unable to parse IP %#v", ip)
  361. continue
  362. }
  363. result.IPAddresses[ip] = true
  364. ipCount++
  365. }
  366. for _, cidrNet := range hostList.CIDRNetworks {
  367. _, network, err := net.ParseCIDR(cidrNet)
  368. if err != nil {
  369. logger.Warn(logSender, "", "unable to parse CIDR network %#v", cidrNet)
  370. continue
  371. }
  372. err = result.Ranges.Insert(cidranger.NewBasicRangerEntry(*network))
  373. if err == nil {
  374. cdrCount++
  375. }
  376. }
  377. logger.Info(logSender, "", "list %#v loaded, ip addresses loaded: %v/%v networks loaded: %v/%v",
  378. name, ipCount, len(hostList.IPAddresses), cdrCount, len(hostList.CIDRNetworks))
  379. return result, nil
  380. }
  381. return nil, nil
  382. }
  383. type kv struct {
  384. Key string
  385. Value int64
  386. }
  387. type kvList []kv
  388. func (p kvList) Len() int { return len(p) }
  389. func (p kvList) Less(i, j int) bool { return p[i].Value < p[j].Value }
  390. func (p kvList) Swap(i, j int) { p[i], p[j] = p[j], p[i] }