123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244 |
- // Copyright (C) 2019-2023 Nicola Murino
- //
- // This program is free software: you can redistribute it and/or modify
- // it under the terms of the GNU Affero General Public License as published
- // by the Free Software Foundation, version 3.
- //
- // This program is distributed in the hope that it will be useful,
- // but WITHOUT ANY WARRANTY; without even the implied warranty of
- // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- // GNU Affero General Public License for more details.
- //
- // You should have received a copy of the GNU Affero General Public License
- // along with this program. If not, see <https://www.gnu.org/licenses/>.
- package httpd
- import (
- "encoding/json"
- "errors"
- "sync"
- "time"
- "github.com/drakkan/sftpgo/v2/internal/dataprovider"
- "github.com/drakkan/sftpgo/v2/internal/logger"
- "github.com/drakkan/sftpgo/v2/internal/util"
- )
- var (
- oidcMgr oidcManager
- )
- func newOIDCManager(isShared int) oidcManager {
- if isShared == 1 {
- logger.Info(logSender, "", "using provider OIDC manager")
- return &dbOIDCManager{}
- }
- logger.Info(logSender, "", "using memory OIDC manager")
- return &memoryOIDCManager{
- pendingAuths: make(map[string]oidcPendingAuth),
- tokens: make(map[string]oidcToken),
- }
- }
- type oidcManager interface {
- addPendingAuth(pendingAuth oidcPendingAuth)
- removePendingAuth(state string)
- getPendingAuth(state string) (oidcPendingAuth, error)
- addToken(token oidcToken)
- getToken(cookie string) (oidcToken, error)
- removeToken(cookie string)
- updateTokenUsage(token oidcToken)
- cleanup()
- }
- type memoryOIDCManager struct {
- authMutex sync.RWMutex
- pendingAuths map[string]oidcPendingAuth
- tokenMutex sync.RWMutex
- tokens map[string]oidcToken
- }
- func (o *memoryOIDCManager) addPendingAuth(pendingAuth oidcPendingAuth) {
- o.authMutex.Lock()
- o.pendingAuths[pendingAuth.State] = pendingAuth
- o.authMutex.Unlock()
- }
- func (o *memoryOIDCManager) removePendingAuth(state string) {
- o.authMutex.Lock()
- defer o.authMutex.Unlock()
- delete(o.pendingAuths, state)
- }
- func (o *memoryOIDCManager) getPendingAuth(state string) (oidcPendingAuth, error) {
- o.authMutex.RLock()
- defer o.authMutex.RUnlock()
- authReq, ok := o.pendingAuths[state]
- if !ok {
- return oidcPendingAuth{}, errors.New("oidc: no auth request found for the specified state")
- }
- diff := util.GetTimeAsMsSinceEpoch(time.Now()) - authReq.IssuedAt
- if diff > authStateValidity {
- return oidcPendingAuth{}, errors.New("oidc: auth request is too old")
- }
- return authReq, nil
- }
- func (o *memoryOIDCManager) addToken(token oidcToken) {
- o.tokenMutex.Lock()
- token.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now())
- o.tokens[token.Cookie] = token
- o.tokenMutex.Unlock()
- }
- func (o *memoryOIDCManager) getToken(cookie string) (oidcToken, error) {
- o.tokenMutex.RLock()
- defer o.tokenMutex.RUnlock()
- token, ok := o.tokens[cookie]
- if !ok {
- return oidcToken{}, errors.New("oidc: no token found for the specified session")
- }
- diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt
- if diff > tokenDeleteInterval {
- return oidcToken{}, errors.New("oidc: token is too old")
- }
- return token, nil
- }
- func (o *memoryOIDCManager) removeToken(cookie string) {
- o.tokenMutex.Lock()
- defer o.tokenMutex.Unlock()
- delete(o.tokens, cookie)
- }
- func (o *memoryOIDCManager) updateTokenUsage(token oidcToken) {
- diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt
- if diff > tokenUpdateInterval {
- o.addToken(token)
- }
- }
- func (o *memoryOIDCManager) cleanup() {
- logger.Debug(logSender, "", "oidc manager cleanup")
- o.cleanupAuthRequests()
- o.cleanupTokens()
- }
- func (o *memoryOIDCManager) cleanupAuthRequests() {
- o.authMutex.Lock()
- defer o.authMutex.Unlock()
- for k, auth := range o.pendingAuths {
- diff := util.GetTimeAsMsSinceEpoch(time.Now()) - auth.IssuedAt
- // remove old pending auth requests
- if diff < 0 || diff > authStateValidity {
- delete(o.pendingAuths, k)
- }
- }
- }
- func (o *memoryOIDCManager) cleanupTokens() {
- o.tokenMutex.Lock()
- defer o.tokenMutex.Unlock()
- for k, token := range o.tokens {
- diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt
- // remove tokens unused from more than tokenDeleteInterval
- if diff > tokenDeleteInterval {
- delete(o.tokens, k)
- }
- }
- }
- type dbOIDCManager struct{}
- func (o *dbOIDCManager) addPendingAuth(pendingAuth oidcPendingAuth) {
- session := dataprovider.Session{
- Key: pendingAuth.State,
- Data: pendingAuth,
- Type: dataprovider.SessionTypeOIDCAuth,
- Timestamp: pendingAuth.IssuedAt + authStateValidity,
- }
- dataprovider.AddSharedSession(session) //nolint:errcheck
- }
- func (o *dbOIDCManager) removePendingAuth(state string) {
- dataprovider.DeleteSharedSession(state) //nolint:errcheck
- }
- func (o *dbOIDCManager) getPendingAuth(state string) (oidcPendingAuth, error) {
- session, err := dataprovider.GetSharedSession(state)
- if err != nil {
- return oidcPendingAuth{}, errors.New("oidc: unable to get the auth request for the specified state")
- }
- if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) {
- // expired
- return oidcPendingAuth{}, errors.New("oidc: auth request is too old")
- }
- return o.decodePendingAuthData(session.Data)
- }
- func (o *dbOIDCManager) decodePendingAuthData(data any) (oidcPendingAuth, error) {
- if val, ok := data.([]byte); ok {
- authReq := oidcPendingAuth{}
- err := json.Unmarshal(val, &authReq)
- return authReq, err
- }
- logger.Error(logSender, "", "invalid oidc auth request data type %T", data)
- return oidcPendingAuth{}, errors.New("oidc: invalid auth request data")
- }
- func (o *dbOIDCManager) addToken(token oidcToken) {
- token.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now())
- session := dataprovider.Session{
- Key: token.Cookie,
- Data: token,
- Type: dataprovider.SessionTypeOIDCToken,
- Timestamp: token.UsedAt + tokenDeleteInterval,
- }
- dataprovider.AddSharedSession(session) //nolint:errcheck
- }
- func (o *dbOIDCManager) removeToken(cookie string) {
- dataprovider.DeleteSharedSession(cookie) //nolint:errcheck
- }
- func (o *dbOIDCManager) updateTokenUsage(token oidcToken) {
- diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt
- if diff > tokenUpdateInterval {
- o.addToken(token)
- }
- }
- func (o *dbOIDCManager) getToken(cookie string) (oidcToken, error) {
- session, err := dataprovider.GetSharedSession(cookie)
- if err != nil {
- return oidcToken{}, errors.New("oidc: unable to get the token for the specified session")
- }
- if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) {
- // expired
- return oidcToken{}, errors.New("oidc: token is too old")
- }
- return o.decodeTokenData(session.Data)
- }
- func (o *dbOIDCManager) decodeTokenData(data any) (oidcToken, error) {
- if val, ok := data.([]byte); ok {
- token := oidcToken{}
- err := json.Unmarshal(val, &token)
- return token, err
- }
- logger.Error(logSender, "", "invalid oidc token data type %T", data)
- return oidcToken{}, errors.New("oidc: invalid token data")
- }
- func (o *dbOIDCManager) cleanup() {
- logger.Debug(logSender, "", "oidc manager cleanup")
- dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOIDCAuth, time.Now()) //nolint:errcheck
- dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOIDCToken, time.Now()) //nolint:errcheck
- }
|