iplist.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. // Copyright (C) 2019 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package dataprovider
  15. import (
  16. "encoding/json"
  17. "fmt"
  18. "net"
  19. "net/netip"
  20. "slices"
  21. "strings"
  22. "sync"
  23. "sync/atomic"
  24. "github.com/yl2chen/cidranger"
  25. "github.com/drakkan/sftpgo/v2/internal/logger"
  26. "github.com/drakkan/sftpgo/v2/internal/util"
  27. )
  28. const (
  29. // maximum number of entries to match in memory
  30. // if the list contains more elements than this limit a
  31. // database query will be executed
  32. ipListMemoryLimit = 15000
  33. )
  34. var (
  35. inMemoryLists map[IPListType]*IPList
  36. )
  37. func init() {
  38. inMemoryLists = map[IPListType]*IPList{}
  39. }
  40. // IPListType is the enumerable for the supported IP list types
  41. type IPListType int
  42. // AsString returns the string representation for the list type
  43. func (t IPListType) AsString() string {
  44. switch t {
  45. case IPListTypeAllowList:
  46. return "Allow list"
  47. case IPListTypeDefender:
  48. return "Defender"
  49. case IPListTypeRateLimiterSafeList:
  50. return "Rate limiters safe list"
  51. default:
  52. return ""
  53. }
  54. }
  55. // Supported IP list types
  56. const (
  57. IPListTypeAllowList IPListType = iota + 1
  58. IPListTypeDefender
  59. IPListTypeRateLimiterSafeList
  60. )
  61. // Supported IP list modes
  62. const (
  63. ListModeAllow = iota + 1
  64. ListModeDeny
  65. )
  66. const (
  67. ipTypeV4 = iota + 1
  68. ipTypeV6
  69. )
  70. var (
  71. supportedIPListType = []IPListType{IPListTypeAllowList, IPListTypeDefender, IPListTypeRateLimiterSafeList}
  72. )
  73. // CheckIPListType returns an error if the provided IP list type is not valid
  74. func CheckIPListType(t IPListType) error {
  75. if !slices.Contains(supportedIPListType, t) {
  76. return util.NewValidationError(fmt.Sprintf("invalid list type %d", t))
  77. }
  78. return nil
  79. }
  80. // IPListEntry defines an entry for the IP addresses list
  81. type IPListEntry struct {
  82. IPOrNet string `json:"ipornet"`
  83. Description string `json:"description,omitempty"`
  84. Type IPListType `json:"type"`
  85. Mode int `json:"mode"`
  86. // Defines the protocols the entry applies to
  87. // - 0 all the supported protocols
  88. // - 1 SSH
  89. // - 2 FTP
  90. // - 4 WebDAV
  91. // - 8 HTTP
  92. // Protocols can be combined
  93. Protocols int `json:"protocols"`
  94. First []byte `json:"first,omitempty"`
  95. Last []byte `json:"last,omitempty"`
  96. IPType int `json:"ip_type,omitempty"`
  97. // Creation time as unix timestamp in milliseconds
  98. CreatedAt int64 `json:"created_at"`
  99. // last update time as unix timestamp in milliseconds
  100. UpdatedAt int64 `json:"updated_at"`
  101. // in multi node setups we mark the rule as deleted to be able to update the cache
  102. DeletedAt int64 `json:"-"`
  103. }
  104. // PrepareForRendering prepares an IP list entry for rendering.
  105. // It hides internal fields
  106. func (e *IPListEntry) PrepareForRendering() {
  107. e.First = nil
  108. e.Last = nil
  109. e.IPType = 0
  110. }
  111. // HasProtocol returns true if the specified protocol is defined
  112. func (e *IPListEntry) HasProtocol(proto string) bool {
  113. switch proto {
  114. case protocolSSH:
  115. return e.Protocols&1 != 0
  116. case protocolFTP:
  117. return e.Protocols&2 != 0
  118. case protocolWebDAV:
  119. return e.Protocols&4 != 0
  120. case protocolHTTP:
  121. return e.Protocols&8 != 0
  122. default:
  123. return false
  124. }
  125. }
  126. // RenderAsJSON implements the renderer interface used within plugins
  127. func (e *IPListEntry) RenderAsJSON(reload bool) ([]byte, error) {
  128. if reload {
  129. entry, err := provider.ipListEntryExists(e.IPOrNet, e.Type)
  130. if err != nil {
  131. providerLog(logger.LevelError, "unable to reload IP list entry before rendering as json: %v", err)
  132. return nil, err
  133. }
  134. entry.PrepareForRendering()
  135. return json.Marshal(entry)
  136. }
  137. e.PrepareForRendering()
  138. return json.Marshal(e)
  139. }
  140. func (e *IPListEntry) getKey() string {
  141. return fmt.Sprintf("%d_%s", e.Type, e.IPOrNet)
  142. }
  143. func (e *IPListEntry) getName() string {
  144. return e.Type.AsString() + "-" + e.IPOrNet
  145. }
  146. func (e *IPListEntry) getFirst() netip.Addr {
  147. if e.IPType == ipTypeV4 {
  148. var a4 [4]byte
  149. copy(a4[:], e.First)
  150. return netip.AddrFrom4(a4)
  151. }
  152. var a16 [16]byte
  153. copy(a16[:], e.First)
  154. return netip.AddrFrom16(a16)
  155. }
  156. func (e *IPListEntry) getLast() netip.Addr {
  157. if e.IPType == ipTypeV4 {
  158. var a4 [4]byte
  159. copy(a4[:], e.Last)
  160. return netip.AddrFrom4(a4)
  161. }
  162. var a16 [16]byte
  163. copy(a16[:], e.Last)
  164. return netip.AddrFrom16(a16)
  165. }
  166. func (e *IPListEntry) checkProtocols() {
  167. for _, proto := range ValidProtocols {
  168. if !e.HasProtocol(proto) {
  169. return
  170. }
  171. }
  172. e.Protocols = 0
  173. }
  174. func (e *IPListEntry) validate() error {
  175. if err := CheckIPListType(e.Type); err != nil {
  176. return err
  177. }
  178. e.checkProtocols()
  179. switch e.Type {
  180. case IPListTypeDefender:
  181. if e.Mode < ListModeAllow || e.Mode > ListModeDeny {
  182. return util.NewValidationError(fmt.Sprintf("invalid list mode: %d", e.Mode))
  183. }
  184. default:
  185. if e.Mode != ListModeAllow {
  186. return util.NewValidationError("invalid list mode")
  187. }
  188. }
  189. e.PrepareForRendering()
  190. if !strings.Contains(e.IPOrNet, "/") {
  191. // parse as IP
  192. parsed, err := netip.ParseAddr(e.IPOrNet)
  193. if err != nil {
  194. return util.NewI18nError(util.NewValidationError(fmt.Sprintf("invalid IP %q", e.IPOrNet)), util.I18nErrorIPInvalid)
  195. }
  196. if parsed.Is4() {
  197. e.IPOrNet += "/32"
  198. } else if parsed.Is4In6() {
  199. e.IPOrNet = netip.AddrFrom4(parsed.As4()).String() + "/32"
  200. } else {
  201. e.IPOrNet += "/128"
  202. }
  203. }
  204. prefix, err := netip.ParsePrefix(e.IPOrNet)
  205. if err != nil {
  206. return util.NewI18nError(util.NewValidationError(fmt.Sprintf("invalid network %q: %v", e.IPOrNet, err)), util.I18nErrorNetInvalid)
  207. }
  208. prefix = prefix.Masked()
  209. if prefix.Addr().Is4In6() {
  210. e.IPOrNet = fmt.Sprintf("%s/%d", netip.AddrFrom4(prefix.Addr().As4()).String(), prefix.Bits()-96)
  211. }
  212. // TODO: to remove when the in memory ranger switch to netip
  213. _, _, err = net.ParseCIDR(e.IPOrNet)
  214. if err != nil {
  215. return util.NewI18nError(util.NewValidationError(fmt.Sprintf("invalid network: %v", err)), util.I18nErrorNetInvalid)
  216. }
  217. if prefix.Addr().Is4() || prefix.Addr().Is4In6() {
  218. e.IPType = ipTypeV4
  219. first := prefix.Addr().As4()
  220. last := util.GetLastIPForPrefix(prefix).As4()
  221. e.First = first[:]
  222. e.Last = last[:]
  223. } else {
  224. e.IPType = ipTypeV6
  225. first := prefix.Addr().As16()
  226. last := util.GetLastIPForPrefix(prefix).As16()
  227. e.First = first[:]
  228. e.Last = last[:]
  229. }
  230. return nil
  231. }
  232. func (e *IPListEntry) getACopy() IPListEntry {
  233. first := make([]byte, len(e.First))
  234. copy(first, e.First)
  235. last := make([]byte, len(e.Last))
  236. copy(last, e.Last)
  237. return IPListEntry{
  238. IPOrNet: e.IPOrNet,
  239. Description: e.Description,
  240. Type: e.Type,
  241. Mode: e.Mode,
  242. First: first,
  243. Last: last,
  244. IPType: e.IPType,
  245. Protocols: e.Protocols,
  246. CreatedAt: e.CreatedAt,
  247. UpdatedAt: e.UpdatedAt,
  248. DeletedAt: e.DeletedAt,
  249. }
  250. }
  251. // getAsRangerEntry returns the entry as cidranger.RangerEntry
  252. func (e *IPListEntry) getAsRangerEntry() (cidranger.RangerEntry, error) {
  253. _, network, err := net.ParseCIDR(e.IPOrNet)
  254. if err != nil {
  255. return nil, err
  256. }
  257. entry := e.getACopy()
  258. return &rangerEntry{
  259. entry: &entry,
  260. network: *network,
  261. }, nil
  262. }
  263. func (e IPListEntry) satisfySearchConstraints(filter, from, order string) bool {
  264. if filter != "" && !strings.HasPrefix(e.IPOrNet, filter) {
  265. return false
  266. }
  267. if from != "" {
  268. if order == OrderASC {
  269. return e.IPOrNet > from
  270. }
  271. return e.IPOrNet < from
  272. }
  273. return true
  274. }
  275. type rangerEntry struct {
  276. entry *IPListEntry
  277. network net.IPNet
  278. }
  279. func (e *rangerEntry) Network() net.IPNet {
  280. return e.network
  281. }
  282. // IPList defines an IP list
  283. type IPList struct {
  284. isInMemory atomic.Bool
  285. listType IPListType
  286. mu sync.RWMutex
  287. Ranges cidranger.Ranger
  288. }
  289. func (l *IPList) addEntry(e *IPListEntry) {
  290. if l.listType != e.Type {
  291. return
  292. }
  293. if !l.isInMemory.Load() {
  294. return
  295. }
  296. entry, err := e.getAsRangerEntry()
  297. if err != nil {
  298. providerLog(logger.LevelError, "unable to get entry to add %q for list type %d, disabling memory mode, err: %v",
  299. e.IPOrNet, l.listType, err)
  300. l.isInMemory.Store(false)
  301. return
  302. }
  303. l.mu.Lock()
  304. defer l.mu.Unlock()
  305. if err := l.Ranges.Insert(entry); err != nil {
  306. providerLog(logger.LevelError, "unable to add entry %q for list type %d, disabling memory mode, err: %v",
  307. e.IPOrNet, l.listType, err)
  308. l.isInMemory.Store(false)
  309. return
  310. }
  311. if l.Ranges.Len() >= ipListMemoryLimit {
  312. providerLog(logger.LevelError, "memory limit exceeded for list type %d, disabling memory mode", l.listType)
  313. l.isInMemory.Store(false)
  314. }
  315. }
  316. func (l *IPList) removeEntry(e *IPListEntry) {
  317. if l.listType != e.Type {
  318. return
  319. }
  320. if !l.isInMemory.Load() {
  321. return
  322. }
  323. entry, err := e.getAsRangerEntry()
  324. if err != nil {
  325. providerLog(logger.LevelError, "unable to get entry to remove %q for list type %d, disabling memory mode, err: %v",
  326. e.IPOrNet, l.listType, err)
  327. l.isInMemory.Store(false)
  328. return
  329. }
  330. l.mu.Lock()
  331. defer l.mu.Unlock()
  332. if _, err := l.Ranges.Remove(entry.Network()); err != nil {
  333. providerLog(logger.LevelError, "unable to remove entry %q for list type %d, disabling memory mode, err: %v",
  334. e.IPOrNet, l.listType, err)
  335. l.isInMemory.Store(false)
  336. }
  337. }
  338. func (l *IPList) updateEntry(e *IPListEntry) {
  339. if l.listType != e.Type {
  340. return
  341. }
  342. if !l.isInMemory.Load() {
  343. return
  344. }
  345. entry, err := e.getAsRangerEntry()
  346. if err != nil {
  347. providerLog(logger.LevelError, "unable to get entry to update %q for list type %d, disabling memory mode, err: %v",
  348. e.IPOrNet, l.listType, err)
  349. l.isInMemory.Store(false)
  350. return
  351. }
  352. l.mu.Lock()
  353. defer l.mu.Unlock()
  354. if _, err := l.Ranges.Remove(entry.Network()); err != nil {
  355. providerLog(logger.LevelError, "unable to remove entry to update %q for list type %d, disabling memory mode, err: %v",
  356. e.IPOrNet, l.listType, err)
  357. l.isInMemory.Store(false)
  358. return
  359. }
  360. if err := l.Ranges.Insert(entry); err != nil {
  361. providerLog(logger.LevelError, "unable to add entry to update %q for list type %d, disabling memory mode, err: %v",
  362. e.IPOrNet, l.listType, err)
  363. l.isInMemory.Store(false)
  364. }
  365. if l.Ranges.Len() >= ipListMemoryLimit {
  366. providerLog(logger.LevelError, "memory limit exceeded for list type %d, disabling memory mode", l.listType)
  367. l.isInMemory.Store(false)
  368. }
  369. }
  370. // DisableMemoryMode disables memory mode forcing database queries
  371. func (l *IPList) DisableMemoryMode() {
  372. l.isInMemory.Store(false)
  373. }
  374. // IsListed checks if there is a match for the specified IP and protocol.
  375. // If there are multiple matches, the first one is returned, in no particular order,
  376. // so the behavior is undefined
  377. func (l *IPList) IsListed(ip, protocol string) (bool, int, error) {
  378. if l.isInMemory.Load() {
  379. l.mu.RLock()
  380. defer l.mu.RUnlock()
  381. parsedIP := net.ParseIP(ip)
  382. if parsedIP == nil {
  383. return false, 0, fmt.Errorf("invalid IP %s", ip)
  384. }
  385. entries, err := l.Ranges.ContainingNetworks(parsedIP)
  386. if err != nil {
  387. return false, 0, fmt.Errorf("unable to find containing networks for ip %q: %w", ip, err)
  388. }
  389. for _, e := range entries {
  390. entry, ok := e.(*rangerEntry)
  391. if ok {
  392. if entry.entry.Protocols == 0 || entry.entry.HasProtocol(protocol) {
  393. return true, entry.entry.Mode, nil
  394. }
  395. }
  396. }
  397. return false, 0, nil
  398. }
  399. entries, err := provider.getListEntriesForIP(ip, l.listType)
  400. if err != nil {
  401. return false, 0, err
  402. }
  403. for _, e := range entries {
  404. if e.Protocols == 0 || e.HasProtocol(protocol) {
  405. return true, e.Mode, nil
  406. }
  407. }
  408. return false, 0, nil
  409. }
  410. // NewIPList returns a new IP list for the specified type
  411. func NewIPList(listType IPListType) (*IPList, error) {
  412. delete(inMemoryLists, listType)
  413. count, err := provider.countIPListEntries(listType)
  414. if err != nil {
  415. return nil, err
  416. }
  417. if count < ipListMemoryLimit {
  418. providerLog(logger.LevelInfo, "using in-memory matching for list type %d, num entries: %d", listType, count)
  419. entries, err := provider.getIPListEntries(listType, "", "", OrderASC, 0)
  420. if err != nil {
  421. return nil, err
  422. }
  423. ipList := &IPList{
  424. listType: listType,
  425. Ranges: cidranger.NewPCTrieRanger(),
  426. }
  427. for idx := range entries {
  428. e := entries[idx]
  429. entry, err := e.getAsRangerEntry()
  430. if err != nil {
  431. return nil, fmt.Errorf("unable to get ranger for entry %q: %w", e.IPOrNet, err)
  432. }
  433. if err := ipList.Ranges.Insert(entry); err != nil {
  434. return nil, fmt.Errorf("unable to add ranger for entry %q: %w", e.IPOrNet, err)
  435. }
  436. }
  437. ipList.isInMemory.Store(true)
  438. inMemoryLists[listType] = ipList
  439. return ipList, nil
  440. }
  441. providerLog(logger.LevelInfo, "list type %d has %d entries, in-memory matching disabled", listType, count)
  442. ipList := &IPList{
  443. listType: listType,
  444. Ranges: nil,
  445. }
  446. ipList.isInMemory.Store(false)
  447. return ipList, nil
  448. }