oauth2.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. // Copyright (C) 2019-2023 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package httpd
  15. import (
  16. "encoding/json"
  17. "errors"
  18. "sync"
  19. "time"
  20. "github.com/rs/xid"
  21. "github.com/drakkan/sftpgo/v2/internal/dataprovider"
  22. "github.com/drakkan/sftpgo/v2/internal/kms"
  23. "github.com/drakkan/sftpgo/v2/internal/logger"
  24. "github.com/drakkan/sftpgo/v2/internal/util"
  25. )
  26. var (
  27. oauth2Mgr oauth2Manager
  28. )
  29. func newOAuth2Manager(isShared int) oauth2Manager {
  30. if isShared == 1 {
  31. logger.Info(logSender, "", "using provider OAuth2 manager")
  32. return &dbOAuth2Manager{}
  33. }
  34. logger.Info(logSender, "", "using memory OAuth2 manager")
  35. return &memoryOAuth2Manager{
  36. pendingAuths: make(map[string]oauth2PendingAuth),
  37. }
  38. }
  39. type oauth2PendingAuth struct {
  40. State string `json:"state"`
  41. Provider int `json:"provider"`
  42. ClientID string `json:"client_id"`
  43. ClientSecret *kms.Secret `json:"client_secret"`
  44. RedirectURL string `json:"redirect_url"`
  45. IssuedAt int64 `json:"issued_at"`
  46. }
  47. func newOAuth2PendingAuth(provider int, redirectURL, clientID string, clientSecret *kms.Secret) oauth2PendingAuth {
  48. return oauth2PendingAuth{
  49. State: xid.New().String(),
  50. Provider: provider,
  51. ClientID: clientID,
  52. ClientSecret: clientSecret,
  53. RedirectURL: redirectURL,
  54. IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
  55. }
  56. }
  57. type oauth2Manager interface {
  58. addPendingAuth(pendingAuth oauth2PendingAuth)
  59. removePendingAuth(state string)
  60. getPendingAuth(state string) (oauth2PendingAuth, error)
  61. cleanup()
  62. }
  63. type memoryOAuth2Manager struct {
  64. mu sync.RWMutex
  65. pendingAuths map[string]oauth2PendingAuth
  66. }
  67. func (o *memoryOAuth2Manager) addPendingAuth(pendingAuth oauth2PendingAuth) {
  68. o.mu.Lock()
  69. defer o.mu.Unlock()
  70. o.pendingAuths[pendingAuth.State] = pendingAuth
  71. }
  72. func (o *memoryOAuth2Manager) removePendingAuth(state string) {
  73. o.mu.Lock()
  74. defer o.mu.Unlock()
  75. delete(o.pendingAuths, state)
  76. }
  77. func (o *memoryOAuth2Manager) getPendingAuth(state string) (oauth2PendingAuth, error) {
  78. o.mu.RLock()
  79. defer o.mu.RUnlock()
  80. authReq, ok := o.pendingAuths[state]
  81. if !ok {
  82. return oauth2PendingAuth{}, errors.New("oauth2: no auth request found for the specified state")
  83. }
  84. diff := util.GetTimeAsMsSinceEpoch(time.Now()) - authReq.IssuedAt
  85. if diff > authStateValidity {
  86. return oauth2PendingAuth{}, errors.New("oauth2: auth request is too old")
  87. }
  88. return authReq, nil
  89. }
  90. func (o *memoryOAuth2Manager) cleanup() {
  91. o.mu.Lock()
  92. defer o.mu.Unlock()
  93. for k, auth := range o.pendingAuths {
  94. diff := util.GetTimeAsMsSinceEpoch(time.Now()) - auth.IssuedAt
  95. // remove old pending auth requests
  96. if diff < 0 || diff > authStateValidity {
  97. delete(o.pendingAuths, k)
  98. }
  99. }
  100. }
  101. type dbOAuth2Manager struct{}
  102. func (o *dbOAuth2Manager) addPendingAuth(pendingAuth oauth2PendingAuth) {
  103. if err := pendingAuth.ClientSecret.Encrypt(); err != nil {
  104. logger.Error(logSender, "", "unable to encrypt oauth2 secret: %v", err)
  105. return
  106. }
  107. session := dataprovider.Session{
  108. Key: pendingAuth.State,
  109. Data: pendingAuth,
  110. Type: dataprovider.SessionTypeOAuth2Auth,
  111. Timestamp: pendingAuth.IssuedAt + authStateValidity,
  112. }
  113. dataprovider.AddSharedSession(session) //nolint:errcheck
  114. }
  115. func (o *dbOAuth2Manager) removePendingAuth(state string) {
  116. dataprovider.DeleteSharedSession(state) //nolint:errcheck
  117. }
  118. func (o *dbOAuth2Manager) getPendingAuth(state string) (oauth2PendingAuth, error) {
  119. session, err := dataprovider.GetSharedSession(state)
  120. if err != nil {
  121. return oauth2PendingAuth{}, errors.New("oauth2: unable to get the auth request for the specified state")
  122. }
  123. if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) {
  124. // expired
  125. return oauth2PendingAuth{}, errors.New("oauth2: auth request is too old")
  126. }
  127. return o.decodePendingAuthData(session.Data)
  128. }
  129. func (o *dbOAuth2Manager) decodePendingAuthData(data any) (oauth2PendingAuth, error) {
  130. if val, ok := data.([]byte); ok {
  131. authReq := oauth2PendingAuth{}
  132. err := json.Unmarshal(val, &authReq)
  133. if err != nil {
  134. return authReq, err
  135. }
  136. err = authReq.ClientSecret.TryDecrypt()
  137. return authReq, err
  138. }
  139. logger.Error(logSender, "", "invalid oauth2 auth request data type %T", data)
  140. return oauth2PendingAuth{}, errors.New("oauth2: invalid auth request data")
  141. }
  142. func (o *dbOAuth2Manager) cleanup() {
  143. dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOAuth2Auth, time.Now()) //nolint:errcheck
  144. }