resetcode.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. // Copyright (C) 2019-2023 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package httpd
  15. import (
  16. "encoding/json"
  17. "sync"
  18. "time"
  19. "github.com/drakkan/sftpgo/v2/internal/dataprovider"
  20. "github.com/drakkan/sftpgo/v2/internal/logger"
  21. "github.com/drakkan/sftpgo/v2/internal/util"
  22. )
  23. var (
  24. resetCodeLifespan = 10 * time.Minute
  25. resetCodesMgr resetCodeManager
  26. )
  27. type resetCodeManager interface {
  28. Add(code *resetCode) error
  29. Get(code string) (*resetCode, error)
  30. Delete(code string) error
  31. Cleanup()
  32. }
  33. func newResetCodeManager(isShared int) resetCodeManager {
  34. if isShared == 1 {
  35. logger.Info(logSender, "", "using provider reset code manager")
  36. return &dbResetCodeManager{}
  37. }
  38. logger.Info(logSender, "", "using memory reset code manager")
  39. return &memoryResetCodeManager{}
  40. }
  41. type resetCode struct {
  42. Code string `json:"code"`
  43. Username string `json:"username"`
  44. IsAdmin bool `json:"is_admin"`
  45. ExpiresAt time.Time `json:"expires_at"`
  46. }
  47. func newResetCode(username string, isAdmin bool) *resetCode {
  48. return &resetCode{
  49. Code: util.GenerateUniqueID(),
  50. Username: username,
  51. IsAdmin: isAdmin,
  52. ExpiresAt: time.Now().Add(resetCodeLifespan).UTC(),
  53. }
  54. }
  55. func (c *resetCode) isExpired() bool {
  56. return c.ExpiresAt.Before(time.Now().UTC())
  57. }
  58. type memoryResetCodeManager struct {
  59. resetCodes sync.Map
  60. }
  61. func (m *memoryResetCodeManager) Add(code *resetCode) error {
  62. m.resetCodes.Store(code.Code, code)
  63. return nil
  64. }
  65. func (m *memoryResetCodeManager) Get(code string) (*resetCode, error) {
  66. c, ok := m.resetCodes.Load(code)
  67. if !ok {
  68. return nil, util.NewRecordNotFoundError("reset code not found")
  69. }
  70. return c.(*resetCode), nil
  71. }
  72. func (m *memoryResetCodeManager) Delete(code string) error {
  73. m.resetCodes.Delete(code)
  74. return nil
  75. }
  76. func (m *memoryResetCodeManager) Cleanup() {
  77. m.resetCodes.Range(func(key, value any) bool {
  78. c, ok := value.(*resetCode)
  79. if !ok || c.isExpired() {
  80. m.resetCodes.Delete(key)
  81. }
  82. return true
  83. })
  84. }
  85. type dbResetCodeManager struct{}
  86. func (m *dbResetCodeManager) Add(code *resetCode) error {
  87. session := dataprovider.Session{
  88. Key: code.Code,
  89. Data: code,
  90. Type: dataprovider.SessionTypeResetCode,
  91. Timestamp: util.GetTimeAsMsSinceEpoch(code.ExpiresAt),
  92. }
  93. return dataprovider.AddSharedSession(session)
  94. }
  95. func (m *dbResetCodeManager) Get(code string) (*resetCode, error) {
  96. session, err := dataprovider.GetSharedSession(code)
  97. if err != nil {
  98. return nil, err
  99. }
  100. if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) {
  101. // expired
  102. return nil, util.NewRecordNotFoundError("reset code expired")
  103. }
  104. return m.decodeData(session.Data)
  105. }
  106. func (m *dbResetCodeManager) decodeData(data any) (*resetCode, error) {
  107. if val, ok := data.([]byte); ok {
  108. c := &resetCode{}
  109. err := json.Unmarshal(val, c)
  110. return c, err
  111. }
  112. logger.Error(logSender, "", "invalid reset code data type %T", data)
  113. return nil, util.NewRecordNotFoundError("invalid reset code")
  114. }
  115. func (m *dbResetCodeManager) Delete(code string) error {
  116. return dataprovider.DeleteSharedSession(code)
  117. }
  118. func (m *dbResetCodeManager) Cleanup() {
  119. dataprovider.CleanupSharedSessions(dataprovider.SessionTypeResetCode, time.Now()) //nolint:errcheck
  120. }