cachedpassword.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. // Copyright (C) 2019 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package dataprovider
  15. import (
  16. "sort"
  17. "sync"
  18. "sync/atomic"
  19. "time"
  20. "github.com/drakkan/sftpgo/v2/internal/logger"
  21. "github.com/drakkan/sftpgo/v2/internal/util"
  22. )
  23. var (
  24. cachedUserPasswords credentialsCache
  25. cachedAdminPasswords credentialsCache
  26. cachedAPIKeys credentialsCache
  27. )
  28. func init() {
  29. cachedUserPasswords = credentialsCache{
  30. name: "users",
  31. sizeLimit: 500,
  32. cache: make(map[string]credentialObject),
  33. }
  34. cachedAdminPasswords = credentialsCache{
  35. name: "admins",
  36. sizeLimit: 100,
  37. cache: make(map[string]credentialObject),
  38. }
  39. cachedAPIKeys = credentialsCache{
  40. name: "API keys",
  41. sizeLimit: 500,
  42. cache: make(map[string]credentialObject),
  43. }
  44. }
  45. // CheckCachedUserPassword is an utility method used only in test cases
  46. func CheckCachedUserPassword(username, password, hash string) (bool, bool) {
  47. return cachedUserPasswords.Check(username, password, hash)
  48. }
  49. type credentialObject struct {
  50. key string
  51. hash string
  52. password string
  53. usedAt *atomic.Int64
  54. }
  55. type credentialsCache struct {
  56. name string
  57. sizeLimit int
  58. sync.RWMutex
  59. cache map[string]credentialObject
  60. }
  61. func (c *credentialsCache) Add(username, password, hash string) {
  62. if !config.PasswordCaching || username == "" || password == "" || hash == "" {
  63. return
  64. }
  65. c.Lock()
  66. defer c.Unlock()
  67. obj := credentialObject{
  68. key: username,
  69. hash: hash,
  70. password: password,
  71. usedAt: &atomic.Int64{},
  72. }
  73. obj.usedAt.Store(util.GetTimeAsMsSinceEpoch(time.Now()))
  74. c.cache[username] = obj
  75. }
  76. func (c *credentialsCache) Remove(username string) {
  77. if !config.PasswordCaching {
  78. return
  79. }
  80. c.Lock()
  81. defer c.Unlock()
  82. delete(c.cache, username)
  83. }
  84. // Check returns if the username is found and if the password match
  85. func (c *credentialsCache) Check(username, password, hash string) (bool, bool) {
  86. if username == "" || password == "" || hash == "" {
  87. return false, false
  88. }
  89. c.RLock()
  90. defer c.RUnlock()
  91. creds, ok := c.cache[username]
  92. if !ok {
  93. return false, false
  94. }
  95. if creds.hash != hash {
  96. creds.usedAt.Store(0)
  97. return false, false
  98. }
  99. match := creds.password == password
  100. if match {
  101. creds.usedAt.Store(util.GetTimeAsMsSinceEpoch(time.Now()))
  102. }
  103. return true, match
  104. }
  105. func (c *credentialsCache) count() int {
  106. c.RLock()
  107. defer c.RUnlock()
  108. return len(c.cache)
  109. }
  110. func (c *credentialsCache) cleanup() {
  111. if !config.PasswordCaching {
  112. return
  113. }
  114. if c.count() <= c.sizeLimit {
  115. return
  116. }
  117. c.Lock()
  118. defer c.Unlock()
  119. for k, v := range c.cache {
  120. if v.usedAt.Load() < util.GetTimeAsMsSinceEpoch(time.Now().Add(-60*time.Minute)) {
  121. delete(c.cache, k)
  122. }
  123. }
  124. providerLog(logger.LevelDebug, "size for credentials %q after cleanup: %d", c.name, len(c.cache))
  125. if len(c.cache) < c.sizeLimit*5 {
  126. return
  127. }
  128. numToRemove := len(c.cache) - c.sizeLimit
  129. providerLog(logger.LevelDebug, "additional item to remove from credentials %q: %d", c.name, numToRemove)
  130. credentials := make([]credentialObject, 0, len(c.cache))
  131. for _, v := range c.cache {
  132. credentials = append(credentials, v)
  133. }
  134. sort.Slice(credentials, func(i, j int) bool {
  135. return credentials[i].usedAt.Load() < credentials[j].usedAt.Load()
  136. })
  137. for idx := range credentials {
  138. if idx >= numToRemove {
  139. break
  140. }
  141. delete(c.cache, credentials[idx].key)
  142. }
  143. providerLog(logger.LevelDebug, "size for credentials %q after additional cleanup: %d", c.name, len(c.cache))
  144. }