| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479 |
- package common
- import (
- "encoding/json"
- "fmt"
- "net"
- "os"
- "sort"
- "sync"
- "time"
- "github.com/yl2chen/cidranger"
- "github.com/drakkan/sftpgo/logger"
- "github.com/drakkan/sftpgo/utils"
- )
- // HostEvent is the enumerable for the support host event
- type HostEvent int
- // Supported host events
- const (
- HostEventLoginFailed HostEvent = iota
- HostEventUserNotFound
- HostEventNoLoginTried
- HostEventRateExceeded
- )
- // Defender defines the interface that a defender must implements
- type Defender interface {
- AddEvent(ip string, event HostEvent)
- IsBanned(ip string) bool
- GetBanTime(ip string) *time.Time
- GetScore(ip string) int
- Unban(ip string) bool
- Reload() error
- }
- // DefenderConfig defines the "defender" configuration
- type DefenderConfig struct {
- // Set to true to enable the defender
- Enabled bool `json:"enabled" mapstructure:"enabled"`
- // BanTime is the number of minutes that a host is banned
- BanTime int `json:"ban_time" mapstructure:"ban_time"`
- // Percentage increase of the ban time if a banned host tries to connect again
- BanTimeIncrement int `json:"ban_time_increment" mapstructure:"ban_time_increment"`
- // Threshold value for banning a client
- Threshold int `json:"threshold" mapstructure:"threshold"`
- // Score for invalid login attempts, eg. non-existent user accounts or
- // client disconnected for inactivity without authentication attempts
- ScoreInvalid int `json:"score_invalid" mapstructure:"score_invalid"`
- // Score for valid login attempts, eg. user accounts that exist
- ScoreValid int `json:"score_valid" mapstructure:"score_valid"`
- // Score for rate exceeded events, generated from the rate limiters
- ScoreRateExceeded int `json:"score_rate_exceeded" mapstructure:"score_rate_exceeded"`
- // Defines the time window, in minutes, for tracking client errors.
- // A host is banned if it has exceeded the defined threshold during
- // the last observation time minutes
- ObservationTime int `json:"observation_time" mapstructure:"observation_time"`
- // The number of banned IPs and host scores kept in memory will vary between the
- // soft and hard limit
- EntriesSoftLimit int `json:"entries_soft_limit" mapstructure:"entries_soft_limit"`
- EntriesHardLimit int `json:"entries_hard_limit" mapstructure:"entries_hard_limit"`
- // Path to a file containing a list of ip addresses and/or networks to never ban
- SafeListFile string `json:"safelist_file" mapstructure:"safelist_file"`
- // Path to a file containing a list of ip addresses and/or networks to always ban
- BlockListFile string `json:"blocklist_file" mapstructure:"blocklist_file"`
- }
- type memoryDefender struct {
- config *DefenderConfig
- sync.RWMutex
- // IP addresses of the clients trying to connected are stored inside hosts,
- // they are added to banned once the thresold is reached.
- // A violation from a banned host will increase the ban time
- // based on the configured BanTimeIncrement
- hosts map[string]hostScore // the key is the host IP
- banned map[string]time.Time // the key is the host IP
- safeList *HostList
- blockList *HostList
- }
- // HostListFile defines the structure expected for safe/block list files
- type HostListFile struct {
- IPAddresses []string `json:"addresses"`
- CIDRNetworks []string `json:"networks"`
- }
- // HostList defines the structure used to keep the HostListFile in memory
- type HostList struct {
- IPAddresses map[string]bool
- Ranges cidranger.Ranger
- }
- func (h *HostList) isListed(ip string) bool {
- if _, ok := h.IPAddresses[ip]; ok {
- return true
- }
- ok, err := h.Ranges.Contains(net.ParseIP(ip))
- if err != nil {
- return false
- }
- return ok
- }
- type hostEvent struct {
- dateTime time.Time
- score int
- }
- type hostScore struct {
- TotalScore int
- Events []hostEvent
- }
- // validate returns an error if the configuration is invalid
- func (c *DefenderConfig) validate() error {
- if !c.Enabled {
- return nil
- }
- if c.ScoreInvalid >= c.Threshold {
- return fmt.Errorf("score_invalid %v cannot be greater than threshold %v", c.ScoreInvalid, c.Threshold)
- }
- if c.ScoreValid >= c.Threshold {
- return fmt.Errorf("score_valid %v cannot be greater than threshold %v", c.ScoreValid, c.Threshold)
- }
- if c.ScoreRateExceeded >= c.Threshold {
- return fmt.Errorf("score_rate_exceeded %v cannot be greater than threshold %v", c.ScoreRateExceeded, c.Threshold)
- }
- if c.BanTime <= 0 {
- return fmt.Errorf("invalid ban_time %v", c.BanTime)
- }
- if c.BanTimeIncrement <= 0 {
- return fmt.Errorf("invalid ban_time_increment %v", c.BanTimeIncrement)
- }
- if c.ObservationTime <= 0 {
- return fmt.Errorf("invalid observation_time %v", c.ObservationTime)
- }
- if c.EntriesSoftLimit <= 0 {
- return fmt.Errorf("invalid entries_soft_limit %v", c.EntriesSoftLimit)
- }
- if c.EntriesHardLimit <= c.EntriesSoftLimit {
- return fmt.Errorf("invalid entries_hard_limit %v must be > %v", c.EntriesHardLimit, c.EntriesSoftLimit)
- }
- return nil
- }
- func newInMemoryDefender(config *DefenderConfig) (Defender, error) {
- err := config.validate()
- if err != nil {
- return nil, err
- }
- defender := &memoryDefender{
- config: config,
- hosts: make(map[string]hostScore),
- banned: make(map[string]time.Time),
- }
- if err := defender.Reload(); err != nil {
- return nil, err
- }
- return defender, nil
- }
- // Reload reloads block and safe lists
- func (d *memoryDefender) Reload() error {
- blockList, err := loadHostListFromFile(d.config.BlockListFile)
- if err != nil {
- return err
- }
- d.Lock()
- d.blockList = blockList
- d.Unlock()
- safeList, err := loadHostListFromFile(d.config.SafeListFile)
- if err != nil {
- return err
- }
- d.Lock()
- d.safeList = safeList
- d.Unlock()
- return nil
- }
- // IsBanned returns true if the specified IP is banned
- // and increase ban time if the IP is found.
- // This method must be called as soon as the client connects
- func (d *memoryDefender) IsBanned(ip string) bool {
- d.RLock()
- if banTime, ok := d.banned[ip]; ok {
- if banTime.After(time.Now()) {
- increment := d.config.BanTime * d.config.BanTimeIncrement / 100
- if increment == 0 {
- increment++
- }
- d.RUnlock()
- // we can save an earlier ban time if there are contemporary updates
- // but this should not make much difference. I prefer to hold a read lock
- // until possible for performance reasons, this method is called each
- // time a new client connects and it must be as fast as possible
- d.Lock()
- d.banned[ip] = banTime.Add(time.Duration(increment) * time.Minute)
- d.Unlock()
- return true
- }
- }
- defer d.RUnlock()
- if d.blockList != nil && d.blockList.isListed(ip) {
- // permanent ban
- return true
- }
- return false
- }
- // Unban removes the specified IP address from the banned ones
- func (d *memoryDefender) Unban(ip string) bool {
- d.Lock()
- defer d.Unlock()
- if _, ok := d.banned[ip]; ok {
- delete(d.banned, ip)
- return true
- }
- return false
- }
- // AddEvent adds an event for the given IP.
- // This method must be called for clients not yet banned
- func (d *memoryDefender) AddEvent(ip string, event HostEvent) {
- d.Lock()
- defer d.Unlock()
- if d.safeList != nil && d.safeList.isListed(ip) {
- return
- }
- var score int
- switch event {
- case HostEventLoginFailed:
- score = d.config.ScoreValid
- case HostEventRateExceeded:
- score = d.config.ScoreRateExceeded
- case HostEventUserNotFound, HostEventNoLoginTried:
- score = d.config.ScoreInvalid
- }
- ev := hostEvent{
- dateTime: time.Now(),
- score: score,
- }
- if hs, ok := d.hosts[ip]; ok {
- hs.Events = append(hs.Events, ev)
- hs.TotalScore = 0
- idx := 0
- for _, event := range hs.Events {
- if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
- hs.Events[idx] = event
- hs.TotalScore += event.score
- idx++
- }
- }
- hs.Events = hs.Events[:idx]
- if hs.TotalScore >= d.config.Threshold {
- d.banned[ip] = time.Now().Add(time.Duration(d.config.BanTime) * time.Minute)
- delete(d.hosts, ip)
- d.cleanupBanned()
- } else {
- d.hosts[ip] = hs
- }
- } else {
- d.hosts[ip] = hostScore{
- TotalScore: ev.score,
- Events: []hostEvent{ev},
- }
- d.cleanupHosts()
- }
- }
- func (d *memoryDefender) countBanned() int {
- d.RLock()
- defer d.RUnlock()
- return len(d.banned)
- }
- func (d *memoryDefender) countHosts() int {
- d.RLock()
- defer d.RUnlock()
- return len(d.hosts)
- }
- // GetBanTime returns the ban time for the given IP or nil if the IP is not banned
- func (d *memoryDefender) GetBanTime(ip string) *time.Time {
- d.RLock()
- defer d.RUnlock()
- if banTime, ok := d.banned[ip]; ok {
- return &banTime
- }
- return nil
- }
- // GetScore returns the score for the given IP
- func (d *memoryDefender) GetScore(ip string) int {
- d.RLock()
- defer d.RUnlock()
- score := 0
- if hs, ok := d.hosts[ip]; ok {
- for _, event := range hs.Events {
- if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
- score += event.score
- }
- }
- }
- return score
- }
- func (d *memoryDefender) cleanupBanned() {
- if len(d.banned) > d.config.EntriesHardLimit {
- kvList := make(kvList, 0, len(d.banned))
- for k, v := range d.banned {
- if v.Before(time.Now()) {
- delete(d.banned, k)
- }
- kvList = append(kvList, kv{
- Key: k,
- Value: v.UnixNano(),
- })
- }
- // we removed expired ip addresses, if any, above, this could be enough
- numToRemove := len(d.banned) - d.config.EntriesSoftLimit
- if numToRemove <= 0 {
- return
- }
- sort.Sort(kvList)
- for idx, kv := range kvList {
- if idx >= numToRemove {
- break
- }
- delete(d.banned, kv.Key)
- }
- }
- }
- func (d *memoryDefender) cleanupHosts() {
- if len(d.hosts) > d.config.EntriesHardLimit {
- kvList := make(kvList, 0, len(d.hosts))
- for k, v := range d.hosts {
- value := int64(0)
- if len(v.Events) > 0 {
- value = v.Events[len(v.Events)-1].dateTime.UnixNano()
- }
- kvList = append(kvList, kv{
- Key: k,
- Value: value,
- })
- }
- sort.Sort(kvList)
- numToRemove := len(d.hosts) - d.config.EntriesSoftLimit
- for idx, kv := range kvList {
- if idx >= numToRemove {
- break
- }
- delete(d.hosts, kv.Key)
- }
- }
- }
- func loadHostListFromFile(name string) (*HostList, error) {
- if name == "" {
- return nil, nil
- }
- if !utils.IsFileInputValid(name) {
- return nil, fmt.Errorf("invalid host list file name %#v", name)
- }
- info, err := os.Stat(name)
- if err != nil {
- return nil, err
- }
- // opinionated max size, you should avoid big host lists
- if info.Size() > 1048576*5 { // 5MB
- return nil, fmt.Errorf("host list file %#v is too big: %v bytes", name, info.Size())
- }
- content, err := os.ReadFile(name)
- if err != nil {
- return nil, fmt.Errorf("unable to read input file %#v: %v", name, err)
- }
- var hostList HostListFile
- err = json.Unmarshal(content, &hostList)
- if err != nil {
- return nil, err
- }
- if len(hostList.CIDRNetworks) > 0 || len(hostList.IPAddresses) > 0 {
- result := &HostList{
- IPAddresses: make(map[string]bool),
- Ranges: cidranger.NewPCTrieRanger(),
- }
- ipCount := 0
- cdrCount := 0
- for _, ip := range hostList.IPAddresses {
- if net.ParseIP(ip) == nil {
- logger.Warn(logSender, "", "unable to parse IP %#v", ip)
- continue
- }
- result.IPAddresses[ip] = true
- ipCount++
- }
- for _, cidrNet := range hostList.CIDRNetworks {
- _, network, err := net.ParseCIDR(cidrNet)
- if err != nil {
- logger.Warn(logSender, "", "unable to parse CIDR network %#v", cidrNet)
- continue
- }
- err = result.Ranges.Insert(cidranger.NewBasicRangerEntry(*network))
- if err == nil {
- cdrCount++
- }
- }
- logger.Info(logSender, "", "list %#v loaded, ip addresses loaded: %v/%v networks loaded: %v/%v",
- name, ipCount, len(hostList.IPAddresses), cdrCount, len(hostList.CIDRNetworks))
- return result, nil
- }
- return nil, nil
- }
- type kv struct {
- Key string
- Value int64
- }
- type kvList []kv
- func (p kvList) Len() int { return len(p) }
- func (p kvList) Less(i, j int) bool { return p[i].Value < p[j].Value }
- func (p kvList) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|