| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311 | 
							- // Copyright (C) 2019-2022  Nicola Murino
 
- //
 
- // This program is free software: you can redistribute it and/or modify
 
- // it under the terms of the GNU Affero General Public License as published
 
- // by the Free Software Foundation, version 3.
 
- //
 
- // This program is distributed in the hope that it will be useful,
 
- // but WITHOUT ANY WARRANTY; without even the implied warranty of
 
- // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 
- // GNU Affero General Public License for more details.
 
- //
 
- // You should have received a copy of the GNU Affero General Public License
 
- // along with this program.  If not, see <https://www.gnu.org/licenses/>.
 
- package common
 
- import (
 
- 	"encoding/hex"
 
- 	"encoding/json"
 
- 	"os"
 
- 	"path/filepath"
 
- 	"testing"
 
- 	"time"
 
- 	"github.com/stretchr/testify/assert"
 
- 	"github.com/drakkan/sftpgo/v2/internal/dataprovider"
 
- 	"github.com/drakkan/sftpgo/v2/internal/util"
 
- )
 
- func TestBasicDbDefender(t *testing.T) {
 
- 	if !isDbDefenderSupported() {
 
- 		t.Skip("this test is not supported with the current database provider")
 
- 	}
 
- 	config := &DefenderConfig{
 
- 		Enabled:            true,
 
- 		BanTime:            10,
 
- 		BanTimeIncrement:   2,
 
- 		Threshold:          5,
 
- 		ScoreInvalid:       2,
 
- 		ScoreValid:         1,
 
- 		ScoreLimitExceeded: 3,
 
- 		ObservationTime:    15,
 
- 		EntriesSoftLimit:   1,
 
- 		EntriesHardLimit:   10,
 
- 		SafeListFile:       "slFile",
 
- 		BlockListFile:      "blFile",
 
- 	}
 
- 	_, err := newDBDefender(config)
 
- 	assert.Error(t, err)
 
- 	bl := HostListFile{
 
- 		IPAddresses:  []string{"172.16.1.1", "172.16.1.2"},
 
- 		CIDRNetworks: []string{"10.8.0.0/24"},
 
- 	}
 
- 	sl := HostListFile{
 
- 		IPAddresses:  []string{"172.16.1.3", "172.16.1.4"},
 
- 		CIDRNetworks: []string{"192.168.8.0/24"},
 
- 	}
 
- 	blFile := filepath.Join(os.TempDir(), "bl.json")
 
- 	slFile := filepath.Join(os.TempDir(), "sl.json")
 
- 	data, err := json.Marshal(bl)
 
- 	assert.NoError(t, err)
 
- 	err = os.WriteFile(blFile, data, os.ModePerm)
 
- 	assert.NoError(t, err)
 
- 	data, err = json.Marshal(sl)
 
- 	assert.NoError(t, err)
 
- 	err = os.WriteFile(slFile, data, os.ModePerm)
 
- 	assert.NoError(t, err)
 
- 	config.BlockListFile = blFile
 
- 	_, err = newDBDefender(config)
 
- 	assert.Error(t, err)
 
- 	config.SafeListFile = slFile
 
- 	d, err := newDBDefender(config)
 
- 	assert.NoError(t, err)
 
- 	defender := d.(*dbDefender)
 
- 	assert.True(t, defender.IsBanned("172.16.1.1"))
 
- 	assert.False(t, defender.IsBanned("172.16.1.10"))
 
- 	assert.False(t, defender.IsBanned("10.8.1.3"))
 
- 	assert.True(t, defender.IsBanned("10.8.0.4"))
 
- 	assert.False(t, defender.IsBanned("invalid ip"))
 
- 	hosts, err := defender.GetHosts()
 
- 	assert.NoError(t, err)
 
- 	assert.Len(t, hosts, 0)
 
- 	_, err = defender.GetHost("10.8.0.3")
 
- 	assert.Error(t, err)
 
- 	defender.AddEvent("172.16.1.4", HostEventLoginFailed)
 
- 	defender.AddEvent("192.168.8.4", HostEventUserNotFound)
 
- 	defender.AddEvent("172.16.1.3", HostEventLimitExceeded)
 
- 	hosts, err = defender.GetHosts()
 
- 	assert.NoError(t, err)
 
- 	assert.Len(t, hosts, 0)
 
- 	assert.True(t, defender.getLastCleanup().IsZero())
 
- 	testIP := "123.45.67.89"
 
- 	defender.AddEvent(testIP, HostEventLoginFailed)
 
- 	lastCleanup := defender.getLastCleanup()
 
- 	assert.False(t, lastCleanup.IsZero())
 
- 	score, err := defender.GetScore(testIP)
 
- 	assert.NoError(t, err)
 
- 	assert.Equal(t, 1, score)
 
- 	hosts, err = defender.GetHosts()
 
- 	assert.NoError(t, err)
 
- 	if assert.Len(t, hosts, 1) {
 
- 		assert.Equal(t, 1, hosts[0].Score)
 
- 		assert.True(t, hosts[0].BanTime.IsZero())
 
- 		assert.Empty(t, hosts[0].GetBanTime())
 
- 	}
 
- 	host, err := defender.GetHost(testIP)
 
- 	assert.NoError(t, err)
 
- 	assert.Equal(t, 1, host.Score)
 
- 	assert.Empty(t, host.GetBanTime())
 
- 	banTime, err := defender.GetBanTime(testIP)
 
- 	assert.NoError(t, err)
 
- 	assert.Nil(t, banTime)
 
- 	defender.AddEvent(testIP, HostEventLimitExceeded)
 
- 	score, err = defender.GetScore(testIP)
 
- 	assert.NoError(t, err)
 
- 	assert.Equal(t, 4, score)
 
- 	hosts, err = defender.GetHosts()
 
- 	assert.NoError(t, err)
 
- 	if assert.Len(t, hosts, 1) {
 
- 		assert.Equal(t, 4, hosts[0].Score)
 
- 		assert.True(t, hosts[0].BanTime.IsZero())
 
- 		assert.Empty(t, hosts[0].GetBanTime())
 
- 	}
 
- 	defender.AddEvent(testIP, HostEventNoLoginTried)
 
- 	defender.AddEvent(testIP, HostEventNoLoginTried)
 
- 	score, err = defender.GetScore(testIP)
 
- 	assert.NoError(t, err)
 
- 	assert.Equal(t, 0, score)
 
- 	banTime, err = defender.GetBanTime(testIP)
 
- 	assert.NoError(t, err)
 
- 	assert.NotNil(t, banTime)
 
- 	hosts, err = defender.GetHosts()
 
- 	assert.NoError(t, err)
 
- 	if assert.Len(t, hosts, 1) {
 
- 		assert.Equal(t, 0, hosts[0].Score)
 
- 		assert.False(t, hosts[0].BanTime.IsZero())
 
- 		assert.NotEmpty(t, hosts[0].GetBanTime())
 
- 		assert.Equal(t, hex.EncodeToString([]byte(testIP)), hosts[0].GetID())
 
- 	}
 
- 	host, err = defender.GetHost(testIP)
 
- 	assert.NoError(t, err)
 
- 	assert.Equal(t, 0, host.Score)
 
- 	assert.NotEmpty(t, host.GetBanTime())
 
- 	// ban time should increase
 
- 	assert.True(t, defender.IsBanned(testIP))
 
- 	newBanTime, err := defender.GetBanTime(testIP)
 
- 	assert.NoError(t, err)
 
- 	assert.True(t, newBanTime.After(*banTime))
 
- 	assert.True(t, defender.DeleteHost(testIP))
 
- 	assert.False(t, defender.DeleteHost(testIP))
 
- 	// test cleanup
 
- 	testIP1 := "123.45.67.90"
 
- 	testIP2 := "123.45.67.91"
 
- 	testIP3 := "123.45.67.92"
 
- 	for i := 0; i < 3; i++ {
 
- 		defender.AddEvent(testIP, HostEventNoLoginTried)
 
- 		defender.AddEvent(testIP1, HostEventNoLoginTried)
 
- 		defender.AddEvent(testIP2, HostEventNoLoginTried)
 
- 	}
 
- 	hosts, err = defender.GetHosts()
 
- 	assert.NoError(t, err)
 
- 	assert.Len(t, hosts, 3)
 
- 	for _, host := range hosts {
 
- 		assert.Equal(t, 0, host.Score)
 
- 		assert.False(t, host.BanTime.IsZero())
 
- 		assert.NotEmpty(t, host.GetBanTime())
 
- 	}
 
- 	defender.AddEvent(testIP3, HostEventLoginFailed)
 
- 	hosts, err = defender.GetHosts()
 
- 	assert.NoError(t, err)
 
- 	assert.Len(t, hosts, 4)
 
- 	// now set a ban time in the past, so the host will be cleanead up
 
- 	for _, ip := range []string{testIP1, testIP2} {
 
- 		err = dataprovider.SetDefenderBanTime(ip, util.GetTimeAsMsSinceEpoch(time.Now().Add(-1*time.Minute)))
 
- 		assert.NoError(t, err)
 
- 	}
 
- 	hosts, err = defender.GetHosts()
 
- 	assert.NoError(t, err)
 
- 	assert.Len(t, hosts, 4)
 
- 	for _, host := range hosts {
 
- 		switch host.IP {
 
- 		case testIP:
 
- 			assert.Equal(t, 0, host.Score)
 
- 			assert.False(t, host.BanTime.IsZero())
 
- 			assert.NotEmpty(t, host.GetBanTime())
 
- 		case testIP3:
 
- 			assert.Equal(t, 1, host.Score)
 
- 			assert.True(t, host.BanTime.IsZero())
 
- 			assert.Empty(t, host.GetBanTime())
 
- 		default:
 
- 			assert.Equal(t, 6, host.Score)
 
- 			assert.True(t, host.BanTime.IsZero())
 
- 			assert.Empty(t, host.GetBanTime())
 
- 		}
 
- 	}
 
- 	host, err = defender.GetHost(testIP)
 
- 	assert.NoError(t, err)
 
- 	assert.Equal(t, 0, host.Score)
 
- 	assert.False(t, host.BanTime.IsZero())
 
- 	assert.NotEmpty(t, host.GetBanTime())
 
- 	host, err = defender.GetHost(testIP3)
 
- 	assert.NoError(t, err)
 
- 	assert.Equal(t, 1, host.Score)
 
- 	assert.True(t, host.BanTime.IsZero())
 
- 	assert.Empty(t, host.GetBanTime())
 
- 	// set a negative observation time so the from field in the queries will be in the future
 
- 	// we still should get the banned hosts
 
- 	defender.config.ObservationTime = -2
 
- 	assert.Greater(t, defender.getStartObservationTime(), time.Now().UnixMilli())
 
- 	hosts, err = defender.GetHosts()
 
- 	assert.NoError(t, err)
 
- 	if assert.Len(t, hosts, 1) {
 
- 		assert.Equal(t, testIP, hosts[0].IP)
 
- 		assert.Equal(t, 0, hosts[0].Score)
 
- 		assert.False(t, hosts[0].BanTime.IsZero())
 
- 		assert.NotEmpty(t, hosts[0].GetBanTime())
 
- 	}
 
- 	_, err = defender.GetHost(testIP)
 
- 	assert.NoError(t, err)
 
- 	// cleanup db
 
- 	err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute)))
 
- 	assert.NoError(t, err)
 
- 	// the banned host must still be there
 
- 	hosts, err = defender.GetHosts()
 
- 	assert.NoError(t, err)
 
- 	if assert.Len(t, hosts, 1) {
 
- 		assert.Equal(t, testIP, hosts[0].IP)
 
- 		assert.Equal(t, 0, hosts[0].Score)
 
- 		assert.False(t, hosts[0].BanTime.IsZero())
 
- 		assert.NotEmpty(t, hosts[0].GetBanTime())
 
- 	}
 
- 	_, err = defender.GetHost(testIP)
 
- 	assert.NoError(t, err)
 
- 	err = dataprovider.SetDefenderBanTime(testIP, util.GetTimeAsMsSinceEpoch(time.Now().Add(-1*time.Minute)))
 
- 	assert.NoError(t, err)
 
- 	err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute)))
 
- 	assert.NoError(t, err)
 
- 	hosts, err = defender.GetHosts()
 
- 	assert.NoError(t, err)
 
- 	assert.Len(t, hosts, 0)
 
- 	err = os.Remove(slFile)
 
- 	assert.NoError(t, err)
 
- 	err = os.Remove(blFile)
 
- 	assert.NoError(t, err)
 
- }
 
- func TestDbDefenderCleanup(t *testing.T) {
 
- 	if !isDbDefenderSupported() {
 
- 		t.Skip("this test is not supported with the current database provider")
 
- 	}
 
- 	config := &DefenderConfig{
 
- 		Enabled:            true,
 
- 		BanTime:            10,
 
- 		BanTimeIncrement:   2,
 
- 		Threshold:          5,
 
- 		ScoreInvalid:       2,
 
- 		ScoreValid:         1,
 
- 		ScoreLimitExceeded: 3,
 
- 		ObservationTime:    15,
 
- 		EntriesSoftLimit:   1,
 
- 		EntriesHardLimit:   10,
 
- 	}
 
- 	d, err := newDBDefender(config)
 
- 	assert.NoError(t, err)
 
- 	defender := d.(*dbDefender)
 
- 	lastCleanup := defender.getLastCleanup()
 
- 	assert.True(t, lastCleanup.IsZero())
 
- 	defender.cleanup()
 
- 	lastCleanup = defender.getLastCleanup()
 
- 	assert.False(t, lastCleanup.IsZero())
 
- 	defender.cleanup()
 
- 	assert.Equal(t, lastCleanup, defender.getLastCleanup())
 
- 	defender.setLastCleanup(time.Now().Add(-time.Duration(config.ObservationTime) * time.Minute * 4))
 
- 	time.Sleep(20 * time.Millisecond)
 
- 	defender.cleanup()
 
- 	assert.True(t, lastCleanup.Before(defender.getLastCleanup()))
 
- 	providerConf := dataprovider.GetProviderConfig()
 
- 	err = dataprovider.Close()
 
- 	assert.NoError(t, err)
 
- 	lastCleanup = time.Now().Add(-time.Duration(config.ObservationTime) * time.Minute * 4)
 
- 	defender.setLastCleanup(lastCleanup)
 
- 	defender.cleanup()
 
- 	// cleanup will fail and so last cleanup should be reset to the previous value
 
- 	assert.Equal(t, lastCleanup, defender.getLastCleanup())
 
- 	err = dataprovider.Initialize(providerConf, configDir, true)
 
- 	assert.NoError(t, err)
 
- }
 
- func isDbDefenderSupported() bool {
 
- 	// SQLite shares the implementation with other SQL-based provider but it makes no sense
 
- 	// to use it outside test cases
 
- 	switch dataprovider.GetProviderStatus().Driver {
 
- 	case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName,
 
- 		dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName:
 
- 		return true
 
- 	default:
 
- 		return false
 
- 	}
 
- }
 
 
  |