iplist.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. // Copyright (C) 2019-2023 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package dataprovider
  15. import (
  16. "encoding/json"
  17. "fmt"
  18. "net"
  19. "net/netip"
  20. "strings"
  21. "sync"
  22. "sync/atomic"
  23. "github.com/yl2chen/cidranger"
  24. "github.com/drakkan/sftpgo/v2/internal/logger"
  25. "github.com/drakkan/sftpgo/v2/internal/util"
  26. )
  27. const (
  28. // maximum number of entries to match in memory
  29. // if the list contains more elements than this limit a
  30. // database query will be executed
  31. ipListMemoryLimit = 15000
  32. )
  33. var (
  34. inMemoryLists map[IPListType]*IPList
  35. )
  36. func init() {
  37. inMemoryLists = map[IPListType]*IPList{}
  38. }
  39. // IPListType is the enumerable for the supported IP list types
  40. type IPListType int
  41. // AsString returns the string representation for the list type
  42. func (t IPListType) AsString() string {
  43. switch t {
  44. case IPListTypeAllowList:
  45. return "Allow list"
  46. case IPListTypeDefender:
  47. return "Defender"
  48. case IPListTypeRateLimiterSafeList:
  49. return "Rate limiters safe list"
  50. default:
  51. return ""
  52. }
  53. }
  54. // Supported IP list types
  55. const (
  56. IPListTypeAllowList IPListType = iota + 1
  57. IPListTypeDefender
  58. IPListTypeRateLimiterSafeList
  59. )
  60. // Supported IP list modes
  61. const (
  62. ListModeAllow = iota + 1
  63. ListModeDeny
  64. )
  65. const (
  66. ipTypeV4 = iota + 1
  67. ipTypeV6
  68. )
  69. var (
  70. supportedIPListType = []IPListType{IPListTypeAllowList, IPListTypeDefender, IPListTypeRateLimiterSafeList}
  71. )
  72. // CheckIPListType returns an error if the provided IP list type is not valid
  73. func CheckIPListType(t IPListType) error {
  74. if !util.Contains(supportedIPListType, t) {
  75. return util.NewValidationError(fmt.Sprintf("invalid list type %d", t))
  76. }
  77. return nil
  78. }
  79. // IPListEntry defines an entry for the IP addresses list
  80. type IPListEntry struct {
  81. IPOrNet string `json:"ipornet"`
  82. Description string `json:"description,omitempty"`
  83. Type IPListType `json:"type"`
  84. Mode int `json:"mode"`
  85. // Defines the protocols the entry applies to
  86. // - 0 all the supported protocols
  87. // - 1 SSH
  88. // - 2 FTP
  89. // - 4 WebDAV
  90. // - 8 HTTP
  91. // Protocols can be combined
  92. Protocols int `json:"protocols"`
  93. First []byte `json:"first,omitempty"`
  94. Last []byte `json:"last,omitempty"`
  95. IPType int `json:"ip_type,omitempty"`
  96. // Creation time as unix timestamp in milliseconds
  97. CreatedAt int64 `json:"created_at"`
  98. // last update time as unix timestamp in milliseconds
  99. UpdatedAt int64 `json:"updated_at"`
  100. // in multi node setups we mark the rule as deleted to be able to update the cache
  101. DeletedAt int64 `json:"-"`
  102. }
  103. // PrepareForRendering prepares an IP list entry for rendering.
  104. // It hides internal fields
  105. func (e *IPListEntry) PrepareForRendering() {
  106. e.First = nil
  107. e.Last = nil
  108. e.IPType = 0
  109. }
  110. // HasProtocol returns true if the specified protocol is defined
  111. func (e *IPListEntry) HasProtocol(proto string) bool {
  112. switch proto {
  113. case protocolSSH:
  114. return e.Protocols&1 != 0
  115. case protocolFTP:
  116. return e.Protocols&2 != 0
  117. case protocolWebDAV:
  118. return e.Protocols&4 != 0
  119. case protocolHTTP:
  120. return e.Protocols&8 != 0
  121. default:
  122. return false
  123. }
  124. }
  125. // RenderAsJSON implements the renderer interface used within plugins
  126. func (e *IPListEntry) RenderAsJSON(reload bool) ([]byte, error) {
  127. if reload {
  128. entry, err := provider.ipListEntryExists(e.IPOrNet, e.Type)
  129. if err != nil {
  130. providerLog(logger.LevelError, "unable to reload IP list entry before rendering as json: %v", err)
  131. return nil, err
  132. }
  133. entry.PrepareForRendering()
  134. return json.Marshal(entry)
  135. }
  136. e.PrepareForRendering()
  137. return json.Marshal(e)
  138. }
  139. func (e *IPListEntry) getKey() string {
  140. return fmt.Sprintf("%d_%s", e.Type, e.IPOrNet)
  141. }
  142. func (e *IPListEntry) getName() string {
  143. return e.Type.AsString() + "-" + e.IPOrNet
  144. }
  145. func (e *IPListEntry) getFirst() netip.Addr {
  146. if e.IPType == ipTypeV4 {
  147. var a4 [4]byte
  148. copy(a4[:], e.First)
  149. return netip.AddrFrom4(a4)
  150. }
  151. var a16 [16]byte
  152. copy(a16[:], e.First)
  153. return netip.AddrFrom16(a16)
  154. }
  155. func (e *IPListEntry) getLast() netip.Addr {
  156. if e.IPType == ipTypeV4 {
  157. var a4 [4]byte
  158. copy(a4[:], e.Last)
  159. return netip.AddrFrom4(a4)
  160. }
  161. var a16 [16]byte
  162. copy(a16[:], e.Last)
  163. return netip.AddrFrom16(a16)
  164. }
  165. func (e *IPListEntry) checkProtocols() {
  166. for _, proto := range ValidProtocols {
  167. if !e.HasProtocol(proto) {
  168. return
  169. }
  170. }
  171. e.Protocols = 0
  172. }
  173. func (e *IPListEntry) validate() error {
  174. if err := CheckIPListType(e.Type); err != nil {
  175. return err
  176. }
  177. e.checkProtocols()
  178. switch e.Type {
  179. case IPListTypeDefender:
  180. if e.Mode < ListModeAllow || e.Mode > ListModeDeny {
  181. return util.NewValidationError(fmt.Sprintf("invalid list mode: %d", e.Mode))
  182. }
  183. default:
  184. if e.Mode != ListModeAllow {
  185. return util.NewValidationError("invalid list mode")
  186. }
  187. }
  188. e.PrepareForRendering()
  189. if !strings.Contains(e.IPOrNet, "/") {
  190. // parse as IP
  191. parsed, err := netip.ParseAddr(e.IPOrNet)
  192. if err != nil {
  193. return util.NewValidationError(fmt.Sprintf("invalid IP %q", e.IPOrNet))
  194. }
  195. if parsed.Is4() {
  196. e.IPOrNet += "/32"
  197. } else if parsed.Is4In6() {
  198. e.IPOrNet = netip.AddrFrom4(parsed.As4()).String() + "/32"
  199. } else {
  200. e.IPOrNet += "/128"
  201. }
  202. }
  203. prefix, err := netip.ParsePrefix(e.IPOrNet)
  204. if err != nil {
  205. return util.NewValidationError(fmt.Sprintf("invalid network %q: %v", e.IPOrNet, err))
  206. }
  207. prefix = prefix.Masked()
  208. if prefix.Addr().Is4In6() {
  209. e.IPOrNet = fmt.Sprintf("%s/%d", netip.AddrFrom4(prefix.Addr().As4()).String(), prefix.Bits()-96)
  210. }
  211. // TODO: to remove when the in memory ranger switch to netip
  212. _, _, err = net.ParseCIDR(e.IPOrNet)
  213. if err != nil {
  214. return util.NewValidationError(fmt.Sprintf("invalid network: %v", err))
  215. }
  216. if prefix.Addr().Is4() || prefix.Addr().Is4In6() {
  217. e.IPType = ipTypeV4
  218. first := prefix.Addr().As4()
  219. last := util.GetLastIPForPrefix(prefix).As4()
  220. e.First = first[:]
  221. e.Last = last[:]
  222. } else {
  223. e.IPType = ipTypeV6
  224. first := prefix.Addr().As16()
  225. last := util.GetLastIPForPrefix(prefix).As16()
  226. e.First = first[:]
  227. e.Last = last[:]
  228. }
  229. return nil
  230. }
  231. func (e *IPListEntry) getACopy() IPListEntry {
  232. first := make([]byte, len(e.First))
  233. copy(first, e.First)
  234. last := make([]byte, len(e.Last))
  235. copy(last, e.Last)
  236. return IPListEntry{
  237. IPOrNet: e.IPOrNet,
  238. Description: e.Description,
  239. Type: e.Type,
  240. Mode: e.Mode,
  241. First: first,
  242. Last: last,
  243. IPType: e.IPType,
  244. Protocols: e.Protocols,
  245. CreatedAt: e.CreatedAt,
  246. UpdatedAt: e.UpdatedAt,
  247. DeletedAt: e.DeletedAt,
  248. }
  249. }
  250. // getAsRangerEntry returns the entry as cidranger.RangerEntry
  251. func (e *IPListEntry) getAsRangerEntry() (cidranger.RangerEntry, error) {
  252. _, network, err := net.ParseCIDR(e.IPOrNet)
  253. if err != nil {
  254. return nil, err
  255. }
  256. entry := e.getACopy()
  257. return &rangerEntry{
  258. entry: &entry,
  259. network: *network,
  260. }, nil
  261. }
  262. func (e IPListEntry) satisfySearchConstraints(filter, from, order string) bool {
  263. if filter != "" && !strings.HasPrefix(e.IPOrNet, filter) {
  264. return false
  265. }
  266. if from != "" {
  267. if order == OrderASC {
  268. return e.IPOrNet > from
  269. }
  270. return e.IPOrNet < from
  271. }
  272. return true
  273. }
  274. type rangerEntry struct {
  275. entry *IPListEntry
  276. network net.IPNet
  277. }
  278. func (e *rangerEntry) Network() net.IPNet {
  279. return e.network
  280. }
  281. // IPList defines an IP list
  282. type IPList struct {
  283. isInMemory atomic.Bool
  284. listType IPListType
  285. mu sync.RWMutex
  286. Ranges cidranger.Ranger
  287. }
  288. func (l *IPList) addEntry(e *IPListEntry) {
  289. if l.listType != e.Type {
  290. return
  291. }
  292. if !l.isInMemory.Load() {
  293. return
  294. }
  295. entry, err := e.getAsRangerEntry()
  296. if err != nil {
  297. providerLog(logger.LevelError, "unable to get entry to add %q for list type %d, disabling memory mode, err: %v",
  298. e.IPOrNet, l.listType, err)
  299. l.isInMemory.Store(false)
  300. return
  301. }
  302. l.mu.Lock()
  303. defer l.mu.Unlock()
  304. if err := l.Ranges.Insert(entry); err != nil {
  305. providerLog(logger.LevelError, "unable to add entry %q for list type %d, disabling memory mode, err: %v",
  306. e.IPOrNet, l.listType, err)
  307. l.isInMemory.Store(false)
  308. return
  309. }
  310. if l.Ranges.Len() >= ipListMemoryLimit {
  311. providerLog(logger.LevelError, "memory limit exceeded for list type %d, disabling memory mode", l.listType)
  312. l.isInMemory.Store(false)
  313. }
  314. }
  315. func (l *IPList) removeEntry(e *IPListEntry) {
  316. if l.listType != e.Type {
  317. return
  318. }
  319. if !l.isInMemory.Load() {
  320. return
  321. }
  322. entry, err := e.getAsRangerEntry()
  323. if err != nil {
  324. providerLog(logger.LevelError, "unable to get entry to remove %q for list type %d, disabling memory mode, err: %v",
  325. e.IPOrNet, l.listType, err)
  326. l.isInMemory.Store(false)
  327. return
  328. }
  329. l.mu.Lock()
  330. defer l.mu.Unlock()
  331. if _, err := l.Ranges.Remove(entry.Network()); err != nil {
  332. providerLog(logger.LevelError, "unable to remove entry %q for list type %d, disabling memory mode, err: %v",
  333. e.IPOrNet, l.listType, err)
  334. l.isInMemory.Store(false)
  335. }
  336. }
  337. func (l *IPList) updateEntry(e *IPListEntry) {
  338. if l.listType != e.Type {
  339. return
  340. }
  341. if !l.isInMemory.Load() {
  342. return
  343. }
  344. entry, err := e.getAsRangerEntry()
  345. if err != nil {
  346. providerLog(logger.LevelError, "unable to get entry to update %q for list type %d, disabling memory mode, err: %v",
  347. e.IPOrNet, l.listType, err)
  348. l.isInMemory.Store(false)
  349. return
  350. }
  351. l.mu.Lock()
  352. defer l.mu.Unlock()
  353. if _, err := l.Ranges.Remove(entry.Network()); err != nil {
  354. providerLog(logger.LevelError, "unable to remove entry to update %q for list type %d, disabling memory mode, err: %v",
  355. e.IPOrNet, l.listType, err)
  356. l.isInMemory.Store(false)
  357. return
  358. }
  359. if err := l.Ranges.Insert(entry); err != nil {
  360. providerLog(logger.LevelError, "unable to add entry to update %q for list type %d, disabling memory mode, err: %v",
  361. e.IPOrNet, l.listType, err)
  362. l.isInMemory.Store(false)
  363. }
  364. if l.Ranges.Len() >= ipListMemoryLimit {
  365. providerLog(logger.LevelError, "memory limit exceeded for list type %d, disabling memory mode", l.listType)
  366. l.isInMemory.Store(false)
  367. }
  368. }
  369. // DisableMemoryMode disables memory mode forcing database queries
  370. func (l *IPList) DisableMemoryMode() {
  371. l.isInMemory.Store(false)
  372. }
  373. // IsListed checks if there is a match for the specified IP and protocol.
  374. // If there are multiple matches, the first one is returned, in no particular order,
  375. // so the behavior is undefined
  376. func (l *IPList) IsListed(ip, protocol string) (bool, int, error) {
  377. if l.isInMemory.Load() {
  378. l.mu.RLock()
  379. defer l.mu.RUnlock()
  380. parsedIP := net.ParseIP(ip)
  381. if parsedIP == nil {
  382. return false, 0, fmt.Errorf("invalid IP %s", ip)
  383. }
  384. entries, err := l.Ranges.ContainingNetworks(parsedIP)
  385. if err != nil {
  386. return false, 0, fmt.Errorf("unable to find containing networks for ip %q: %w", ip, err)
  387. }
  388. for _, e := range entries {
  389. entry, ok := e.(*rangerEntry)
  390. if ok {
  391. if entry.entry.Protocols == 0 || entry.entry.HasProtocol(protocol) {
  392. return true, entry.entry.Mode, nil
  393. }
  394. }
  395. }
  396. return false, 0, nil
  397. }
  398. entries, err := provider.getListEntriesForIP(ip, l.listType)
  399. if err != nil {
  400. return false, 0, err
  401. }
  402. for _, e := range entries {
  403. if e.Protocols == 0 || e.HasProtocol(protocol) {
  404. return true, e.Mode, nil
  405. }
  406. }
  407. return false, 0, nil
  408. }
  409. // NewIPList returns a new IP list for the specified type
  410. func NewIPList(listType IPListType) (*IPList, error) {
  411. delete(inMemoryLists, listType)
  412. count, err := provider.countIPListEntries(listType)
  413. if err != nil {
  414. return nil, err
  415. }
  416. if count < ipListMemoryLimit {
  417. providerLog(logger.LevelInfo, "using in-memory matching for list type %d, num entries: %d", listType, count)
  418. entries, err := provider.getIPListEntries(listType, "", "", OrderASC, 0)
  419. if err != nil {
  420. return nil, err
  421. }
  422. ipList := &IPList{
  423. listType: listType,
  424. Ranges: cidranger.NewPCTrieRanger(),
  425. }
  426. for idx := range entries {
  427. e := entries[idx]
  428. entry, err := e.getAsRangerEntry()
  429. if err != nil {
  430. return nil, fmt.Errorf("unable to get ranger for entry %q: %w", e.IPOrNet, err)
  431. }
  432. if err := ipList.Ranges.Insert(entry); err != nil {
  433. return nil, fmt.Errorf("unable to add ranger for entry %q: %w", e.IPOrNet, err)
  434. }
  435. }
  436. ipList.isInMemory.Store(true)
  437. inMemoryLists[listType] = ipList
  438. return ipList, nil
  439. }
  440. providerLog(logger.LevelInfo, "list type %d has %d entries, in-memory matching disabled", listType, count)
  441. ipList := &IPList{
  442. listType: listType,
  443. Ranges: nil,
  444. }
  445. ipList.isInMemory.Store(false)
  446. return ipList, nil
  447. }