oidcmanager.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. // Copyright (C) 2019-2023 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package httpd
  15. import (
  16. "encoding/json"
  17. "errors"
  18. "sync"
  19. "time"
  20. "github.com/drakkan/sftpgo/v2/internal/dataprovider"
  21. "github.com/drakkan/sftpgo/v2/internal/logger"
  22. "github.com/drakkan/sftpgo/v2/internal/util"
  23. )
  24. var (
  25. oidcMgr oidcManager
  26. )
  27. func newOIDCManager(isShared int) oidcManager {
  28. if isShared == 1 {
  29. logger.Info(logSender, "", "using provider OIDC manager")
  30. return &dbOIDCManager{}
  31. }
  32. logger.Info(logSender, "", "using memory OIDC manager")
  33. return &memoryOIDCManager{
  34. pendingAuths: make(map[string]oidcPendingAuth),
  35. tokens: make(map[string]oidcToken),
  36. }
  37. }
  38. type oidcManager interface {
  39. addPendingAuth(pendingAuth oidcPendingAuth)
  40. removePendingAuth(state string)
  41. getPendingAuth(state string) (oidcPendingAuth, error)
  42. addToken(token oidcToken)
  43. getToken(cookie string) (oidcToken, error)
  44. removeToken(cookie string)
  45. updateTokenUsage(token oidcToken)
  46. cleanup()
  47. }
  48. type memoryOIDCManager struct {
  49. authMutex sync.RWMutex
  50. pendingAuths map[string]oidcPendingAuth
  51. tokenMutex sync.RWMutex
  52. tokens map[string]oidcToken
  53. }
  54. func (o *memoryOIDCManager) addPendingAuth(pendingAuth oidcPendingAuth) {
  55. o.authMutex.Lock()
  56. o.pendingAuths[pendingAuth.State] = pendingAuth
  57. o.authMutex.Unlock()
  58. }
  59. func (o *memoryOIDCManager) removePendingAuth(state string) {
  60. o.authMutex.Lock()
  61. defer o.authMutex.Unlock()
  62. delete(o.pendingAuths, state)
  63. }
  64. func (o *memoryOIDCManager) getPendingAuth(state string) (oidcPendingAuth, error) {
  65. o.authMutex.RLock()
  66. defer o.authMutex.RUnlock()
  67. authReq, ok := o.pendingAuths[state]
  68. if !ok {
  69. return oidcPendingAuth{}, errors.New("oidc: no auth request found for the specified state")
  70. }
  71. diff := util.GetTimeAsMsSinceEpoch(time.Now()) - authReq.IssuedAt
  72. if diff > authStateValidity {
  73. return oidcPendingAuth{}, errors.New("oidc: auth request is too old")
  74. }
  75. return authReq, nil
  76. }
  77. func (o *memoryOIDCManager) addToken(token oidcToken) {
  78. o.tokenMutex.Lock()
  79. token.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now())
  80. o.tokens[token.Cookie] = token
  81. o.tokenMutex.Unlock()
  82. }
  83. func (o *memoryOIDCManager) getToken(cookie string) (oidcToken, error) {
  84. o.tokenMutex.RLock()
  85. defer o.tokenMutex.RUnlock()
  86. token, ok := o.tokens[cookie]
  87. if !ok {
  88. return oidcToken{}, errors.New("oidc: no token found for the specified session")
  89. }
  90. diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt
  91. if diff > tokenDeleteInterval {
  92. return oidcToken{}, errors.New("oidc: token is too old")
  93. }
  94. return token, nil
  95. }
  96. func (o *memoryOIDCManager) removeToken(cookie string) {
  97. o.tokenMutex.Lock()
  98. defer o.tokenMutex.Unlock()
  99. delete(o.tokens, cookie)
  100. }
  101. func (o *memoryOIDCManager) updateTokenUsage(token oidcToken) {
  102. diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt
  103. if diff > tokenUpdateInterval {
  104. o.addToken(token)
  105. }
  106. }
  107. func (o *memoryOIDCManager) cleanup() {
  108. o.cleanupAuthRequests()
  109. o.cleanupTokens()
  110. }
  111. func (o *memoryOIDCManager) cleanupAuthRequests() {
  112. o.authMutex.Lock()
  113. defer o.authMutex.Unlock()
  114. for k, auth := range o.pendingAuths {
  115. diff := util.GetTimeAsMsSinceEpoch(time.Now()) - auth.IssuedAt
  116. // remove old pending auth requests
  117. if diff < 0 || diff > authStateValidity {
  118. delete(o.pendingAuths, k)
  119. }
  120. }
  121. }
  122. func (o *memoryOIDCManager) cleanupTokens() {
  123. o.tokenMutex.Lock()
  124. defer o.tokenMutex.Unlock()
  125. for k, token := range o.tokens {
  126. diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt
  127. // remove tokens unused from more than tokenDeleteInterval
  128. if diff > tokenDeleteInterval {
  129. delete(o.tokens, k)
  130. }
  131. }
  132. }
  133. type dbOIDCManager struct{}
  134. func (o *dbOIDCManager) addPendingAuth(pendingAuth oidcPendingAuth) {
  135. session := dataprovider.Session{
  136. Key: pendingAuth.State,
  137. Data: pendingAuth,
  138. Type: dataprovider.SessionTypeOIDCAuth,
  139. Timestamp: pendingAuth.IssuedAt + authStateValidity,
  140. }
  141. dataprovider.AddSharedSession(session) //nolint:errcheck
  142. }
  143. func (o *dbOIDCManager) removePendingAuth(state string) {
  144. dataprovider.DeleteSharedSession(state) //nolint:errcheck
  145. }
  146. func (o *dbOIDCManager) getPendingAuth(state string) (oidcPendingAuth, error) {
  147. session, err := dataprovider.GetSharedSession(state)
  148. if err != nil {
  149. return oidcPendingAuth{}, errors.New("oidc: unable to get the auth request for the specified state")
  150. }
  151. if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) {
  152. // expired
  153. return oidcPendingAuth{}, errors.New("oidc: auth request is too old")
  154. }
  155. return o.decodePendingAuthData(session.Data)
  156. }
  157. func (o *dbOIDCManager) decodePendingAuthData(data any) (oidcPendingAuth, error) {
  158. if val, ok := data.([]byte); ok {
  159. authReq := oidcPendingAuth{}
  160. err := json.Unmarshal(val, &authReq)
  161. return authReq, err
  162. }
  163. logger.Error(logSender, "", "invalid oidc auth request data type %T", data)
  164. return oidcPendingAuth{}, errors.New("oidc: invalid auth request data")
  165. }
  166. func (o *dbOIDCManager) addToken(token oidcToken) {
  167. token.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now())
  168. session := dataprovider.Session{
  169. Key: token.Cookie,
  170. Data: token,
  171. Type: dataprovider.SessionTypeOIDCToken,
  172. Timestamp: token.UsedAt + tokenDeleteInterval,
  173. }
  174. dataprovider.AddSharedSession(session) //nolint:errcheck
  175. }
  176. func (o *dbOIDCManager) removeToken(cookie string) {
  177. dataprovider.DeleteSharedSession(cookie) //nolint:errcheck
  178. }
  179. func (o *dbOIDCManager) updateTokenUsage(token oidcToken) {
  180. diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt
  181. if diff > tokenUpdateInterval {
  182. o.addToken(token)
  183. }
  184. }
  185. func (o *dbOIDCManager) getToken(cookie string) (oidcToken, error) {
  186. session, err := dataprovider.GetSharedSession(cookie)
  187. if err != nil {
  188. return oidcToken{}, errors.New("oidc: unable to get the token for the specified session")
  189. }
  190. if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) {
  191. // expired
  192. return oidcToken{}, errors.New("oidc: token is too old")
  193. }
  194. return o.decodeTokenData(session.Data)
  195. }
  196. func (o *dbOIDCManager) decodeTokenData(data any) (oidcToken, error) {
  197. if val, ok := data.([]byte); ok {
  198. token := oidcToken{}
  199. err := json.Unmarshal(val, &token)
  200. return token, err
  201. }
  202. logger.Error(logSender, "", "invalid oidc token data type %T", data)
  203. return oidcToken{}, errors.New("oidc: invalid token data")
  204. }
  205. func (o *dbOIDCManager) cleanup() {
  206. dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOIDCAuth, time.Now()) //nolint:errcheck
  207. dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOIDCToken, time.Now()) //nolint:errcheck
  208. }