oidcmanager.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. // Copyright (C) 2019-2022 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package httpd
  15. import (
  16. "encoding/json"
  17. "errors"
  18. "sync"
  19. "time"
  20. "github.com/drakkan/sftpgo/v2/internal/dataprovider"
  21. "github.com/drakkan/sftpgo/v2/internal/logger"
  22. "github.com/drakkan/sftpgo/v2/internal/util"
  23. )
  24. var (
  25. oidcMgr oidcManager
  26. )
  27. func newOIDCManager(isShared int) oidcManager {
  28. if isShared == 1 {
  29. logger.Info(logSender, "", "using provider OIDC manager")
  30. return &dbOIDCManager{}
  31. }
  32. logger.Info(logSender, "", "using memory OIDC manager")
  33. return &memoryOIDCManager{
  34. pendingAuths: make(map[string]oidcPendingAuth),
  35. tokens: make(map[string]oidcToken),
  36. }
  37. }
  38. type oidcManager interface {
  39. addPendingAuth(pendingAuth oidcPendingAuth)
  40. removePendingAuth(state string)
  41. getPendingAuth(state string) (oidcPendingAuth, error)
  42. addToken(token oidcToken)
  43. getToken(cookie string) (oidcToken, error)
  44. removeToken(cookie string)
  45. updateTokenUsage(token oidcToken)
  46. cleanup()
  47. }
  48. type memoryOIDCManager struct {
  49. authMutex sync.RWMutex
  50. pendingAuths map[string]oidcPendingAuth
  51. tokenMutex sync.RWMutex
  52. tokens map[string]oidcToken
  53. }
  54. func (o *memoryOIDCManager) addPendingAuth(pendingAuth oidcPendingAuth) {
  55. o.authMutex.Lock()
  56. o.pendingAuths[pendingAuth.State] = pendingAuth
  57. o.authMutex.Unlock()
  58. }
  59. func (o *memoryOIDCManager) removePendingAuth(state string) {
  60. o.authMutex.Lock()
  61. defer o.authMutex.Unlock()
  62. delete(o.pendingAuths, state)
  63. }
  64. func (o *memoryOIDCManager) getPendingAuth(state string) (oidcPendingAuth, error) {
  65. o.authMutex.RLock()
  66. defer o.authMutex.RUnlock()
  67. authReq, ok := o.pendingAuths[state]
  68. if !ok {
  69. return oidcPendingAuth{}, errors.New("oidc: no auth request found for the specified state")
  70. }
  71. diff := util.GetTimeAsMsSinceEpoch(time.Now()) - authReq.IssuedAt
  72. if diff > authStateValidity {
  73. return oidcPendingAuth{}, errors.New("oidc: auth request is too old")
  74. }
  75. return authReq, nil
  76. }
  77. func (o *memoryOIDCManager) addToken(token oidcToken) {
  78. o.tokenMutex.Lock()
  79. token.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now())
  80. o.tokens[token.Cookie] = token
  81. o.tokenMutex.Unlock()
  82. }
  83. func (o *memoryOIDCManager) getToken(cookie string) (oidcToken, error) {
  84. o.tokenMutex.RLock()
  85. defer o.tokenMutex.RUnlock()
  86. token, ok := o.tokens[cookie]
  87. if !ok {
  88. return oidcToken{}, errors.New("oidc: no token found for the specified session")
  89. }
  90. diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt
  91. if diff > tokenDeleteInterval {
  92. return oidcToken{}, errors.New("oidc: token is too old")
  93. }
  94. return token, nil
  95. }
  96. func (o *memoryOIDCManager) removeToken(cookie string) {
  97. o.tokenMutex.Lock()
  98. defer o.tokenMutex.Unlock()
  99. delete(o.tokens, cookie)
  100. }
  101. func (o *memoryOIDCManager) updateTokenUsage(token oidcToken) {
  102. diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt
  103. if diff > tokenUpdateInterval {
  104. o.addToken(token)
  105. }
  106. }
  107. func (o *memoryOIDCManager) cleanup() {
  108. logger.Debug(logSender, "", "oidc manager cleanup")
  109. o.cleanupAuthRequests()
  110. o.cleanupTokens()
  111. }
  112. func (o *memoryOIDCManager) cleanupAuthRequests() {
  113. o.authMutex.Lock()
  114. defer o.authMutex.Unlock()
  115. for k, auth := range o.pendingAuths {
  116. diff := util.GetTimeAsMsSinceEpoch(time.Now()) - auth.IssuedAt
  117. // remove old pending auth requests
  118. if diff < 0 || diff > authStateValidity {
  119. delete(o.pendingAuths, k)
  120. }
  121. }
  122. }
  123. func (o *memoryOIDCManager) cleanupTokens() {
  124. o.tokenMutex.Lock()
  125. defer o.tokenMutex.Unlock()
  126. for k, token := range o.tokens {
  127. diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt
  128. // remove tokens unused from more than tokenDeleteInterval
  129. if diff > tokenDeleteInterval {
  130. delete(o.tokens, k)
  131. }
  132. }
  133. }
  134. type dbOIDCManager struct{}
  135. func (o *dbOIDCManager) addPendingAuth(pendingAuth oidcPendingAuth) {
  136. session := dataprovider.Session{
  137. Key: pendingAuth.State,
  138. Data: pendingAuth,
  139. Type: dataprovider.SessionTypeOIDCAuth,
  140. Timestamp: pendingAuth.IssuedAt + authStateValidity,
  141. }
  142. dataprovider.AddSharedSession(session) //nolint:errcheck
  143. }
  144. func (o *dbOIDCManager) removePendingAuth(state string) {
  145. dataprovider.DeleteSharedSession(state) //nolint:errcheck
  146. }
  147. func (o *dbOIDCManager) getPendingAuth(state string) (oidcPendingAuth, error) {
  148. session, err := dataprovider.GetSharedSession(state)
  149. if err != nil {
  150. return oidcPendingAuth{}, errors.New("oidc: unable to get the auth request for the specified state")
  151. }
  152. if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) {
  153. // expired
  154. return oidcPendingAuth{}, errors.New("oidc: auth request is too old")
  155. }
  156. return o.decodePendingAuthData(session.Data)
  157. }
  158. func (o *dbOIDCManager) decodePendingAuthData(data any) (oidcPendingAuth, error) {
  159. if val, ok := data.([]byte); ok {
  160. authReq := oidcPendingAuth{}
  161. err := json.Unmarshal(val, &authReq)
  162. return authReq, err
  163. }
  164. logger.Error(logSender, "", "invalid oidc auth request data type %T", data)
  165. return oidcPendingAuth{}, errors.New("oidc: invalid auth request data")
  166. }
  167. func (o *dbOIDCManager) addToken(token oidcToken) {
  168. token.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now())
  169. session := dataprovider.Session{
  170. Key: token.Cookie,
  171. Data: token,
  172. Type: dataprovider.SessionTypeOIDCToken,
  173. Timestamp: token.UsedAt + tokenDeleteInterval,
  174. }
  175. dataprovider.AddSharedSession(session) //nolint:errcheck
  176. }
  177. func (o *dbOIDCManager) removeToken(cookie string) {
  178. dataprovider.DeleteSharedSession(cookie) //nolint:errcheck
  179. }
  180. func (o *dbOIDCManager) updateTokenUsage(token oidcToken) {
  181. diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt
  182. if diff > tokenUpdateInterval {
  183. o.addToken(token)
  184. }
  185. }
  186. func (o *dbOIDCManager) getToken(cookie string) (oidcToken, error) {
  187. session, err := dataprovider.GetSharedSession(cookie)
  188. if err != nil {
  189. return oidcToken{}, errors.New("oidc: unable to get the token for the specified session")
  190. }
  191. if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) {
  192. // expired
  193. return oidcToken{}, errors.New("oidc: token is too old")
  194. }
  195. return o.decodeTokenData(session.Data)
  196. }
  197. func (o *dbOIDCManager) decodeTokenData(data any) (oidcToken, error) {
  198. if val, ok := data.([]byte); ok {
  199. token := oidcToken{}
  200. err := json.Unmarshal(val, &token)
  201. return token, err
  202. }
  203. logger.Error(logSender, "", "invalid oidc token data type %T", data)
  204. return oidcToken{}, errors.New("oidc: invalid token data")
  205. }
  206. func (o *dbOIDCManager) cleanup() {
  207. logger.Debug(logSender, "", "oidc manager cleanup")
  208. dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOIDCAuth, time.Now()) //nolint:errcheck
  209. dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOIDCToken, time.Now()) //nolint:errcheck
  210. }