oidc_test.go 57 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776
  1. // Copyright (C) 2019 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package httpd
  15. import (
  16. "bytes"
  17. "context"
  18. "encoding/json"
  19. "fmt"
  20. "io/fs"
  21. "net/http"
  22. "net/http/httptest"
  23. "net/url"
  24. "os"
  25. "path/filepath"
  26. "reflect"
  27. "runtime"
  28. "testing"
  29. "time"
  30. "unsafe"
  31. "github.com/coreos/go-oidc/v3/oidc"
  32. "github.com/go-chi/jwtauth/v5"
  33. "github.com/rs/xid"
  34. "github.com/sftpgo/sdk"
  35. "github.com/stretchr/testify/assert"
  36. "github.com/stretchr/testify/require"
  37. "golang.org/x/oauth2"
  38. "github.com/drakkan/sftpgo/v2/internal/common"
  39. "github.com/drakkan/sftpgo/v2/internal/dataprovider"
  40. "github.com/drakkan/sftpgo/v2/internal/kms"
  41. "github.com/drakkan/sftpgo/v2/internal/util"
  42. "github.com/drakkan/sftpgo/v2/internal/vfs"
  43. )
  44. const (
  45. oidcMockAddr = "127.0.0.1:11111"
  46. )
  47. type mockTokenSource struct {
  48. token *oauth2.Token
  49. err error
  50. }
  51. func (t *mockTokenSource) Token() (*oauth2.Token, error) {
  52. return t.token, t.err
  53. }
  54. type mockOAuth2Config struct {
  55. tokenSource *mockTokenSource
  56. authCodeURL string
  57. token *oauth2.Token
  58. err error
  59. }
  60. func (c *mockOAuth2Config) AuthCodeURL(_ string, _ ...oauth2.AuthCodeOption) string {
  61. return c.authCodeURL
  62. }
  63. func (c *mockOAuth2Config) Exchange(_ context.Context, _ string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
  64. return c.token, c.err
  65. }
  66. func (c *mockOAuth2Config) TokenSource(_ context.Context, _ *oauth2.Token) oauth2.TokenSource {
  67. return c.tokenSource
  68. }
  69. type mockOIDCVerifier struct {
  70. token *oidc.IDToken
  71. err error
  72. }
  73. func (v *mockOIDCVerifier) Verify(_ context.Context, _ string) (*oidc.IDToken, error) {
  74. return v.token, v.err
  75. }
  76. // hack because the field is unexported
  77. func setIDTokenClaims(idToken *oidc.IDToken, claims []byte) {
  78. pointerVal := reflect.ValueOf(idToken)
  79. val := reflect.Indirect(pointerVal)
  80. member := val.FieldByName("claims")
  81. ptr := unsafe.Pointer(member.UnsafeAddr())
  82. realPtr := (*[]byte)(ptr)
  83. *realPtr = claims
  84. }
  85. func TestOIDCInitialization(t *testing.T) {
  86. config := OIDC{}
  87. err := config.initialize()
  88. assert.NoError(t, err)
  89. secret := "jRsmE0SWnuZjP7djBqNq0mrf8QN77j2c"
  90. config = OIDC{
  91. ClientID: "sftpgo-client",
  92. ClientSecret: util.GenerateUniqueID(),
  93. ConfigURL: fmt.Sprintf("http://%v/", oidcMockAddr),
  94. RedirectBaseURL: "http://127.0.0.1:8081/",
  95. UsernameField: "preferred_username",
  96. RoleField: "sftpgo_role",
  97. }
  98. err = config.initialize()
  99. if assert.Error(t, err) {
  100. assert.Contains(t, err.Error(), "oidc: required scope \"openid\" is not set")
  101. }
  102. config.Scopes = []string{oidc.ScopeOpenID}
  103. config.ClientSecretFile = "missing file"
  104. err = config.initialize()
  105. assert.ErrorIs(t, err, fs.ErrNotExist)
  106. secretFile := filepath.Join(os.TempDir(), util.GenerateUniqueID())
  107. defer os.Remove(secretFile)
  108. err = os.WriteFile(secretFile, []byte(secret), 0600)
  109. assert.NoError(t, err)
  110. config.ClientSecretFile = secretFile
  111. err = config.initialize()
  112. if assert.Error(t, err) {
  113. assert.Contains(t, err.Error(), "oidc: unable to initialize provider")
  114. }
  115. assert.Equal(t, secret, config.ClientSecret)
  116. config.ConfigURL = fmt.Sprintf("http://%v/auth/realms/sftpgo", oidcMockAddr)
  117. err = config.initialize()
  118. assert.NoError(t, err)
  119. assert.Equal(t, "http://127.0.0.1:8081"+webOIDCRedirectPath, config.getRedirectURL())
  120. }
  121. func TestOIDCLoginLogout(t *testing.T) {
  122. tokenValidationMode = 2
  123. oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
  124. require.True(t, ok)
  125. server := getTestOIDCServer()
  126. err := server.binding.OIDC.initialize()
  127. assert.NoError(t, err)
  128. server.initializeRouter()
  129. rr := httptest.NewRecorder()
  130. r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath, nil)
  131. assert.NoError(t, err)
  132. server.router.ServeHTTP(rr, r)
  133. assert.Equal(t, http.StatusBadRequest, rr.Code)
  134. assert.Contains(t, rr.Body.String(), util.I18nInvalidAuth)
  135. expiredAuthReq := oidcPendingAuth{
  136. State: util.GenerateOpaqueString(),
  137. Nonce: util.GenerateOpaqueString(),
  138. Audience: tokenAudienceWebClient,
  139. IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)),
  140. }
  141. oidcMgr.addPendingAuth(expiredAuthReq)
  142. rr = httptest.NewRecorder()
  143. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+expiredAuthReq.State, nil)
  144. assert.NoError(t, err)
  145. server.router.ServeHTTP(rr, r)
  146. assert.Equal(t, http.StatusBadRequest, rr.Code)
  147. assert.Contains(t, rr.Body.String(), util.I18nInvalidAuth)
  148. oidcMgr.removePendingAuth(expiredAuthReq.State)
  149. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  150. tokenSource: &mockTokenSource{},
  151. authCodeURL: webOIDCRedirectPath,
  152. err: common.ErrGenericFailure,
  153. }
  154. server.binding.OIDC.verifier = &mockOIDCVerifier{
  155. err: common.ErrGenericFailure,
  156. }
  157. rr = httptest.NewRecorder()
  158. r, err = http.NewRequest(http.MethodGet, webAdminOIDCLoginPath, nil)
  159. assert.NoError(t, err)
  160. server.router.ServeHTTP(rr, r)
  161. assert.Equal(t, http.StatusFound, rr.Code)
  162. assert.Equal(t, webOIDCRedirectPath, rr.Header().Get("Location"))
  163. require.Len(t, oidcMgr.pendingAuths, 1)
  164. var state string
  165. for k := range oidcMgr.pendingAuths {
  166. state = k
  167. }
  168. rr = httptest.NewRecorder()
  169. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+state, nil)
  170. assert.NoError(t, err)
  171. server.router.ServeHTTP(rr, r)
  172. assert.Equal(t, http.StatusFound, rr.Code)
  173. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  174. require.Len(t, oidcMgr.pendingAuths, 0)
  175. rr = httptest.NewRecorder()
  176. r, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil)
  177. assert.NoError(t, err)
  178. server.router.ServeHTTP(rr, r)
  179. assert.Equal(t, http.StatusOK, rr.Code)
  180. // now the same for the web client
  181. rr = httptest.NewRecorder()
  182. r, err = http.NewRequest(http.MethodGet, webClientOIDCLoginPath, nil)
  183. assert.NoError(t, err)
  184. server.router.ServeHTTP(rr, r)
  185. assert.Equal(t, http.StatusFound, rr.Code)
  186. assert.Equal(t, webOIDCRedirectPath, rr.Header().Get("Location"))
  187. require.Len(t, oidcMgr.pendingAuths, 1)
  188. for k := range oidcMgr.pendingAuths {
  189. state = k
  190. }
  191. rr = httptest.NewRecorder()
  192. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+state, nil)
  193. assert.NoError(t, err)
  194. server.router.ServeHTTP(rr, r)
  195. assert.Equal(t, http.StatusFound, rr.Code)
  196. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  197. require.Len(t, oidcMgr.pendingAuths, 0)
  198. rr = httptest.NewRecorder()
  199. r, err = http.NewRequest(http.MethodGet, webClientLoginPath, nil)
  200. assert.NoError(t, err)
  201. server.router.ServeHTTP(rr, r)
  202. assert.Equal(t, http.StatusOK, rr.Code)
  203. // now return an OAuth2 token without the id_token
  204. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  205. tokenSource: &mockTokenSource{},
  206. authCodeURL: webOIDCRedirectPath,
  207. token: &oauth2.Token{
  208. AccessToken: "123",
  209. Expiry: time.Now().Add(5 * time.Minute),
  210. },
  211. err: nil,
  212. }
  213. authReq := newOIDCPendingAuth(tokenAudienceWebClient)
  214. oidcMgr.addPendingAuth(authReq)
  215. rr = httptest.NewRecorder()
  216. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  217. assert.NoError(t, err)
  218. server.router.ServeHTTP(rr, r)
  219. assert.Equal(t, http.StatusFound, rr.Code)
  220. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  221. require.Len(t, oidcMgr.pendingAuths, 0)
  222. // now fail to verify the id token
  223. token := &oauth2.Token{
  224. AccessToken: "123",
  225. Expiry: time.Now().Add(5 * time.Minute),
  226. }
  227. token = token.WithExtra(map[string]any{
  228. "id_token": "id_token_val",
  229. })
  230. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  231. tokenSource: &mockTokenSource{},
  232. authCodeURL: webOIDCRedirectPath,
  233. token: token,
  234. err: nil,
  235. }
  236. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  237. oidcMgr.addPendingAuth(authReq)
  238. rr = httptest.NewRecorder()
  239. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  240. assert.NoError(t, err)
  241. server.router.ServeHTTP(rr, r)
  242. assert.Equal(t, http.StatusFound, rr.Code)
  243. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  244. require.Len(t, oidcMgr.pendingAuths, 0)
  245. // id token nonce does not match
  246. server.binding.OIDC.verifier = &mockOIDCVerifier{
  247. err: nil,
  248. token: &oidc.IDToken{},
  249. }
  250. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  251. oidcMgr.addPendingAuth(authReq)
  252. rr = httptest.NewRecorder()
  253. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  254. assert.NoError(t, err)
  255. server.router.ServeHTTP(rr, r)
  256. assert.Equal(t, http.StatusFound, rr.Code)
  257. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  258. require.Len(t, oidcMgr.pendingAuths, 0)
  259. // null id token claims
  260. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  261. oidcMgr.addPendingAuth(authReq)
  262. server.binding.OIDC.verifier = &mockOIDCVerifier{
  263. err: nil,
  264. token: &oidc.IDToken{
  265. Nonce: authReq.Nonce,
  266. },
  267. }
  268. rr = httptest.NewRecorder()
  269. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  270. assert.NoError(t, err)
  271. server.router.ServeHTTP(rr, r)
  272. assert.Equal(t, http.StatusFound, rr.Code)
  273. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  274. require.Len(t, oidcMgr.pendingAuths, 0)
  275. // invalid id token claims: no username
  276. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  277. oidcMgr.addPendingAuth(authReq)
  278. idToken := &oidc.IDToken{
  279. Nonce: authReq.Nonce,
  280. Expiry: time.Now().Add(5 * time.Minute),
  281. }
  282. setIDTokenClaims(idToken, []byte(`{"aud": "my_client_id"}`))
  283. server.binding.OIDC.verifier = &mockOIDCVerifier{
  284. err: nil,
  285. token: idToken,
  286. }
  287. rr = httptest.NewRecorder()
  288. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  289. assert.NoError(t, err)
  290. server.router.ServeHTTP(rr, r)
  291. assert.Equal(t, http.StatusFound, rr.Code)
  292. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  293. require.Len(t, oidcMgr.pendingAuths, 0)
  294. // invalid id token clamims: username not a string
  295. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  296. oidcMgr.addPendingAuth(authReq)
  297. idToken = &oidc.IDToken{
  298. Nonce: authReq.Nonce,
  299. Expiry: time.Now().Add(5 * time.Minute),
  300. }
  301. setIDTokenClaims(idToken, []byte(`{"aud": "my_client_id","preferred_username": 1}`))
  302. server.binding.OIDC.verifier = &mockOIDCVerifier{
  303. err: nil,
  304. token: idToken,
  305. }
  306. rr = httptest.NewRecorder()
  307. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  308. assert.NoError(t, err)
  309. server.router.ServeHTTP(rr, r)
  310. assert.Equal(t, http.StatusFound, rr.Code)
  311. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  312. require.Len(t, oidcMgr.pendingAuths, 0)
  313. // invalid audience
  314. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  315. oidcMgr.addPendingAuth(authReq)
  316. idToken = &oidc.IDToken{
  317. Nonce: authReq.Nonce,
  318. Expiry: time.Now().Add(5 * time.Minute),
  319. }
  320. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test","sftpgo_role":"admin"}`))
  321. server.binding.OIDC.verifier = &mockOIDCVerifier{
  322. err: nil,
  323. token: idToken,
  324. }
  325. rr = httptest.NewRecorder()
  326. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  327. assert.NoError(t, err)
  328. server.router.ServeHTTP(rr, r)
  329. assert.Equal(t, http.StatusFound, rr.Code)
  330. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  331. require.Len(t, oidcMgr.pendingAuths, 0)
  332. // invalid audience
  333. authReq = newOIDCPendingAuth(tokenAudienceWebAdmin)
  334. oidcMgr.addPendingAuth(authReq)
  335. idToken = &oidc.IDToken{
  336. Nonce: authReq.Nonce,
  337. Expiry: time.Now().Add(5 * time.Minute),
  338. }
  339. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test"}`))
  340. server.binding.OIDC.verifier = &mockOIDCVerifier{
  341. err: nil,
  342. token: idToken,
  343. }
  344. rr = httptest.NewRecorder()
  345. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  346. assert.NoError(t, err)
  347. server.router.ServeHTTP(rr, r)
  348. assert.Equal(t, http.StatusFound, rr.Code)
  349. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  350. require.Len(t, oidcMgr.pendingAuths, 0)
  351. // mapped user not found
  352. authReq = newOIDCPendingAuth(tokenAudienceWebAdmin)
  353. oidcMgr.addPendingAuth(authReq)
  354. idToken = &oidc.IDToken{
  355. Nonce: authReq.Nonce,
  356. Expiry: time.Now().Add(5 * time.Minute),
  357. }
  358. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test","sftpgo_role":"admin"}`))
  359. server.binding.OIDC.verifier = &mockOIDCVerifier{
  360. err: nil,
  361. token: idToken,
  362. }
  363. rr = httptest.NewRecorder()
  364. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  365. assert.NoError(t, err)
  366. server.router.ServeHTTP(rr, r)
  367. assert.Equal(t, http.StatusFound, rr.Code)
  368. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  369. require.Len(t, oidcMgr.pendingAuths, 0)
  370. // admin login ok
  371. authReq = newOIDCPendingAuth(tokenAudienceWebAdmin)
  372. oidcMgr.addPendingAuth(authReq)
  373. idToken = &oidc.IDToken{
  374. Nonce: authReq.Nonce,
  375. Expiry: time.Now().Add(5 * time.Minute),
  376. }
  377. setIDTokenClaims(idToken, []byte(`{"preferred_username":"admin","sftpgo_role":"admin","sid":"sid123"}`))
  378. server.binding.OIDC.verifier = &mockOIDCVerifier{
  379. err: nil,
  380. token: idToken,
  381. }
  382. rr = httptest.NewRecorder()
  383. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  384. assert.NoError(t, err)
  385. server.router.ServeHTTP(rr, r)
  386. assert.Equal(t, http.StatusFound, rr.Code)
  387. assert.Equal(t, webUsersPath, rr.Header().Get("Location"))
  388. require.Len(t, oidcMgr.pendingAuths, 0)
  389. require.Len(t, oidcMgr.tokens, 1)
  390. // admin profile is not available
  391. var tokenCookie string
  392. for k := range oidcMgr.tokens {
  393. tokenCookie = k
  394. }
  395. oidcToken, err := oidcMgr.getToken(tokenCookie)
  396. assert.NoError(t, err)
  397. assert.Equal(t, "sid123", oidcToken.SessionID)
  398. assert.True(t, oidcToken.isAdmin())
  399. assert.False(t, oidcToken.isExpired())
  400. rr = httptest.NewRecorder()
  401. r, err = http.NewRequest(http.MethodGet, webAdminProfilePath, nil)
  402. assert.NoError(t, err)
  403. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  404. server.router.ServeHTTP(rr, r)
  405. assert.Equal(t, http.StatusForbidden, rr.Code)
  406. // the admin can access the allowed pages
  407. rr = httptest.NewRecorder()
  408. r, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
  409. assert.NoError(t, err)
  410. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  411. server.router.ServeHTTP(rr, r)
  412. assert.Equal(t, http.StatusOK, rr.Code)
  413. // try with an invalid cookie
  414. rr = httptest.NewRecorder()
  415. r, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
  416. assert.NoError(t, err)
  417. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, xid.New().String()))
  418. server.router.ServeHTTP(rr, r)
  419. assert.Equal(t, http.StatusFound, rr.Code)
  420. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  421. // Web Client is not available with an admin token
  422. rr = httptest.NewRecorder()
  423. r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
  424. assert.NoError(t, err)
  425. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  426. server.router.ServeHTTP(rr, r)
  427. assert.Equal(t, http.StatusFound, rr.Code)
  428. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  429. // logout the admin user
  430. rr = httptest.NewRecorder()
  431. r, err = http.NewRequest(http.MethodGet, webLogoutPath, nil)
  432. assert.NoError(t, err)
  433. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  434. server.router.ServeHTTP(rr, r)
  435. assert.Equal(t, http.StatusFound, rr.Code)
  436. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  437. require.Len(t, oidcMgr.pendingAuths, 0)
  438. require.Len(t, oidcMgr.tokens, 0)
  439. // now login and logout a user
  440. username := "test_oidc_user"
  441. user := dataprovider.User{
  442. BaseUser: sdk.BaseUser{
  443. Username: username,
  444. Password: "pwd",
  445. HomeDir: filepath.Join(os.TempDir(), username),
  446. Status: 1,
  447. Permissions: map[string][]string{
  448. "/": {dataprovider.PermAny},
  449. },
  450. },
  451. Filters: dataprovider.UserFilters{
  452. BaseUserFilters: sdk.BaseUserFilters{
  453. WebClient: []string{sdk.WebClientSharesDisabled},
  454. },
  455. },
  456. }
  457. err = dataprovider.AddUser(&user, "", "", "")
  458. assert.NoError(t, err)
  459. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  460. oidcMgr.addPendingAuth(authReq)
  461. idToken = &oidc.IDToken{
  462. Nonce: authReq.Nonce,
  463. Expiry: time.Now().Add(5 * time.Minute),
  464. }
  465. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test_oidc_user"}`))
  466. server.binding.OIDC.verifier = &mockOIDCVerifier{
  467. err: nil,
  468. token: idToken,
  469. }
  470. rr = httptest.NewRecorder()
  471. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  472. assert.NoError(t, err)
  473. server.router.ServeHTTP(rr, r)
  474. assert.Equal(t, http.StatusFound, rr.Code)
  475. assert.Equal(t, webClientFilesPath, rr.Header().Get("Location"))
  476. require.Len(t, oidcMgr.pendingAuths, 0)
  477. require.Len(t, oidcMgr.tokens, 1)
  478. // user profile is not available
  479. for k := range oidcMgr.tokens {
  480. tokenCookie = k
  481. }
  482. oidcToken, err = oidcMgr.getToken(tokenCookie)
  483. assert.NoError(t, err)
  484. assert.Empty(t, oidcToken.SessionID)
  485. assert.False(t, oidcToken.isAdmin())
  486. assert.False(t, oidcToken.isExpired())
  487. if assert.Len(t, oidcToken.Permissions, 1) {
  488. assert.Equal(t, sdk.WebClientSharesDisabled, oidcToken.Permissions[0])
  489. }
  490. rr = httptest.NewRecorder()
  491. r, err = http.NewRequest(http.MethodGet, webClientProfilePath, nil)
  492. assert.NoError(t, err)
  493. r.RequestURI = webClientProfilePath
  494. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  495. server.router.ServeHTTP(rr, r)
  496. assert.Equal(t, http.StatusOK, rr.Code)
  497. // the user can access the allowed pages
  498. rr = httptest.NewRecorder()
  499. r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
  500. assert.NoError(t, err)
  501. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  502. server.router.ServeHTTP(rr, r)
  503. assert.Equal(t, http.StatusOK, rr.Code)
  504. // try with an invalid cookie
  505. rr = httptest.NewRecorder()
  506. r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
  507. assert.NoError(t, err)
  508. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, xid.New().String()))
  509. server.router.ServeHTTP(rr, r)
  510. assert.Equal(t, http.StatusFound, rr.Code)
  511. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  512. // Web Admin is not available with a client cookie
  513. rr = httptest.NewRecorder()
  514. r, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
  515. assert.NoError(t, err)
  516. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  517. server.router.ServeHTTP(rr, r)
  518. assert.Equal(t, http.StatusFound, rr.Code)
  519. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  520. // logout the user
  521. rr = httptest.NewRecorder()
  522. r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  523. assert.NoError(t, err)
  524. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  525. server.router.ServeHTTP(rr, r)
  526. assert.Equal(t, http.StatusFound, rr.Code)
  527. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  528. require.Len(t, oidcMgr.pendingAuths, 0)
  529. require.Len(t, oidcMgr.tokens, 0)
  530. err = os.RemoveAll(user.GetHomeDir())
  531. assert.NoError(t, err)
  532. err = dataprovider.DeleteUser(username, "", "", "")
  533. assert.NoError(t, err)
  534. tokenValidationMode = 0
  535. }
  536. func TestOIDCRefreshToken(t *testing.T) {
  537. oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
  538. require.True(t, ok)
  539. r, err := http.NewRequest(http.MethodGet, webUsersPath, nil)
  540. assert.NoError(t, err)
  541. token := oidcToken{
  542. Cookie: util.GenerateOpaqueString(),
  543. AccessToken: xid.New().String(),
  544. TokenType: "Bearer",
  545. ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-1 * time.Minute)),
  546. Nonce: xid.New().String(),
  547. Role: adminRoleFieldValue,
  548. Username: defaultAdminUsername,
  549. }
  550. config := mockOAuth2Config{
  551. tokenSource: &mockTokenSource{
  552. err: common.ErrGenericFailure,
  553. },
  554. }
  555. verifier := mockOIDCVerifier{
  556. err: common.ErrGenericFailure,
  557. }
  558. err = token.refresh(context.Background(), &config, &verifier, r)
  559. if assert.Error(t, err) {
  560. assert.Contains(t, err.Error(), "refresh token not set")
  561. }
  562. token.RefreshToken = xid.New().String()
  563. err = token.refresh(context.Background(), &config, &verifier, r)
  564. assert.ErrorIs(t, err, common.ErrGenericFailure)
  565. newToken := &oauth2.Token{
  566. AccessToken: xid.New().String(),
  567. RefreshToken: xid.New().String(),
  568. Expiry: time.Now().Add(5 * time.Minute),
  569. }
  570. config = mockOAuth2Config{
  571. tokenSource: &mockTokenSource{
  572. token: newToken,
  573. },
  574. }
  575. verifier = mockOIDCVerifier{
  576. token: &oidc.IDToken{},
  577. }
  578. err = token.refresh(context.Background(), &config, &verifier, r)
  579. if assert.Error(t, err) {
  580. assert.Contains(t, err.Error(), "the refreshed token has no id token")
  581. }
  582. newToken = newToken.WithExtra(map[string]any{
  583. "id_token": "id_token_val",
  584. })
  585. newToken.Expiry = time.Time{}
  586. config = mockOAuth2Config{
  587. tokenSource: &mockTokenSource{
  588. token: newToken,
  589. },
  590. }
  591. verifier = mockOIDCVerifier{
  592. err: common.ErrGenericFailure,
  593. }
  594. err = token.refresh(context.Background(), &config, &verifier, r)
  595. assert.ErrorIs(t, err, common.ErrGenericFailure)
  596. newToken = newToken.WithExtra(map[string]any{
  597. "id_token": "id_token_val",
  598. })
  599. newToken.Expiry = time.Now().Add(5 * time.Minute)
  600. config = mockOAuth2Config{
  601. tokenSource: &mockTokenSource{
  602. token: newToken,
  603. },
  604. }
  605. verifier = mockOIDCVerifier{
  606. token: &oidc.IDToken{
  607. Nonce: xid.New().String(), // nonce is different from the expected one
  608. },
  609. }
  610. err = token.refresh(context.Background(), &config, &verifier, r)
  611. if assert.Error(t, err) {
  612. assert.Contains(t, err.Error(), "the refreshed token nonce mismatch")
  613. }
  614. verifier = mockOIDCVerifier{
  615. token: &oidc.IDToken{
  616. Nonce: "", // empty token is fine on refresh but claims are not set
  617. },
  618. }
  619. err = token.refresh(context.Background(), &config, &verifier, r)
  620. if assert.Error(t, err) {
  621. assert.Contains(t, err.Error(), "oidc: claims not set")
  622. }
  623. idToken := &oidc.IDToken{
  624. Nonce: token.Nonce,
  625. }
  626. setIDTokenClaims(idToken, []byte(`{"sid":"id_token_sid"}`))
  627. verifier = mockOIDCVerifier{
  628. token: idToken,
  629. }
  630. err = token.refresh(context.Background(), &config, &verifier, r)
  631. assert.NoError(t, err)
  632. assert.Len(t, token.Permissions, 1)
  633. token.Role = nil
  634. // user does not exist
  635. err = token.refresh(context.Background(), &config, &verifier, r)
  636. assert.Error(t, err)
  637. require.Len(t, oidcMgr.tokens, 1)
  638. oidcMgr.removeToken(token.Cookie)
  639. require.Len(t, oidcMgr.tokens, 0)
  640. }
  641. func TestOIDCRefreshUser(t *testing.T) {
  642. token := oidcToken{
  643. Cookie: util.GenerateOpaqueString(),
  644. AccessToken: xid.New().String(),
  645. TokenType: "Bearer",
  646. ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(1 * time.Minute)),
  647. Nonce: xid.New().String(),
  648. Role: adminRoleFieldValue,
  649. Username: "missing username",
  650. }
  651. r, err := http.NewRequest(http.MethodGet, webUsersPath, nil)
  652. assert.NoError(t, err)
  653. err = token.refreshUser(r)
  654. assert.Error(t, err)
  655. admin := dataprovider.Admin{
  656. Username: "test_oidc_admin_refresh",
  657. Password: "p",
  658. Permissions: []string{dataprovider.PermAdminAny},
  659. Status: 0,
  660. Filters: dataprovider.AdminFilters{
  661. Preferences: dataprovider.AdminPreferences{
  662. HideUserPageSections: 1 + 2 + 4,
  663. },
  664. },
  665. }
  666. err = dataprovider.AddAdmin(&admin, "", "", "")
  667. assert.NoError(t, err)
  668. token.Username = admin.Username
  669. err = token.refreshUser(r)
  670. if assert.Error(t, err) {
  671. assert.Contains(t, err.Error(), "is disabled")
  672. }
  673. admin.Status = 1
  674. err = dataprovider.UpdateAdmin(&admin, "", "", "")
  675. assert.NoError(t, err)
  676. err = token.refreshUser(r)
  677. assert.NoError(t, err)
  678. assert.Equal(t, admin.Permissions, token.Permissions)
  679. assert.Equal(t, admin.Filters.Preferences.HideUserPageSections, token.HideUserPageSections)
  680. err = dataprovider.DeleteAdmin(admin.Username, "", "", "")
  681. assert.NoError(t, err)
  682. username := "test_oidc_user_refresh_token"
  683. user := dataprovider.User{
  684. BaseUser: sdk.BaseUser{
  685. Username: username,
  686. Password: "p",
  687. HomeDir: filepath.Join(os.TempDir(), username),
  688. Status: 0,
  689. Permissions: map[string][]string{
  690. "/": {dataprovider.PermAny},
  691. },
  692. },
  693. Filters: dataprovider.UserFilters{
  694. BaseUserFilters: sdk.BaseUserFilters{
  695. DeniedProtocols: []string{common.ProtocolHTTP},
  696. WebClient: []string{sdk.WebClientSharesDisabled, sdk.WebClientWriteDisabled},
  697. },
  698. },
  699. }
  700. err = dataprovider.AddUser(&user, "", "", "")
  701. assert.NoError(t, err)
  702. r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
  703. assert.NoError(t, err)
  704. token.Role = nil
  705. token.Username = username
  706. assert.False(t, token.isAdmin())
  707. err = token.refreshUser(r)
  708. if assert.Error(t, err) {
  709. assert.Contains(t, err.Error(), "is disabled")
  710. }
  711. user, err = dataprovider.UserExists(username, "")
  712. assert.NoError(t, err)
  713. user.Status = 1
  714. err = dataprovider.UpdateUser(&user, "", "", "")
  715. assert.NoError(t, err)
  716. err = token.refreshUser(r)
  717. if assert.Error(t, err) {
  718. assert.Contains(t, err.Error(), "protocol HTTP is not allowed")
  719. }
  720. user.Filters.DeniedProtocols = []string{common.ProtocolFTP}
  721. err = dataprovider.UpdateUser(&user, "", "", "")
  722. assert.NoError(t, err)
  723. err = token.refreshUser(r)
  724. assert.NoError(t, err)
  725. assert.Equal(t, user.Filters.WebClient, token.Permissions)
  726. err = dataprovider.DeleteUser(username, "", "", "")
  727. assert.NoError(t, err)
  728. }
  729. func TestValidateOIDCToken(t *testing.T) {
  730. oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
  731. require.True(t, ok)
  732. server := getTestOIDCServer()
  733. err := server.binding.OIDC.initialize()
  734. assert.NoError(t, err)
  735. server.initializeRouter()
  736. rr := httptest.NewRecorder()
  737. r, err := http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  738. assert.NoError(t, err)
  739. _, err = server.validateOIDCToken(rr, r, false)
  740. assert.ErrorIs(t, err, errInvalidToken)
  741. // expired token and refresh error
  742. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  743. tokenSource: &mockTokenSource{
  744. err: common.ErrGenericFailure,
  745. },
  746. }
  747. token := oidcToken{
  748. Cookie: util.GenerateOpaqueString(),
  749. AccessToken: xid.New().String(),
  750. ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-2 * time.Minute)),
  751. }
  752. oidcMgr.addToken(token)
  753. rr = httptest.NewRecorder()
  754. r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  755. assert.NoError(t, err)
  756. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie))
  757. _, err = server.validateOIDCToken(rr, r, false)
  758. assert.ErrorIs(t, err, errInvalidToken)
  759. oidcMgr.removeToken(token.Cookie)
  760. assert.Len(t, oidcMgr.tokens, 0)
  761. server.tokenAuth = jwtauth.New("PS256", util.GenerateRandomBytes(32), nil)
  762. token = oidcToken{
  763. Cookie: util.GenerateOpaqueString(),
  764. AccessToken: util.GenerateUniqueID(),
  765. }
  766. oidcMgr.addToken(token)
  767. rr = httptest.NewRecorder()
  768. r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  769. assert.NoError(t, err)
  770. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie))
  771. server.router.ServeHTTP(rr, r)
  772. assert.Equal(t, http.StatusFound, rr.Code)
  773. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  774. oidcMgr.removeToken(token.Cookie)
  775. assert.Len(t, oidcMgr.tokens, 0)
  776. token = oidcToken{
  777. Cookie: util.GenerateOpaqueString(),
  778. AccessToken: xid.New().String(),
  779. Role: "admin",
  780. }
  781. oidcMgr.addToken(token)
  782. rr = httptest.NewRecorder()
  783. r, err = http.NewRequest(http.MethodGet, webLogoutPath, nil)
  784. assert.NoError(t, err)
  785. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie))
  786. server.router.ServeHTTP(rr, r)
  787. assert.Equal(t, http.StatusFound, rr.Code)
  788. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  789. oidcMgr.removeToken(token.Cookie)
  790. assert.Len(t, oidcMgr.tokens, 0)
  791. }
  792. func TestSkipOIDCAuth(t *testing.T) {
  793. server := getTestOIDCServer()
  794. err := server.binding.OIDC.initialize()
  795. assert.NoError(t, err)
  796. server.initializeRouter()
  797. jwtTokenClaims := jwtTokenClaims{
  798. Username: "user",
  799. }
  800. _, tokenString, err := jwtTokenClaims.createToken(server.tokenAuth, tokenAudienceWebClient, "")
  801. assert.NoError(t, err)
  802. rr := httptest.NewRecorder()
  803. r, err := http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  804. assert.NoError(t, err)
  805. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", jwtCookieKey, tokenString))
  806. server.router.ServeHTTP(rr, r)
  807. assert.Equal(t, http.StatusFound, rr.Code)
  808. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  809. }
  810. func TestOIDCLogoutErrors(t *testing.T) {
  811. server := getTestOIDCServer()
  812. assert.Empty(t, server.binding.OIDC.providerLogoutURL)
  813. server.logoutFromOIDCOP("")
  814. server.binding.OIDC.providerLogoutURL = "http://foo\x7f.com/"
  815. server.doOIDCFromLogout("")
  816. server.binding.OIDC.providerLogoutURL = "http://127.0.0.1:11234"
  817. server.doOIDCFromLogout("")
  818. }
  819. func TestOIDCToken(t *testing.T) {
  820. admin := dataprovider.Admin{
  821. Username: "test_oidc_admin",
  822. Password: "p",
  823. Permissions: []string{dataprovider.PermAdminAny},
  824. Status: 0,
  825. }
  826. err := dataprovider.AddAdmin(&admin, "", "", "")
  827. assert.NoError(t, err)
  828. token := oidcToken{
  829. Username: admin.Username,
  830. }
  831. // role not initialized, user with the specified username does not exist
  832. req, err := http.NewRequest(http.MethodGet, webUsersPath, nil)
  833. assert.NoError(t, err)
  834. err = token.getUser(req)
  835. assert.ErrorIs(t, err, util.ErrNotFound)
  836. token.Role = "admin"
  837. req, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
  838. assert.NoError(t, err)
  839. err = token.getUser(req)
  840. if assert.Error(t, err) {
  841. assert.Contains(t, err.Error(), "is disabled")
  842. }
  843. err = dataprovider.DeleteAdmin(admin.Username, "", "", "")
  844. assert.NoError(t, err)
  845. username := "test_oidc_user"
  846. token.Username = username
  847. token.Role = ""
  848. err = token.getUser(req)
  849. if assert.Error(t, err) {
  850. assert.ErrorIs(t, err, util.ErrNotFound)
  851. }
  852. user := dataprovider.User{
  853. BaseUser: sdk.BaseUser{
  854. Username: username,
  855. Password: "p",
  856. HomeDir: filepath.Join(os.TempDir(), username),
  857. Status: 0,
  858. Permissions: map[string][]string{
  859. "/": {dataprovider.PermAny},
  860. },
  861. },
  862. Filters: dataprovider.UserFilters{
  863. BaseUserFilters: sdk.BaseUserFilters{
  864. DeniedProtocols: []string{common.ProtocolHTTP},
  865. },
  866. },
  867. }
  868. err = dataprovider.AddUser(&user, "", "", "")
  869. assert.NoError(t, err)
  870. err = token.getUser(req)
  871. if assert.Error(t, err) {
  872. assert.Contains(t, err.Error(), "is disabled")
  873. }
  874. user, err = dataprovider.UserExists(username, "")
  875. assert.NoError(t, err)
  876. user.Status = 1
  877. user.Password = "np"
  878. err = dataprovider.UpdateUser(&user, "", "", "")
  879. assert.NoError(t, err)
  880. err = token.getUser(req)
  881. if assert.Error(t, err) {
  882. assert.Contains(t, err.Error(), "protocol HTTP is not allowed")
  883. }
  884. user.Filters.DeniedProtocols = nil
  885. user.FsConfig.Provider = sdk.SFTPFilesystemProvider
  886. user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{
  887. BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{
  888. Endpoint: "127.0.0.1:8022",
  889. Username: username,
  890. },
  891. Password: kms.NewPlainSecret("np"),
  892. }
  893. err = dataprovider.UpdateUser(&user, "", "", "")
  894. assert.NoError(t, err)
  895. err = token.getUser(req)
  896. if assert.Error(t, err) {
  897. assert.Contains(t, err.Error(), "SFTP loop")
  898. }
  899. common.Config.PostConnectHook = fmt.Sprintf("http://%v/404", oidcMockAddr)
  900. err = token.getUser(req)
  901. if assert.Error(t, err) {
  902. assert.Contains(t, err.Error(), "access denied")
  903. }
  904. common.Config.PostConnectHook = ""
  905. err = os.RemoveAll(user.GetHomeDir())
  906. assert.NoError(t, err)
  907. err = dataprovider.DeleteUser(username, "", "", "")
  908. assert.NoError(t, err)
  909. }
  910. func TestOIDCImplicitRoles(t *testing.T) {
  911. oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
  912. require.True(t, ok)
  913. server := getTestOIDCServer()
  914. server.binding.OIDC.ImplicitRoles = true
  915. err := server.binding.OIDC.initialize()
  916. assert.NoError(t, err)
  917. server.initializeRouter()
  918. authReq := newOIDCPendingAuth(tokenAudienceWebAdmin)
  919. oidcMgr.addPendingAuth(authReq)
  920. token := &oauth2.Token{
  921. AccessToken: "1234",
  922. Expiry: time.Now().Add(5 * time.Minute),
  923. }
  924. token = token.WithExtra(map[string]any{
  925. "id_token": "id_token_val",
  926. })
  927. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  928. tokenSource: &mockTokenSource{},
  929. authCodeURL: webOIDCRedirectPath,
  930. token: token,
  931. }
  932. idToken := &oidc.IDToken{
  933. Nonce: authReq.Nonce,
  934. Expiry: time.Now().Add(5 * time.Minute),
  935. }
  936. setIDTokenClaims(idToken, []byte(`{"preferred_username":"admin","sid":"sid456"}`))
  937. server.binding.OIDC.verifier = &mockOIDCVerifier{
  938. err: nil,
  939. token: idToken,
  940. }
  941. rr := httptest.NewRecorder()
  942. r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  943. assert.NoError(t, err)
  944. server.router.ServeHTTP(rr, r)
  945. assert.Equal(t, http.StatusFound, rr.Code)
  946. assert.Equal(t, webUsersPath, rr.Header().Get("Location"))
  947. require.Len(t, oidcMgr.pendingAuths, 0)
  948. require.Len(t, oidcMgr.tokens, 1)
  949. var tokenCookie string
  950. for k := range oidcMgr.tokens {
  951. tokenCookie = k
  952. }
  953. // Web Client is not available with an admin token
  954. rr = httptest.NewRecorder()
  955. r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
  956. assert.NoError(t, err)
  957. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  958. server.router.ServeHTTP(rr, r)
  959. assert.Equal(t, http.StatusFound, rr.Code)
  960. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  961. // logout the admin user
  962. rr = httptest.NewRecorder()
  963. r, err = http.NewRequest(http.MethodGet, webLogoutPath, nil)
  964. assert.NoError(t, err)
  965. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  966. server.router.ServeHTTP(rr, r)
  967. assert.Equal(t, http.StatusFound, rr.Code)
  968. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  969. require.Len(t, oidcMgr.pendingAuths, 0)
  970. require.Len(t, oidcMgr.tokens, 0)
  971. // now login and logout a user
  972. username := "test_oidc_implicit_user"
  973. user := dataprovider.User{
  974. BaseUser: sdk.BaseUser{
  975. Username: username,
  976. Password: "pwd",
  977. HomeDir: filepath.Join(os.TempDir(), username),
  978. Status: 1,
  979. Permissions: map[string][]string{
  980. "/": {dataprovider.PermAny},
  981. },
  982. },
  983. Filters: dataprovider.UserFilters{
  984. BaseUserFilters: sdk.BaseUserFilters{
  985. WebClient: []string{sdk.WebClientSharesDisabled},
  986. },
  987. },
  988. }
  989. err = dataprovider.AddUser(&user, "", "", "")
  990. assert.NoError(t, err)
  991. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  992. oidcMgr.addPendingAuth(authReq)
  993. idToken = &oidc.IDToken{
  994. Nonce: authReq.Nonce,
  995. Expiry: time.Now().Add(5 * time.Minute),
  996. }
  997. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test_oidc_implicit_user"}`))
  998. server.binding.OIDC.verifier = &mockOIDCVerifier{
  999. err: nil,
  1000. token: idToken,
  1001. }
  1002. rr = httptest.NewRecorder()
  1003. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  1004. assert.NoError(t, err)
  1005. server.router.ServeHTTP(rr, r)
  1006. assert.Equal(t, http.StatusFound, rr.Code)
  1007. assert.Equal(t, webClientFilesPath, rr.Header().Get("Location"))
  1008. require.Len(t, oidcMgr.pendingAuths, 0)
  1009. require.Len(t, oidcMgr.tokens, 1)
  1010. for k := range oidcMgr.tokens {
  1011. tokenCookie = k
  1012. }
  1013. rr = httptest.NewRecorder()
  1014. r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  1015. assert.NoError(t, err)
  1016. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  1017. server.router.ServeHTTP(rr, r)
  1018. assert.Equal(t, http.StatusFound, rr.Code)
  1019. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  1020. require.Len(t, oidcMgr.pendingAuths, 0)
  1021. require.Len(t, oidcMgr.tokens, 0)
  1022. err = os.RemoveAll(user.GetHomeDir())
  1023. assert.NoError(t, err)
  1024. err = dataprovider.DeleteUser(username, "", "", "")
  1025. assert.NoError(t, err)
  1026. }
  1027. func TestMemoryOIDCManager(t *testing.T) {
  1028. oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
  1029. require.True(t, ok)
  1030. require.Len(t, oidcMgr.pendingAuths, 0)
  1031. authReq := newOIDCPendingAuth(tokenAudienceWebAdmin)
  1032. oidcMgr.addPendingAuth(authReq)
  1033. require.Len(t, oidcMgr.pendingAuths, 1)
  1034. _, err := oidcMgr.getPendingAuth(authReq.State)
  1035. assert.NoError(t, err)
  1036. oidcMgr.removePendingAuth(authReq.State)
  1037. require.Len(t, oidcMgr.pendingAuths, 0)
  1038. authReq.IssuedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-61 * time.Second))
  1039. oidcMgr.addPendingAuth(authReq)
  1040. require.Len(t, oidcMgr.pendingAuths, 1)
  1041. _, err = oidcMgr.getPendingAuth(authReq.State)
  1042. if assert.Error(t, err) {
  1043. assert.Contains(t, err.Error(), "too old")
  1044. }
  1045. oidcMgr.cleanup()
  1046. require.Len(t, oidcMgr.pendingAuths, 0)
  1047. token := oidcToken{
  1048. AccessToken: xid.New().String(),
  1049. Nonce: xid.New().String(),
  1050. SessionID: xid.New().String(),
  1051. Cookie: util.GenerateOpaqueString(),
  1052. Username: xid.New().String(),
  1053. Role: "admin",
  1054. Permissions: []string{dataprovider.PermAdminAny},
  1055. }
  1056. require.Len(t, oidcMgr.tokens, 0)
  1057. oidcMgr.addToken(token)
  1058. require.Len(t, oidcMgr.tokens, 1)
  1059. _, err = oidcMgr.getToken(xid.New().String())
  1060. assert.Error(t, err)
  1061. storedToken, err := oidcMgr.getToken(token.Cookie)
  1062. assert.NoError(t, err)
  1063. token.UsedAt = 0 // ensure we don't modify the stored token
  1064. assert.Greater(t, storedToken.UsedAt, int64(0))
  1065. token.UsedAt = storedToken.UsedAt
  1066. assert.Equal(t, token, storedToken)
  1067. // the usage will not be updated, it is recent
  1068. oidcMgr.updateTokenUsage(storedToken)
  1069. storedToken, err = oidcMgr.getToken(token.Cookie)
  1070. assert.NoError(t, err)
  1071. assert.Equal(t, token, storedToken)
  1072. usedAt := util.GetTimeAsMsSinceEpoch(time.Now().Add(-5 * time.Minute))
  1073. storedToken.UsedAt = usedAt
  1074. oidcMgr.tokens[token.Cookie] = storedToken
  1075. storedToken, err = oidcMgr.getToken(token.Cookie)
  1076. assert.NoError(t, err)
  1077. assert.Equal(t, usedAt, storedToken.UsedAt)
  1078. token.UsedAt = storedToken.UsedAt
  1079. assert.Equal(t, token, storedToken)
  1080. oidcMgr.updateTokenUsage(storedToken)
  1081. storedToken, err = oidcMgr.getToken(token.Cookie)
  1082. assert.NoError(t, err)
  1083. assert.Greater(t, storedToken.UsedAt, usedAt)
  1084. token.UsedAt = storedToken.UsedAt
  1085. assert.Equal(t, token, storedToken)
  1086. storedToken.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now()) - tokenDeleteInterval - 1
  1087. oidcMgr.tokens[token.Cookie] = storedToken
  1088. storedToken, err = oidcMgr.getToken(token.Cookie)
  1089. if assert.Error(t, err) {
  1090. assert.Contains(t, err.Error(), "token is too old")
  1091. }
  1092. oidcMgr.removeToken(xid.New().String())
  1093. require.Len(t, oidcMgr.tokens, 1)
  1094. oidcMgr.removeToken(token.Cookie)
  1095. require.Len(t, oidcMgr.tokens, 0)
  1096. oidcMgr.addToken(token)
  1097. usedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-6 * time.Hour))
  1098. token.UsedAt = usedAt
  1099. oidcMgr.tokens[token.Cookie] = token
  1100. newToken := oidcToken{
  1101. Cookie: util.GenerateOpaqueString(),
  1102. }
  1103. oidcMgr.addToken(newToken)
  1104. oidcMgr.cleanup()
  1105. require.Len(t, oidcMgr.tokens, 1)
  1106. _, err = oidcMgr.getToken(token.Cookie)
  1107. assert.Error(t, err)
  1108. _, err = oidcMgr.getToken(newToken.Cookie)
  1109. assert.NoError(t, err)
  1110. oidcMgr.removeToken(newToken.Cookie)
  1111. require.Len(t, oidcMgr.tokens, 0)
  1112. }
  1113. func TestOIDCEvMgrIntegration(t *testing.T) {
  1114. providerConf := dataprovider.GetProviderConfig()
  1115. err := dataprovider.Close()
  1116. assert.NoError(t, err)
  1117. newProviderConf := providerConf
  1118. newProviderConf.NamingRules = 5
  1119. err = dataprovider.Initialize(newProviderConf, configDir, true)
  1120. assert.NoError(t, err)
  1121. // add a special chars to check json replacer
  1122. username := `test_"oidc_eventmanager`
  1123. u := map[string]any{
  1124. "username": "{{Name}}",
  1125. "status": 1,
  1126. "home_dir": filepath.Join(os.TempDir(), "{{IDPFieldcustom1.sub}}"),
  1127. "permissions": map[string][]string{
  1128. "/": {dataprovider.PermAny},
  1129. },
  1130. "description": "{{IDPFieldcustom2}}",
  1131. }
  1132. userTmpl, err := json.Marshal(u)
  1133. require.NoError(t, err)
  1134. a := map[string]any{
  1135. "username": "{{Name}}",
  1136. "status": 1,
  1137. "permissions": []string{dataprovider.PermAdminAny},
  1138. }
  1139. adminTmpl, err := json.Marshal(a)
  1140. require.NoError(t, err)
  1141. action := &dataprovider.BaseEventAction{
  1142. Name: "a",
  1143. Type: dataprovider.ActionTypeIDPAccountCheck,
  1144. Options: dataprovider.BaseEventActionOptions{
  1145. IDPConfig: dataprovider.EventActionIDPAccountCheck{
  1146. Mode: 0,
  1147. TemplateUser: string(userTmpl),
  1148. TemplateAdmin: string(adminTmpl),
  1149. },
  1150. },
  1151. }
  1152. err = dataprovider.AddEventAction(action, "", "", "")
  1153. assert.NoError(t, err)
  1154. rule := &dataprovider.EventRule{
  1155. Name: "r",
  1156. Status: 1,
  1157. Trigger: dataprovider.EventTriggerIDPLogin,
  1158. Conditions: dataprovider.EventConditions{
  1159. IDPLoginEvent: 0,
  1160. },
  1161. Actions: []dataprovider.EventAction{
  1162. {
  1163. BaseEventAction: dataprovider.BaseEventAction{
  1164. Name: action.Name,
  1165. },
  1166. Options: dataprovider.EventActionOptions{
  1167. ExecuteSync: true,
  1168. },
  1169. },
  1170. },
  1171. }
  1172. err = dataprovider.AddEventRule(rule, "", "", "")
  1173. assert.NoError(t, err)
  1174. oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
  1175. require.True(t, ok)
  1176. server := getTestOIDCServer()
  1177. server.binding.OIDC.ImplicitRoles = true
  1178. server.binding.OIDC.CustomFields = []string{"custom1.sub", "custom2"}
  1179. err = server.binding.OIDC.initialize()
  1180. assert.NoError(t, err)
  1181. server.initializeRouter()
  1182. // login a user with OIDC
  1183. _, err = dataprovider.UserExists(username, "")
  1184. assert.ErrorIs(t, err, util.ErrNotFound)
  1185. authReq := newOIDCPendingAuth(tokenAudienceWebClient)
  1186. oidcMgr.addPendingAuth(authReq)
  1187. token := &oauth2.Token{
  1188. AccessToken: "1234",
  1189. Expiry: time.Now().Add(5 * time.Minute),
  1190. }
  1191. token = token.WithExtra(map[string]any{
  1192. "id_token": "id_token_val",
  1193. })
  1194. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  1195. tokenSource: &mockTokenSource{},
  1196. authCodeURL: webOIDCRedirectPath,
  1197. token: token,
  1198. }
  1199. idToken := &oidc.IDToken{
  1200. Nonce: authReq.Nonce,
  1201. Expiry: time.Now().Add(5 * time.Minute),
  1202. }
  1203. setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+util.JSONEscape(username)+`","custom1":{"sub":"val1"},"custom2":"desc"}`)) //nolint:goconst
  1204. server.binding.OIDC.verifier = &mockOIDCVerifier{
  1205. err: nil,
  1206. token: idToken,
  1207. }
  1208. rr := httptest.NewRecorder()
  1209. r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  1210. assert.NoError(t, err)
  1211. server.router.ServeHTTP(rr, r)
  1212. assert.Equal(t, http.StatusFound, rr.Code)
  1213. assert.Equal(t, webClientFilesPath, rr.Header().Get("Location"))
  1214. user, err := dataprovider.UserExists(username, "")
  1215. assert.NoError(t, err)
  1216. assert.Equal(t, filepath.Join(os.TempDir(), "val1"), user.GetHomeDir())
  1217. assert.Equal(t, "desc", user.Description)
  1218. err = dataprovider.DeleteUser(username, "", "", "")
  1219. assert.NoError(t, err)
  1220. err = os.RemoveAll(user.GetHomeDir())
  1221. assert.NoError(t, err)
  1222. // login an admin with OIDC
  1223. _, err = dataprovider.AdminExists(username)
  1224. assert.ErrorIs(t, err, util.ErrNotFound)
  1225. authReq = newOIDCPendingAuth(tokenAudienceWebAdmin)
  1226. oidcMgr.addPendingAuth(authReq)
  1227. idToken = &oidc.IDToken{
  1228. Nonce: authReq.Nonce,
  1229. Expiry: time.Now().Add(5 * time.Minute),
  1230. }
  1231. setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+util.JSONEscape(username)+`"}`))
  1232. server.binding.OIDC.verifier = &mockOIDCVerifier{
  1233. err: nil,
  1234. token: idToken,
  1235. }
  1236. rr = httptest.NewRecorder()
  1237. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  1238. assert.NoError(t, err)
  1239. server.router.ServeHTTP(rr, r)
  1240. assert.Equal(t, http.StatusFound, rr.Code)
  1241. assert.Equal(t, webUsersPath, rr.Header().Get("Location"))
  1242. _, err = dataprovider.AdminExists(username)
  1243. assert.NoError(t, err)
  1244. err = dataprovider.DeleteAdmin(username, "", "", "")
  1245. assert.NoError(t, err)
  1246. // set invalid templates and try again
  1247. action.Options.IDPConfig.TemplateUser = `{}`
  1248. action.Options.IDPConfig.TemplateAdmin = `{}`
  1249. err = dataprovider.UpdateEventAction(action, "", "", "")
  1250. assert.NoError(t, err)
  1251. for _, audience := range []string{tokenAudienceWebAdmin, tokenAudienceWebClient} {
  1252. authReq = newOIDCPendingAuth(audience)
  1253. oidcMgr.addPendingAuth(authReq)
  1254. idToken = &oidc.IDToken{
  1255. Nonce: authReq.Nonce,
  1256. Expiry: time.Now().Add(5 * time.Minute),
  1257. }
  1258. setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+util.JSONEscape(username)+`"}`))
  1259. server.binding.OIDC.verifier = &mockOIDCVerifier{
  1260. err: nil,
  1261. token: idToken,
  1262. }
  1263. rr = httptest.NewRecorder()
  1264. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  1265. assert.NoError(t, err)
  1266. server.router.ServeHTTP(rr, r)
  1267. assert.Equal(t, http.StatusFound, rr.Code)
  1268. }
  1269. for k := range oidcMgr.tokens {
  1270. oidcMgr.removeToken(k)
  1271. }
  1272. err = dataprovider.DeleteEventRule(rule.Name, "", "", "")
  1273. assert.NoError(t, err)
  1274. err = dataprovider.DeleteEventAction(action.Name, "", "", "")
  1275. assert.NoError(t, err)
  1276. err = dataprovider.Close()
  1277. assert.NoError(t, err)
  1278. err = dataprovider.Initialize(providerConf, configDir, true)
  1279. assert.NoError(t, err)
  1280. }
  1281. func TestOIDCPreLoginHook(t *testing.T) {
  1282. if runtime.GOOS == osWindows {
  1283. t.Skip("this test is not available on Windows")
  1284. }
  1285. oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
  1286. require.True(t, ok)
  1287. username := "test_oidc_user_prelogin"
  1288. u := dataprovider.User{
  1289. BaseUser: sdk.BaseUser{
  1290. Username: username,
  1291. HomeDir: filepath.Join(os.TempDir(), username),
  1292. Status: 1,
  1293. Permissions: map[string][]string{
  1294. "/": {dataprovider.PermAny},
  1295. },
  1296. },
  1297. }
  1298. preLoginPath := filepath.Join(os.TempDir(), "prelogin.sh")
  1299. providerConf := dataprovider.GetProviderConfig()
  1300. err := dataprovider.Close()
  1301. assert.NoError(t, err)
  1302. err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm)
  1303. assert.NoError(t, err)
  1304. newProviderConf := providerConf
  1305. newProviderConf.PreLoginHook = preLoginPath
  1306. err = dataprovider.Initialize(newProviderConf, configDir, true)
  1307. assert.NoError(t, err)
  1308. server := getTestOIDCServer()
  1309. server.binding.OIDC.CustomFields = []string{"field1", "field2"}
  1310. err = server.binding.OIDC.initialize()
  1311. assert.NoError(t, err)
  1312. server.initializeRouter()
  1313. _, err = dataprovider.UserExists(username, "")
  1314. assert.ErrorIs(t, err, util.ErrNotFound)
  1315. // now login with OIDC
  1316. authReq := newOIDCPendingAuth(tokenAudienceWebClient)
  1317. oidcMgr.addPendingAuth(authReq)
  1318. token := &oauth2.Token{
  1319. AccessToken: "1234",
  1320. Expiry: time.Now().Add(5 * time.Minute),
  1321. }
  1322. token = token.WithExtra(map[string]any{
  1323. "id_token": "id_token_val",
  1324. })
  1325. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  1326. tokenSource: &mockTokenSource{},
  1327. authCodeURL: webOIDCRedirectPath,
  1328. token: token,
  1329. }
  1330. idToken := &oidc.IDToken{
  1331. Nonce: authReq.Nonce,
  1332. Expiry: time.Now().Add(5 * time.Minute),
  1333. }
  1334. setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+username+`"}`))
  1335. server.binding.OIDC.verifier = &mockOIDCVerifier{
  1336. err: nil,
  1337. token: idToken,
  1338. }
  1339. rr := httptest.NewRecorder()
  1340. r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  1341. assert.NoError(t, err)
  1342. server.router.ServeHTTP(rr, r)
  1343. assert.Equal(t, http.StatusFound, rr.Code)
  1344. assert.Equal(t, webClientFilesPath, rr.Header().Get("Location"))
  1345. _, err = dataprovider.UserExists(username, "")
  1346. assert.NoError(t, err)
  1347. err = dataprovider.DeleteUser(username, "", "", "")
  1348. assert.NoError(t, err)
  1349. err = os.RemoveAll(u.HomeDir)
  1350. assert.NoError(t, err)
  1351. err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, true), os.ModePerm)
  1352. assert.NoError(t, err)
  1353. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  1354. oidcMgr.addPendingAuth(authReq)
  1355. idToken = &oidc.IDToken{
  1356. Nonce: authReq.Nonce,
  1357. Expiry: time.Now().Add(5 * time.Minute),
  1358. }
  1359. setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+username+`","field1":"value1","field2":"value2","field3":"value3"}`))
  1360. server.binding.OIDC.verifier = &mockOIDCVerifier{
  1361. err: nil,
  1362. token: idToken,
  1363. }
  1364. rr = httptest.NewRecorder()
  1365. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  1366. assert.NoError(t, err)
  1367. server.router.ServeHTTP(rr, r)
  1368. assert.Equal(t, http.StatusFound, rr.Code)
  1369. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  1370. _, err = dataprovider.UserExists(username, "")
  1371. assert.ErrorIs(t, err, util.ErrNotFound)
  1372. if assert.Len(t, oidcMgr.tokens, 1) {
  1373. for k := range oidcMgr.tokens {
  1374. oidcMgr.removeToken(k)
  1375. }
  1376. }
  1377. require.Len(t, oidcMgr.pendingAuths, 0)
  1378. require.Len(t, oidcMgr.tokens, 0)
  1379. err = dataprovider.Close()
  1380. assert.NoError(t, err)
  1381. err = dataprovider.Initialize(providerConf, configDir, true)
  1382. assert.NoError(t, err)
  1383. err = os.Remove(preLoginPath)
  1384. assert.NoError(t, err)
  1385. }
  1386. func TestOIDCIsAdmin(t *testing.T) {
  1387. type test struct {
  1388. input any
  1389. want bool
  1390. }
  1391. emptySlice := make([]any, 0)
  1392. tests := []test{
  1393. {input: "admin", want: true},
  1394. {input: append(emptySlice, "admin"), want: true},
  1395. {input: append(emptySlice, "user", "admin"), want: true},
  1396. {input: "user", want: false},
  1397. {input: emptySlice, want: false},
  1398. {input: append(emptySlice, 1), want: false},
  1399. {input: 1, want: false},
  1400. {input: nil, want: false},
  1401. {input: map[string]string{"admin": "admin"}, want: false},
  1402. }
  1403. for _, tc := range tests {
  1404. token := oidcToken{
  1405. Role: tc.input,
  1406. }
  1407. assert.Equal(t, tc.want, token.isAdmin(), "%v should return %t", tc.input, tc.want)
  1408. }
  1409. }
  1410. func TestParseAdminRole(t *testing.T) {
  1411. claims := make(map[string]any)
  1412. rawClaims := []byte(`{
  1413. "sub": "35666371",
  1414. "email": "[email protected]",
  1415. "preferred_username": "Sally",
  1416. "name": "Sally Tyler",
  1417. "updated_at": "2018-04-13T22:08:45Z",
  1418. "given_name": "Sally",
  1419. "family_name": "Tyler",
  1420. "params": {
  1421. "sftpgo_role": "admin",
  1422. "subparams": {
  1423. "sftpgo_role": "admin",
  1424. "inner": {
  1425. "sftpgo_role": ["user","admin"]
  1426. }
  1427. }
  1428. },
  1429. "at_hash": "lPLhxI2wjEndc-WfyroDZA",
  1430. "rt_hash": "mCmxPtA04N-55AxlEUbq-A",
  1431. "aud": "78d1d040-20c9-0136-5146-067351775fae92920",
  1432. "exp": 1523664997,
  1433. "iat": 1523657797
  1434. }`)
  1435. err := json.Unmarshal(rawClaims, &claims)
  1436. assert.NoError(t, err)
  1437. type test struct {
  1438. input string
  1439. want bool
  1440. val any
  1441. }
  1442. tests := []test{
  1443. {input: "", want: false},
  1444. {input: "sftpgo_role", want: false},
  1445. {input: "params.sftpgo_role", want: true, val: "admin"},
  1446. {input: "params.subparams.sftpgo_role", want: true, val: "admin"},
  1447. {input: "params.subparams.inner.sftpgo_role", want: true, val: []any{"user", "admin"}},
  1448. {input: "email", want: false},
  1449. {input: "missing", want: false},
  1450. {input: "params.email", want: false},
  1451. {input: "missing.sftpgo_role", want: false},
  1452. {input: "params", want: false},
  1453. {input: "params.subparams.inner.sftpgo_role.missing", want: false},
  1454. }
  1455. for _, tc := range tests {
  1456. token := oidcToken{}
  1457. token.getRoleFromField(claims, tc.input)
  1458. assert.Equal(t, tc.want, token.isAdmin(), "%q should return %t", tc.input, tc.want)
  1459. if tc.want {
  1460. assert.Equal(t, tc.val, token.Role)
  1461. }
  1462. }
  1463. }
  1464. func TestOIDCWithLoginFormsDisabled(t *testing.T) {
  1465. oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
  1466. require.True(t, ok)
  1467. server := getTestOIDCServer()
  1468. server.binding.OIDC.ImplicitRoles = true
  1469. server.binding.DisabledLoginMethods = 12
  1470. server.binding.EnableWebAdmin = true
  1471. server.binding.EnableWebClient = true
  1472. err := server.binding.OIDC.initialize()
  1473. assert.NoError(t, err)
  1474. server.initializeRouter()
  1475. // login with an admin user
  1476. authReq := newOIDCPendingAuth(tokenAudienceWebAdmin)
  1477. oidcMgr.addPendingAuth(authReq)
  1478. token := &oauth2.Token{
  1479. AccessToken: "1234",
  1480. Expiry: time.Now().Add(5 * time.Minute),
  1481. }
  1482. token = token.WithExtra(map[string]any{
  1483. "id_token": "id_token_val",
  1484. })
  1485. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  1486. tokenSource: &mockTokenSource{},
  1487. authCodeURL: webOIDCRedirectPath,
  1488. token: token,
  1489. }
  1490. idToken := &oidc.IDToken{
  1491. Nonce: authReq.Nonce,
  1492. Expiry: time.Now().Add(5 * time.Minute),
  1493. }
  1494. setIDTokenClaims(idToken, []byte(`{"preferred_username":"admin","sid":"sid456"}`))
  1495. server.binding.OIDC.verifier = &mockOIDCVerifier{
  1496. err: nil,
  1497. token: idToken,
  1498. }
  1499. rr := httptest.NewRecorder()
  1500. r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  1501. assert.NoError(t, err)
  1502. server.router.ServeHTTP(rr, r)
  1503. assert.Equal(t, http.StatusFound, rr.Code)
  1504. assert.Equal(t, webUsersPath, rr.Header().Get("Location"))
  1505. var tokenCookie string
  1506. for k := range oidcMgr.tokens {
  1507. tokenCookie = k
  1508. }
  1509. // we should be able to create admins without setting a password
  1510. adminUsername := "testAdmin"
  1511. form := make(url.Values)
  1512. form.Set(csrfFormToken, createCSRFToken(rr, r, server.csrfTokenAuth, tokenCookie, webBaseAdminPath))
  1513. form.Set("username", adminUsername)
  1514. form.Set("password", "")
  1515. form.Set("status", "1")
  1516. form.Set("permissions", "*")
  1517. rr = httptest.NewRecorder()
  1518. r, err = http.NewRequest(http.MethodPost, webAdminPath, bytes.NewBuffer([]byte(form.Encode())))
  1519. assert.NoError(t, err)
  1520. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  1521. r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  1522. server.router.ServeHTTP(rr, r)
  1523. assert.Equal(t, http.StatusSeeOther, rr.Code)
  1524. _, err = dataprovider.AdminExists(adminUsername)
  1525. assert.NoError(t, err)
  1526. err = dataprovider.DeleteAdmin(adminUsername, "", "", "")
  1527. assert.NoError(t, err)
  1528. // login and password related routes are disabled
  1529. rr = httptest.NewRecorder()
  1530. r, err = http.NewRequest(http.MethodPost, webAdminLoginPath, nil)
  1531. assert.NoError(t, err)
  1532. server.router.ServeHTTP(rr, r)
  1533. assert.Equal(t, http.StatusMethodNotAllowed, rr.Code)
  1534. rr = httptest.NewRecorder()
  1535. r, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, nil)
  1536. assert.NoError(t, err)
  1537. server.router.ServeHTTP(rr, r)
  1538. assert.Equal(t, http.StatusNotFound, rr.Code)
  1539. rr = httptest.NewRecorder()
  1540. r, err = http.NewRequest(http.MethodPost, webClientLoginPath, nil)
  1541. assert.NoError(t, err)
  1542. server.router.ServeHTTP(rr, r)
  1543. assert.Equal(t, http.StatusMethodNotAllowed, rr.Code)
  1544. rr = httptest.NewRecorder()
  1545. r, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, nil)
  1546. assert.NoError(t, err)
  1547. server.router.ServeHTTP(rr, r)
  1548. assert.Equal(t, http.StatusNotFound, rr.Code)
  1549. }
  1550. func TestDbOIDCManager(t *testing.T) {
  1551. if !isSharedProviderSupported() {
  1552. t.Skip("this test it is not available with this provider")
  1553. }
  1554. mgr := newOIDCManager(1)
  1555. pendingAuth := newOIDCPendingAuth(tokenAudienceWebAdmin)
  1556. mgr.addPendingAuth(pendingAuth)
  1557. authReq, err := mgr.getPendingAuth(pendingAuth.State)
  1558. assert.NoError(t, err)
  1559. assert.Equal(t, pendingAuth, authReq)
  1560. pendingAuth.IssuedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour))
  1561. mgr.addPendingAuth(pendingAuth)
  1562. _, err = mgr.getPendingAuth(pendingAuth.State)
  1563. if assert.Error(t, err) {
  1564. assert.Contains(t, err.Error(), "auth request is too old")
  1565. }
  1566. mgr.removePendingAuth(pendingAuth.State)
  1567. _, err = mgr.getPendingAuth(pendingAuth.State)
  1568. if assert.Error(t, err) {
  1569. assert.Contains(t, err.Error(), "unable to get the auth request for the specified state")
  1570. }
  1571. mgr.addPendingAuth(pendingAuth)
  1572. _, err = mgr.getPendingAuth(pendingAuth.State)
  1573. if assert.Error(t, err) {
  1574. assert.Contains(t, err.Error(), "auth request is too old")
  1575. }
  1576. mgr.cleanup()
  1577. _, err = mgr.getPendingAuth(pendingAuth.State)
  1578. if assert.Error(t, err) {
  1579. assert.Contains(t, err.Error(), "unable to get the auth request for the specified state")
  1580. }
  1581. token := oidcToken{
  1582. Cookie: util.GenerateOpaqueString(),
  1583. AccessToken: xid.New().String(),
  1584. TokenType: "Bearer",
  1585. RefreshToken: xid.New().String(),
  1586. ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-2 * time.Minute)),
  1587. SessionID: xid.New().String(),
  1588. IDToken: xid.New().String(),
  1589. Nonce: xid.New().String(),
  1590. Username: xid.New().String(),
  1591. Permissions: []string{dataprovider.PermAdminAny},
  1592. Role: "admin",
  1593. }
  1594. mgr.addToken(token)
  1595. tokenGet, err := mgr.getToken(token.Cookie)
  1596. assert.NoError(t, err)
  1597. assert.Greater(t, tokenGet.UsedAt, int64(0))
  1598. token.UsedAt = tokenGet.UsedAt
  1599. assert.Equal(t, token, tokenGet)
  1600. time.Sleep(100 * time.Millisecond)
  1601. mgr.updateTokenUsage(token)
  1602. // no change
  1603. tokenGet, err = mgr.getToken(token.Cookie)
  1604. assert.NoError(t, err)
  1605. assert.Equal(t, token.UsedAt, tokenGet.UsedAt)
  1606. tokenGet.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour))
  1607. tokenGet.RefreshToken = xid.New().String()
  1608. mgr.updateTokenUsage(tokenGet)
  1609. tokenGet, err = mgr.getToken(token.Cookie)
  1610. assert.NoError(t, err)
  1611. assert.NotEmpty(t, tokenGet.RefreshToken)
  1612. assert.NotEqual(t, token.RefreshToken, tokenGet.RefreshToken)
  1613. assert.Greater(t, tokenGet.UsedAt, token.UsedAt)
  1614. mgr.removeToken(token.Cookie)
  1615. tokenGet, err = mgr.getToken(token.Cookie)
  1616. if assert.Error(t, err) {
  1617. assert.Contains(t, err.Error(), "unable to get the token for the specified session")
  1618. }
  1619. // add an expired token
  1620. token.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour))
  1621. session := dataprovider.Session{
  1622. Key: token.Cookie,
  1623. Data: token,
  1624. Type: dataprovider.SessionTypeOIDCToken,
  1625. Timestamp: token.UsedAt + tokenDeleteInterval,
  1626. }
  1627. err = dataprovider.AddSharedSession(session)
  1628. assert.NoError(t, err)
  1629. _, err = mgr.getToken(token.Cookie)
  1630. if assert.Error(t, err) {
  1631. assert.Contains(t, err.Error(), "token is too old")
  1632. }
  1633. mgr.cleanup()
  1634. _, err = mgr.getToken(token.Cookie)
  1635. if assert.Error(t, err) {
  1636. assert.Contains(t, err.Error(), "unable to get the token for the specified session")
  1637. }
  1638. // adding a session without a key should fail
  1639. session.Key = ""
  1640. err = dataprovider.AddSharedSession(session)
  1641. if assert.Error(t, err) {
  1642. assert.Contains(t, err.Error(), "unable to save a session with an empty key")
  1643. }
  1644. session.Key = xid.New().String()
  1645. session.Type = 1000
  1646. err = dataprovider.AddSharedSession(session)
  1647. if assert.Error(t, err) {
  1648. assert.Contains(t, err.Error(), "invalid session type")
  1649. }
  1650. dbMgr, ok := mgr.(*dbOIDCManager)
  1651. if assert.True(t, ok) {
  1652. _, err = dbMgr.decodePendingAuthData(2)
  1653. assert.Error(t, err)
  1654. _, err = dbMgr.decodeTokenData(true)
  1655. assert.Error(t, err)
  1656. }
  1657. }
  1658. func getTestOIDCServer() *httpdServer {
  1659. return &httpdServer{
  1660. binding: Binding{
  1661. OIDC: OIDC{
  1662. ClientID: "sftpgo-client",
  1663. ClientSecret: "jRsmE0SWnuZjP7djBqNq0mrf8QN77j2c",
  1664. ConfigURL: fmt.Sprintf("http://%v/auth/realms/sftpgo", oidcMockAddr),
  1665. RedirectBaseURL: "http://127.0.0.1:8081/",
  1666. UsernameField: "preferred_username",
  1667. RoleField: "sftpgo_role",
  1668. ImplicitRoles: false,
  1669. Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
  1670. CustomFields: nil,
  1671. Debug: true,
  1672. },
  1673. },
  1674. enableWebAdmin: true,
  1675. enableWebClient: true,
  1676. }
  1677. }
  1678. func getPreLoginScriptContent(user dataprovider.User, nonJSONResponse bool) []byte {
  1679. content := []byte("#!/bin/sh\n\n")
  1680. if nonJSONResponse {
  1681. content = append(content, []byte("echo 'text response'\n")...)
  1682. return content
  1683. }
  1684. if len(user.Username) > 0 {
  1685. u, _ := json.Marshal(user)
  1686. content = append(content, []byte(fmt.Sprintf("echo '%v'\n", string(u)))...)
  1687. }
  1688. return content
  1689. }