oauth2.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. // Copyright (C) 2019 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package httpd
  15. import (
  16. "crypto/sha256"
  17. "encoding/hex"
  18. "encoding/json"
  19. "errors"
  20. "sync"
  21. "time"
  22. "github.com/drakkan/sftpgo/v2/internal/dataprovider"
  23. "github.com/drakkan/sftpgo/v2/internal/kms"
  24. "github.com/drakkan/sftpgo/v2/internal/logger"
  25. "github.com/drakkan/sftpgo/v2/internal/util"
  26. )
  27. var (
  28. oauth2Mgr oauth2Manager
  29. )
  30. func newOAuth2Manager(isShared int) oauth2Manager {
  31. if isShared == 1 {
  32. logger.Info(logSender, "", "using provider OAuth2 manager")
  33. return &dbOAuth2Manager{}
  34. }
  35. logger.Info(logSender, "", "using memory OAuth2 manager")
  36. return &memoryOAuth2Manager{
  37. pendingAuths: make(map[string]oauth2PendingAuth),
  38. }
  39. }
  40. type oauth2PendingAuth struct {
  41. State string `json:"state"`
  42. Provider int `json:"provider"`
  43. ClientID string `json:"client_id"`
  44. ClientSecret *kms.Secret `json:"client_secret"`
  45. RedirectURL string `json:"redirect_url"`
  46. IssuedAt int64 `json:"issued_at"`
  47. }
  48. func newOAuth2PendingAuth(provider int, redirectURL, clientID string, clientSecret *kms.Secret) oauth2PendingAuth {
  49. state := sha256.Sum256(util.GenerateRandomBytes(32))
  50. return oauth2PendingAuth{
  51. State: hex.EncodeToString(state[:]),
  52. Provider: provider,
  53. ClientID: clientID,
  54. ClientSecret: clientSecret,
  55. RedirectURL: redirectURL,
  56. IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
  57. }
  58. }
  59. type oauth2Manager interface {
  60. addPendingAuth(pendingAuth oauth2PendingAuth)
  61. removePendingAuth(state string)
  62. getPendingAuth(state string) (oauth2PendingAuth, error)
  63. cleanup()
  64. }
  65. type memoryOAuth2Manager struct {
  66. mu sync.RWMutex
  67. pendingAuths map[string]oauth2PendingAuth
  68. }
  69. func (o *memoryOAuth2Manager) addPendingAuth(pendingAuth oauth2PendingAuth) {
  70. o.mu.Lock()
  71. defer o.mu.Unlock()
  72. o.pendingAuths[pendingAuth.State] = pendingAuth
  73. }
  74. func (o *memoryOAuth2Manager) removePendingAuth(state string) {
  75. o.mu.Lock()
  76. defer o.mu.Unlock()
  77. delete(o.pendingAuths, state)
  78. }
  79. func (o *memoryOAuth2Manager) getPendingAuth(state string) (oauth2PendingAuth, error) {
  80. o.mu.RLock()
  81. defer o.mu.RUnlock()
  82. authReq, ok := o.pendingAuths[state]
  83. if !ok {
  84. return oauth2PendingAuth{}, errors.New("oauth2: no auth request found for the specified state")
  85. }
  86. diff := util.GetTimeAsMsSinceEpoch(time.Now()) - authReq.IssuedAt
  87. if diff > authStateValidity {
  88. return oauth2PendingAuth{}, errors.New("oauth2: auth request is too old")
  89. }
  90. return authReq, nil
  91. }
  92. func (o *memoryOAuth2Manager) cleanup() {
  93. o.mu.Lock()
  94. defer o.mu.Unlock()
  95. for k, auth := range o.pendingAuths {
  96. diff := util.GetTimeAsMsSinceEpoch(time.Now()) - auth.IssuedAt
  97. // remove old pending auth requests
  98. if diff < 0 || diff > authStateValidity {
  99. delete(o.pendingAuths, k)
  100. }
  101. }
  102. }
  103. type dbOAuth2Manager struct{}
  104. func (o *dbOAuth2Manager) addPendingAuth(pendingAuth oauth2PendingAuth) {
  105. if err := pendingAuth.ClientSecret.Encrypt(); err != nil {
  106. logger.Error(logSender, "", "unable to encrypt oauth2 secret: %v", err)
  107. return
  108. }
  109. session := dataprovider.Session{
  110. Key: pendingAuth.State,
  111. Data: pendingAuth,
  112. Type: dataprovider.SessionTypeOAuth2Auth,
  113. Timestamp: pendingAuth.IssuedAt + authStateValidity,
  114. }
  115. dataprovider.AddSharedSession(session) //nolint:errcheck
  116. }
  117. func (o *dbOAuth2Manager) removePendingAuth(state string) {
  118. dataprovider.DeleteSharedSession(state) //nolint:errcheck
  119. }
  120. func (o *dbOAuth2Manager) getPendingAuth(state string) (oauth2PendingAuth, error) {
  121. session, err := dataprovider.GetSharedSession(state)
  122. if err != nil {
  123. return oauth2PendingAuth{}, errors.New("oauth2: unable to get the auth request for the specified state")
  124. }
  125. if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) {
  126. // expired
  127. return oauth2PendingAuth{}, errors.New("oauth2: auth request is too old")
  128. }
  129. return o.decodePendingAuthData(session.Data)
  130. }
  131. func (o *dbOAuth2Manager) decodePendingAuthData(data any) (oauth2PendingAuth, error) {
  132. if val, ok := data.([]byte); ok {
  133. authReq := oauth2PendingAuth{}
  134. err := json.Unmarshal(val, &authReq)
  135. if err != nil {
  136. return authReq, err
  137. }
  138. err = authReq.ClientSecret.TryDecrypt()
  139. return authReq, err
  140. }
  141. logger.Error(logSender, "", "invalid oauth2 auth request data type %T", data)
  142. return oauth2PendingAuth{}, errors.New("oauth2: invalid auth request data")
  143. }
  144. func (o *dbOAuth2Manager) cleanup() {
  145. dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOAuth2Auth, time.Now()) //nolint:errcheck
  146. }