1
0

oauth2.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. // Copyright (C) 2019-2023 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package httpd
  15. import (
  16. "encoding/json"
  17. "errors"
  18. "sync"
  19. "time"
  20. "github.com/rs/xid"
  21. "github.com/drakkan/sftpgo/v2/internal/dataprovider"
  22. "github.com/drakkan/sftpgo/v2/internal/kms"
  23. "github.com/drakkan/sftpgo/v2/internal/logger"
  24. "github.com/drakkan/sftpgo/v2/internal/util"
  25. )
  26. var (
  27. oauth2Mgr oauth2Manager
  28. )
  29. func newOAuth2Manager(isShared int) oauth2Manager {
  30. if isShared == 1 {
  31. logger.Info(logSender, "", "using provider OAuth2 manager")
  32. return &dbOAuth2Manager{}
  33. }
  34. logger.Info(logSender, "", "using memory OAuth2 manager")
  35. return &memoryOAuth2Manager{
  36. pendingAuths: make(map[string]oauth2PendingAuth),
  37. }
  38. }
  39. type oauth2PendingAuth struct {
  40. State string `json:"state"`
  41. Provider int `json:"provider"`
  42. ClientID string `json:"client_id"`
  43. ClientSecret *kms.Secret `json:"client_secret"`
  44. RedirectURL string `json:"redirect_url"`
  45. IssuedAt int64 `json:"issued_at"`
  46. }
  47. func newOAuth2PendingAuth(provider int, redirectURL, clientID string, clientSecret *kms.Secret) oauth2PendingAuth {
  48. return oauth2PendingAuth{
  49. State: xid.New().String(),
  50. Provider: provider,
  51. ClientID: clientID,
  52. ClientSecret: clientSecret,
  53. RedirectURL: redirectURL,
  54. IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
  55. }
  56. }
  57. type oauth2Manager interface {
  58. addPendingAuth(pendingAuth oauth2PendingAuth)
  59. removePendingAuth(state string)
  60. getPendingAuth(state string) (oauth2PendingAuth, error)
  61. cleanup()
  62. }
  63. type memoryOAuth2Manager struct {
  64. mu sync.RWMutex
  65. pendingAuths map[string]oauth2PendingAuth
  66. }
  67. func (o *memoryOAuth2Manager) addPendingAuth(pendingAuth oauth2PendingAuth) {
  68. o.mu.Lock()
  69. defer o.mu.Unlock()
  70. o.pendingAuths[pendingAuth.State] = pendingAuth
  71. }
  72. func (o *memoryOAuth2Manager) removePendingAuth(state string) {
  73. o.mu.Lock()
  74. defer o.mu.Unlock()
  75. delete(o.pendingAuths, state)
  76. }
  77. func (o *memoryOAuth2Manager) getPendingAuth(state string) (oauth2PendingAuth, error) {
  78. o.mu.RLock()
  79. defer o.mu.RUnlock()
  80. authReq, ok := o.pendingAuths[state]
  81. if !ok {
  82. return oauth2PendingAuth{}, errors.New("oauth2: no auth request found for the specified state")
  83. }
  84. diff := util.GetTimeAsMsSinceEpoch(time.Now()) - authReq.IssuedAt
  85. if diff > authStateValidity {
  86. return oauth2PendingAuth{}, errors.New("oauth2: auth request is too old")
  87. }
  88. return authReq, nil
  89. }
  90. func (o *memoryOAuth2Manager) cleanup() {
  91. logger.Debug(logSender, "", "oauth2 manager cleanup")
  92. o.mu.Lock()
  93. defer o.mu.Unlock()
  94. for k, auth := range o.pendingAuths {
  95. diff := util.GetTimeAsMsSinceEpoch(time.Now()) - auth.IssuedAt
  96. // remove old pending auth requests
  97. if diff < 0 || diff > authStateValidity {
  98. delete(o.pendingAuths, k)
  99. }
  100. }
  101. }
  102. type dbOAuth2Manager struct{}
  103. func (o *dbOAuth2Manager) addPendingAuth(pendingAuth oauth2PendingAuth) {
  104. if err := pendingAuth.ClientSecret.Encrypt(); err != nil {
  105. logger.Error(logSender, "", "unable to encrypt oauth2 secret: %v", err)
  106. return
  107. }
  108. session := dataprovider.Session{
  109. Key: pendingAuth.State,
  110. Data: pendingAuth,
  111. Type: dataprovider.SessionTypeOAuth2Auth,
  112. Timestamp: pendingAuth.IssuedAt + authStateValidity,
  113. }
  114. dataprovider.AddSharedSession(session) //nolint:errcheck
  115. }
  116. func (o *dbOAuth2Manager) removePendingAuth(state string) {
  117. dataprovider.DeleteSharedSession(state) //nolint:errcheck
  118. }
  119. func (o *dbOAuth2Manager) getPendingAuth(state string) (oauth2PendingAuth, error) {
  120. session, err := dataprovider.GetSharedSession(state)
  121. if err != nil {
  122. return oauth2PendingAuth{}, errors.New("oauth2: unable to get the auth request for the specified state")
  123. }
  124. if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) {
  125. // expired
  126. return oauth2PendingAuth{}, errors.New("oauth2: auth request is too old")
  127. }
  128. return o.decodePendingAuthData(session.Data)
  129. }
  130. func (o *dbOAuth2Manager) decodePendingAuthData(data any) (oauth2PendingAuth, error) {
  131. if val, ok := data.([]byte); ok {
  132. authReq := oauth2PendingAuth{}
  133. err := json.Unmarshal(val, &authReq)
  134. if err != nil {
  135. return authReq, err
  136. }
  137. err = authReq.ClientSecret.TryDecrypt()
  138. return authReq, err
  139. }
  140. logger.Error(logSender, "", "invalid oauth2 auth request data type %T", data)
  141. return oauth2PendingAuth{}, errors.New("oauth2: invalid auth request data")
  142. }
  143. func (o *dbOAuth2Manager) cleanup() {
  144. logger.Debug(logSender, "", "oauth2 manager cleanup")
  145. dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOAuth2Auth, time.Now()) //nolint:errcheck
  146. }