oidc_test.go 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161
  1. package httpd
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "net/http"
  7. "net/http/httptest"
  8. "os"
  9. "path/filepath"
  10. "reflect"
  11. "runtime"
  12. "testing"
  13. "time"
  14. "unsafe"
  15. "github.com/coreos/go-oidc/v3/oidc"
  16. "github.com/go-chi/jwtauth/v5"
  17. "github.com/rs/xid"
  18. "github.com/sftpgo/sdk"
  19. "github.com/stretchr/testify/assert"
  20. "github.com/stretchr/testify/require"
  21. "golang.org/x/oauth2"
  22. "github.com/drakkan/sftpgo/v2/common"
  23. "github.com/drakkan/sftpgo/v2/dataprovider"
  24. "github.com/drakkan/sftpgo/v2/kms"
  25. "github.com/drakkan/sftpgo/v2/util"
  26. "github.com/drakkan/sftpgo/v2/vfs"
  27. )
  28. const (
  29. oidcMockAddr = "127.0.0.1:11111"
  30. )
  31. type mockTokenSource struct {
  32. token *oauth2.Token
  33. err error
  34. }
  35. func (t *mockTokenSource) Token() (*oauth2.Token, error) {
  36. return t.token, t.err
  37. }
  38. type mockOAuth2Config struct {
  39. tokenSource *mockTokenSource
  40. authCodeURL string
  41. token *oauth2.Token
  42. err error
  43. }
  44. func (c *mockOAuth2Config) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
  45. return c.authCodeURL
  46. }
  47. func (c *mockOAuth2Config) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
  48. return c.token, c.err
  49. }
  50. func (c *mockOAuth2Config) TokenSource(ctx context.Context, t *oauth2.Token) oauth2.TokenSource {
  51. return c.tokenSource
  52. }
  53. type mockOIDCVerifier struct {
  54. token *oidc.IDToken
  55. err error
  56. }
  57. func (v *mockOIDCVerifier) Verify(ctx context.Context, rawIDToken string) (*oidc.IDToken, error) {
  58. return v.token, v.err
  59. }
  60. // hack because the field is unexported
  61. func setIDTokenClaims(idToken *oidc.IDToken, claims []byte) {
  62. pointerVal := reflect.ValueOf(idToken)
  63. val := reflect.Indirect(pointerVal)
  64. member := val.FieldByName("claims")
  65. ptr := unsafe.Pointer(member.UnsafeAddr())
  66. realPtr := (*[]byte)(ptr)
  67. *realPtr = claims
  68. }
  69. func TestOIDCInitialization(t *testing.T) {
  70. config := OIDC{}
  71. err := config.initialize()
  72. assert.NoError(t, err)
  73. config = OIDC{
  74. ClientID: "sftpgo-client",
  75. ClientSecret: "jRsmE0SWnuZjP7djBqNq0mrf8QN77j2c",
  76. ConfigURL: fmt.Sprintf("http://%v/", oidcMockAddr),
  77. RedirectBaseURL: "http://127.0.0.1:8081/",
  78. UsernameField: "preferred_username",
  79. RoleField: "sftpgo_role",
  80. }
  81. err = config.initialize()
  82. if assert.Error(t, err) {
  83. assert.Contains(t, err.Error(), "oidc: unable to initialize provider")
  84. }
  85. config.ConfigURL = fmt.Sprintf("http://%v/auth/realms/sftpgo", oidcMockAddr)
  86. err = config.initialize()
  87. assert.NoError(t, err)
  88. assert.Equal(t, "http://127.0.0.1:8081"+webOIDCRedirectPath, config.getRedirectURL())
  89. }
  90. func TestOIDCLoginLogout(t *testing.T) {
  91. oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
  92. require.True(t, ok)
  93. server := getTestOIDCServer()
  94. err := server.binding.OIDC.initialize()
  95. assert.NoError(t, err)
  96. server.initializeRouter()
  97. rr := httptest.NewRecorder()
  98. r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath, nil)
  99. assert.NoError(t, err)
  100. server.router.ServeHTTP(rr, r)
  101. assert.Equal(t, http.StatusBadRequest, rr.Code)
  102. assert.Contains(t, rr.Body.String(), "Authentication state did not match")
  103. expiredAuthReq := oidcPendingAuth{
  104. State: xid.New().String(),
  105. Nonce: xid.New().String(),
  106. Audience: tokenAudienceWebClient,
  107. IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)),
  108. }
  109. oidcMgr.addPendingAuth(expiredAuthReq)
  110. rr = httptest.NewRecorder()
  111. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+expiredAuthReq.State, nil)
  112. assert.NoError(t, err)
  113. server.router.ServeHTTP(rr, r)
  114. assert.Equal(t, http.StatusBadRequest, rr.Code)
  115. assert.Contains(t, rr.Body.String(), "Authentication state did not match")
  116. oidcMgr.removePendingAuth(expiredAuthReq.State)
  117. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  118. tokenSource: &mockTokenSource{},
  119. authCodeURL: webOIDCRedirectPath,
  120. err: common.ErrGenericFailure,
  121. }
  122. server.binding.OIDC.verifier = &mockOIDCVerifier{
  123. err: common.ErrGenericFailure,
  124. }
  125. rr = httptest.NewRecorder()
  126. r, err = http.NewRequest(http.MethodGet, webAdminOIDCLoginPath, nil)
  127. assert.NoError(t, err)
  128. server.router.ServeHTTP(rr, r)
  129. assert.Equal(t, http.StatusFound, rr.Code)
  130. assert.Equal(t, webOIDCRedirectPath, rr.Header().Get("Location"))
  131. require.Len(t, oidcMgr.pendingAuths, 1)
  132. var state string
  133. for k := range oidcMgr.pendingAuths {
  134. state = k
  135. }
  136. rr = httptest.NewRecorder()
  137. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+state, nil)
  138. assert.NoError(t, err)
  139. server.router.ServeHTTP(rr, r)
  140. assert.Equal(t, http.StatusFound, rr.Code)
  141. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  142. require.Len(t, oidcMgr.pendingAuths, 0)
  143. rr = httptest.NewRecorder()
  144. r, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil)
  145. assert.NoError(t, err)
  146. server.router.ServeHTTP(rr, r)
  147. assert.Equal(t, http.StatusOK, rr.Code)
  148. // now the same for the web client
  149. rr = httptest.NewRecorder()
  150. r, err = http.NewRequest(http.MethodGet, webClientOIDCLoginPath, nil)
  151. assert.NoError(t, err)
  152. server.router.ServeHTTP(rr, r)
  153. assert.Equal(t, http.StatusFound, rr.Code)
  154. assert.Equal(t, webOIDCRedirectPath, rr.Header().Get("Location"))
  155. require.Len(t, oidcMgr.pendingAuths, 1)
  156. for k := range oidcMgr.pendingAuths {
  157. state = k
  158. }
  159. rr = httptest.NewRecorder()
  160. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+state, nil)
  161. assert.NoError(t, err)
  162. server.router.ServeHTTP(rr, r)
  163. assert.Equal(t, http.StatusFound, rr.Code)
  164. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  165. require.Len(t, oidcMgr.pendingAuths, 0)
  166. rr = httptest.NewRecorder()
  167. r, err = http.NewRequest(http.MethodGet, webClientLoginPath, nil)
  168. assert.NoError(t, err)
  169. server.router.ServeHTTP(rr, r)
  170. assert.Equal(t, http.StatusOK, rr.Code)
  171. // now return an OAuth2 token without the id_token
  172. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  173. tokenSource: &mockTokenSource{},
  174. authCodeURL: webOIDCRedirectPath,
  175. token: &oauth2.Token{
  176. AccessToken: "123",
  177. Expiry: time.Now().Add(5 * time.Minute),
  178. },
  179. err: nil,
  180. }
  181. authReq := newOIDCPendingAuth(tokenAudienceWebClient)
  182. oidcMgr.addPendingAuth(authReq)
  183. rr = httptest.NewRecorder()
  184. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  185. assert.NoError(t, err)
  186. server.router.ServeHTTP(rr, r)
  187. assert.Equal(t, http.StatusFound, rr.Code)
  188. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  189. require.Len(t, oidcMgr.pendingAuths, 0)
  190. // now fail to verify the id token
  191. token := &oauth2.Token{
  192. AccessToken: "123",
  193. Expiry: time.Now().Add(5 * time.Minute),
  194. }
  195. token = token.WithExtra(map[string]any{
  196. "id_token": "id_token_val",
  197. })
  198. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  199. tokenSource: &mockTokenSource{},
  200. authCodeURL: webOIDCRedirectPath,
  201. token: token,
  202. err: nil,
  203. }
  204. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  205. oidcMgr.addPendingAuth(authReq)
  206. rr = httptest.NewRecorder()
  207. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  208. assert.NoError(t, err)
  209. server.router.ServeHTTP(rr, r)
  210. assert.Equal(t, http.StatusFound, rr.Code)
  211. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  212. require.Len(t, oidcMgr.pendingAuths, 0)
  213. // id token nonce does not match
  214. server.binding.OIDC.verifier = &mockOIDCVerifier{
  215. err: nil,
  216. token: &oidc.IDToken{},
  217. }
  218. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  219. oidcMgr.addPendingAuth(authReq)
  220. rr = httptest.NewRecorder()
  221. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  222. assert.NoError(t, err)
  223. server.router.ServeHTTP(rr, r)
  224. assert.Equal(t, http.StatusFound, rr.Code)
  225. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  226. require.Len(t, oidcMgr.pendingAuths, 0)
  227. // null id token claims
  228. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  229. oidcMgr.addPendingAuth(authReq)
  230. server.binding.OIDC.verifier = &mockOIDCVerifier{
  231. err: nil,
  232. token: &oidc.IDToken{
  233. Nonce: authReq.Nonce,
  234. },
  235. }
  236. rr = httptest.NewRecorder()
  237. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  238. assert.NoError(t, err)
  239. server.router.ServeHTTP(rr, r)
  240. assert.Equal(t, http.StatusFound, rr.Code)
  241. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  242. require.Len(t, oidcMgr.pendingAuths, 0)
  243. // invalid id token claims (no username)
  244. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  245. oidcMgr.addPendingAuth(authReq)
  246. idToken := &oidc.IDToken{
  247. Nonce: authReq.Nonce,
  248. Expiry: time.Now().Add(5 * time.Minute),
  249. }
  250. setIDTokenClaims(idToken, []byte(`{"aud": "my_client_id"}`))
  251. server.binding.OIDC.verifier = &mockOIDCVerifier{
  252. err: nil,
  253. token: idToken,
  254. }
  255. rr = httptest.NewRecorder()
  256. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  257. assert.NoError(t, err)
  258. server.router.ServeHTTP(rr, r)
  259. assert.Equal(t, http.StatusFound, rr.Code)
  260. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  261. require.Len(t, oidcMgr.pendingAuths, 0)
  262. // invalid audience
  263. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  264. oidcMgr.addPendingAuth(authReq)
  265. idToken = &oidc.IDToken{
  266. Nonce: authReq.Nonce,
  267. Expiry: time.Now().Add(5 * time.Minute),
  268. }
  269. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test","sftpgo_role":"admin"}`))
  270. server.binding.OIDC.verifier = &mockOIDCVerifier{
  271. err: nil,
  272. token: idToken,
  273. }
  274. rr = httptest.NewRecorder()
  275. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  276. assert.NoError(t, err)
  277. server.router.ServeHTTP(rr, r)
  278. assert.Equal(t, http.StatusFound, rr.Code)
  279. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  280. require.Len(t, oidcMgr.pendingAuths, 0)
  281. // invalid audience
  282. authReq = newOIDCPendingAuth(tokenAudienceWebAdmin)
  283. oidcMgr.addPendingAuth(authReq)
  284. idToken = &oidc.IDToken{
  285. Nonce: authReq.Nonce,
  286. Expiry: time.Now().Add(5 * time.Minute),
  287. }
  288. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test"}`))
  289. server.binding.OIDC.verifier = &mockOIDCVerifier{
  290. err: nil,
  291. token: idToken,
  292. }
  293. rr = httptest.NewRecorder()
  294. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  295. assert.NoError(t, err)
  296. server.router.ServeHTTP(rr, r)
  297. assert.Equal(t, http.StatusFound, rr.Code)
  298. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  299. require.Len(t, oidcMgr.pendingAuths, 0)
  300. // mapped user not found
  301. authReq = newOIDCPendingAuth(tokenAudienceWebAdmin)
  302. oidcMgr.addPendingAuth(authReq)
  303. idToken = &oidc.IDToken{
  304. Nonce: authReq.Nonce,
  305. Expiry: time.Now().Add(5 * time.Minute),
  306. }
  307. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test","sftpgo_role":"admin"}`))
  308. server.binding.OIDC.verifier = &mockOIDCVerifier{
  309. err: nil,
  310. token: idToken,
  311. }
  312. rr = httptest.NewRecorder()
  313. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  314. assert.NoError(t, err)
  315. server.router.ServeHTTP(rr, r)
  316. assert.Equal(t, http.StatusFound, rr.Code)
  317. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  318. require.Len(t, oidcMgr.pendingAuths, 0)
  319. // admin login ok
  320. authReq = newOIDCPendingAuth(tokenAudienceWebAdmin)
  321. oidcMgr.addPendingAuth(authReq)
  322. idToken = &oidc.IDToken{
  323. Nonce: authReq.Nonce,
  324. Expiry: time.Now().Add(5 * time.Minute),
  325. }
  326. setIDTokenClaims(idToken, []byte(`{"preferred_username":"admin","sftpgo_role":"admin","sid":"sid123"}`))
  327. server.binding.OIDC.verifier = &mockOIDCVerifier{
  328. err: nil,
  329. token: idToken,
  330. }
  331. rr = httptest.NewRecorder()
  332. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  333. assert.NoError(t, err)
  334. server.router.ServeHTTP(rr, r)
  335. assert.Equal(t, http.StatusFound, rr.Code)
  336. assert.Equal(t, webUsersPath, rr.Header().Get("Location"))
  337. require.Len(t, oidcMgr.pendingAuths, 0)
  338. require.Len(t, oidcMgr.tokens, 1)
  339. // admin profile is not available
  340. var tokenCookie string
  341. for k := range oidcMgr.tokens {
  342. tokenCookie = k
  343. }
  344. oidcToken, err := oidcMgr.getToken(tokenCookie)
  345. assert.NoError(t, err)
  346. assert.Equal(t, "sid123", oidcToken.SessionID)
  347. assert.True(t, oidcToken.isAdmin())
  348. assert.False(t, oidcToken.isExpired())
  349. rr = httptest.NewRecorder()
  350. r, err = http.NewRequest(http.MethodGet, webAdminProfilePath, nil)
  351. assert.NoError(t, err)
  352. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  353. server.router.ServeHTTP(rr, r)
  354. assert.Equal(t, http.StatusForbidden, rr.Code)
  355. // the admin can access the allowed pages
  356. rr = httptest.NewRecorder()
  357. r, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
  358. assert.NoError(t, err)
  359. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  360. server.router.ServeHTTP(rr, r)
  361. assert.Equal(t, http.StatusOK, rr.Code)
  362. // try with an invalid cookie
  363. rr = httptest.NewRecorder()
  364. r, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
  365. assert.NoError(t, err)
  366. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, xid.New().String()))
  367. server.router.ServeHTTP(rr, r)
  368. assert.Equal(t, http.StatusFound, rr.Code)
  369. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  370. // Web Client is not available with an admin token
  371. rr = httptest.NewRecorder()
  372. r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
  373. assert.NoError(t, err)
  374. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  375. server.router.ServeHTTP(rr, r)
  376. assert.Equal(t, http.StatusFound, rr.Code)
  377. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  378. // logout the admin user
  379. rr = httptest.NewRecorder()
  380. r, err = http.NewRequest(http.MethodGet, webLogoutPath, nil)
  381. assert.NoError(t, err)
  382. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  383. server.router.ServeHTTP(rr, r)
  384. assert.Equal(t, http.StatusFound, rr.Code)
  385. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  386. require.Len(t, oidcMgr.pendingAuths, 0)
  387. require.Len(t, oidcMgr.tokens, 0)
  388. // now login and logout a user
  389. username := "test_oidc_user"
  390. user := dataprovider.User{
  391. BaseUser: sdk.BaseUser{
  392. Username: username,
  393. Password: "pwd",
  394. HomeDir: filepath.Join(os.TempDir(), username),
  395. Status: 1,
  396. Permissions: map[string][]string{
  397. "/": {dataprovider.PermAny},
  398. },
  399. },
  400. Filters: dataprovider.UserFilters{
  401. BaseUserFilters: sdk.BaseUserFilters{
  402. WebClient: []string{sdk.WebClientSharesDisabled},
  403. },
  404. },
  405. }
  406. err = dataprovider.AddUser(&user, "", "")
  407. assert.NoError(t, err)
  408. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  409. oidcMgr.addPendingAuth(authReq)
  410. idToken = &oidc.IDToken{
  411. Nonce: authReq.Nonce,
  412. Expiry: time.Now().Add(5 * time.Minute),
  413. }
  414. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test_oidc_user"}`))
  415. server.binding.OIDC.verifier = &mockOIDCVerifier{
  416. err: nil,
  417. token: idToken,
  418. }
  419. rr = httptest.NewRecorder()
  420. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  421. assert.NoError(t, err)
  422. server.router.ServeHTTP(rr, r)
  423. assert.Equal(t, http.StatusFound, rr.Code)
  424. assert.Equal(t, webClientFilesPath, rr.Header().Get("Location"))
  425. require.Len(t, oidcMgr.pendingAuths, 0)
  426. require.Len(t, oidcMgr.tokens, 1)
  427. // user profile is not available
  428. for k := range oidcMgr.tokens {
  429. tokenCookie = k
  430. }
  431. oidcToken, err = oidcMgr.getToken(tokenCookie)
  432. assert.NoError(t, err)
  433. assert.Empty(t, oidcToken.SessionID)
  434. assert.False(t, oidcToken.isAdmin())
  435. assert.False(t, oidcToken.isExpired())
  436. if assert.Len(t, oidcToken.Permissions, 1) {
  437. assert.Equal(t, sdk.WebClientSharesDisabled, oidcToken.Permissions[0])
  438. }
  439. rr = httptest.NewRecorder()
  440. r, err = http.NewRequest(http.MethodGet, webClientProfilePath, nil)
  441. assert.NoError(t, err)
  442. r.RequestURI = webClientProfilePath
  443. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  444. server.router.ServeHTTP(rr, r)
  445. assert.Equal(t, http.StatusForbidden, rr.Code)
  446. // the user can access the allowed pages
  447. rr = httptest.NewRecorder()
  448. r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
  449. assert.NoError(t, err)
  450. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  451. server.router.ServeHTTP(rr, r)
  452. assert.Equal(t, http.StatusOK, rr.Code)
  453. // try with an invalid cookie
  454. rr = httptest.NewRecorder()
  455. r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
  456. assert.NoError(t, err)
  457. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, xid.New().String()))
  458. server.router.ServeHTTP(rr, r)
  459. assert.Equal(t, http.StatusFound, rr.Code)
  460. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  461. // Web Admin is not available with a client cookie
  462. rr = httptest.NewRecorder()
  463. r, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
  464. assert.NoError(t, err)
  465. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  466. server.router.ServeHTTP(rr, r)
  467. assert.Equal(t, http.StatusFound, rr.Code)
  468. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  469. // logout the user
  470. rr = httptest.NewRecorder()
  471. r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  472. assert.NoError(t, err)
  473. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  474. server.router.ServeHTTP(rr, r)
  475. assert.Equal(t, http.StatusFound, rr.Code)
  476. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  477. require.Len(t, oidcMgr.pendingAuths, 0)
  478. require.Len(t, oidcMgr.tokens, 0)
  479. err = os.RemoveAll(user.GetHomeDir())
  480. assert.NoError(t, err)
  481. err = dataprovider.DeleteUser(username, "", "")
  482. assert.NoError(t, err)
  483. }
  484. func TestOIDCRefreshToken(t *testing.T) {
  485. oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
  486. require.True(t, ok)
  487. token := oidcToken{
  488. Cookie: xid.New().String(),
  489. AccessToken: xid.New().String(),
  490. TokenType: "Bearer",
  491. ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-1 * time.Minute)),
  492. Nonce: xid.New().String(),
  493. }
  494. config := mockOAuth2Config{
  495. tokenSource: &mockTokenSource{
  496. err: common.ErrGenericFailure,
  497. },
  498. }
  499. verifier := mockOIDCVerifier{
  500. err: common.ErrGenericFailure,
  501. }
  502. err := token.refresh(&config, &verifier)
  503. if assert.Error(t, err) {
  504. assert.Contains(t, err.Error(), "refresh token not set")
  505. }
  506. token.RefreshToken = xid.New().String()
  507. err = token.refresh(&config, &verifier)
  508. assert.ErrorIs(t, err, common.ErrGenericFailure)
  509. newToken := &oauth2.Token{
  510. AccessToken: xid.New().String(),
  511. RefreshToken: xid.New().String(),
  512. Expiry: time.Now().Add(5 * time.Minute),
  513. }
  514. config = mockOAuth2Config{
  515. tokenSource: &mockTokenSource{
  516. token: newToken,
  517. },
  518. }
  519. verifier = mockOIDCVerifier{
  520. token: &oidc.IDToken{},
  521. }
  522. err = token.refresh(&config, &verifier)
  523. if assert.Error(t, err) {
  524. assert.Contains(t, err.Error(), "the refreshed token has no id token")
  525. }
  526. newToken = newToken.WithExtra(map[string]any{
  527. "id_token": "id_token_val",
  528. })
  529. newToken.Expiry = time.Time{}
  530. config = mockOAuth2Config{
  531. tokenSource: &mockTokenSource{
  532. token: newToken,
  533. },
  534. }
  535. verifier = mockOIDCVerifier{
  536. err: common.ErrGenericFailure,
  537. }
  538. err = token.refresh(&config, &verifier)
  539. assert.ErrorIs(t, err, common.ErrGenericFailure)
  540. newToken = newToken.WithExtra(map[string]any{
  541. "id_token": "id_token_val",
  542. })
  543. newToken.Expiry = time.Now().Add(5 * time.Minute)
  544. config = mockOAuth2Config{
  545. tokenSource: &mockTokenSource{
  546. token: newToken,
  547. },
  548. }
  549. verifier = mockOIDCVerifier{
  550. token: &oidc.IDToken{},
  551. }
  552. err = token.refresh(&config, &verifier)
  553. if assert.Error(t, err) {
  554. assert.Contains(t, err.Error(), "the refreshed token nonce mismatch")
  555. }
  556. verifier = mockOIDCVerifier{
  557. token: &oidc.IDToken{
  558. Nonce: token.Nonce,
  559. },
  560. }
  561. err = token.refresh(&config, &verifier)
  562. if assert.Error(t, err) {
  563. assert.Contains(t, err.Error(), "oidc: claims not set")
  564. }
  565. idToken := &oidc.IDToken{
  566. Nonce: token.Nonce,
  567. }
  568. setIDTokenClaims(idToken, []byte(`{"sid":"id_token_sid"}`))
  569. verifier = mockOIDCVerifier{
  570. token: idToken,
  571. }
  572. err = token.refresh(&config, &verifier)
  573. assert.NoError(t, err)
  574. require.Len(t, oidcMgr.tokens, 1)
  575. oidcMgr.removeToken(token.Cookie)
  576. require.Len(t, oidcMgr.tokens, 0)
  577. }
  578. func TestValidateOIDCToken(t *testing.T) {
  579. oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
  580. require.True(t, ok)
  581. server := getTestOIDCServer()
  582. err := server.binding.OIDC.initialize()
  583. assert.NoError(t, err)
  584. server.initializeRouter()
  585. rr := httptest.NewRecorder()
  586. r, err := http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  587. assert.NoError(t, err)
  588. _, err = server.validateOIDCToken(rr, r, false)
  589. assert.ErrorIs(t, err, errInvalidToken)
  590. // expired token and refresh error
  591. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  592. tokenSource: &mockTokenSource{
  593. err: common.ErrGenericFailure,
  594. },
  595. }
  596. token := oidcToken{
  597. Cookie: xid.New().String(),
  598. AccessToken: xid.New().String(),
  599. ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-2 * time.Minute)),
  600. }
  601. oidcMgr.addToken(token)
  602. rr = httptest.NewRecorder()
  603. r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  604. assert.NoError(t, err)
  605. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie))
  606. _, err = server.validateOIDCToken(rr, r, false)
  607. assert.ErrorIs(t, err, errInvalidToken)
  608. oidcMgr.removeToken(token.Cookie)
  609. assert.Len(t, oidcMgr.tokens, 0)
  610. server.tokenAuth = jwtauth.New("PS256", util.GenerateRandomBytes(32), nil)
  611. token = oidcToken{
  612. Cookie: xid.New().String(),
  613. AccessToken: xid.New().String(),
  614. }
  615. oidcMgr.addToken(token)
  616. rr = httptest.NewRecorder()
  617. r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  618. assert.NoError(t, err)
  619. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie))
  620. server.router.ServeHTTP(rr, r)
  621. assert.Equal(t, http.StatusFound, rr.Code)
  622. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  623. oidcMgr.removeToken(token.Cookie)
  624. assert.Len(t, oidcMgr.tokens, 0)
  625. token = oidcToken{
  626. Cookie: xid.New().String(),
  627. AccessToken: xid.New().String(),
  628. Role: "admin",
  629. }
  630. oidcMgr.addToken(token)
  631. rr = httptest.NewRecorder()
  632. r, err = http.NewRequest(http.MethodGet, webLogoutPath, nil)
  633. assert.NoError(t, err)
  634. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie))
  635. server.router.ServeHTTP(rr, r)
  636. assert.Equal(t, http.StatusFound, rr.Code)
  637. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  638. oidcMgr.removeToken(token.Cookie)
  639. assert.Len(t, oidcMgr.tokens, 0)
  640. }
  641. func TestSkipOIDCAuth(t *testing.T) {
  642. server := getTestOIDCServer()
  643. err := server.binding.OIDC.initialize()
  644. assert.NoError(t, err)
  645. server.initializeRouter()
  646. jwtTokenClaims := jwtTokenClaims{
  647. Username: "user",
  648. }
  649. _, tokenString, err := jwtTokenClaims.createToken(server.tokenAuth, tokenAudienceWebClient, "")
  650. assert.NoError(t, err)
  651. rr := httptest.NewRecorder()
  652. r, err := http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  653. assert.NoError(t, err)
  654. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", jwtCookieKey, tokenString))
  655. server.router.ServeHTTP(rr, r)
  656. assert.Equal(t, http.StatusFound, rr.Code)
  657. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  658. }
  659. func TestOIDCLogoutErrors(t *testing.T) {
  660. server := getTestOIDCServer()
  661. assert.Empty(t, server.binding.OIDC.providerLogoutURL)
  662. server.logoutFromOIDCOP("")
  663. server.binding.OIDC.providerLogoutURL = "http://foo\x7f.com/"
  664. server.doOIDCFromLogout("")
  665. server.binding.OIDC.providerLogoutURL = "http://127.0.0.1:11234"
  666. server.doOIDCFromLogout("")
  667. }
  668. func TestOIDCToken(t *testing.T) {
  669. admin := dataprovider.Admin{
  670. Username: "test_oidc_admin",
  671. Password: "p",
  672. Permissions: []string{dataprovider.PermAdminAny},
  673. Status: 0,
  674. }
  675. err := dataprovider.AddAdmin(&admin, "", "")
  676. assert.NoError(t, err)
  677. token := oidcToken{
  678. Username: admin.Username,
  679. Role: "admin",
  680. }
  681. req, err := http.NewRequest(http.MethodGet, webUsersPath, nil)
  682. assert.NoError(t, err)
  683. err = token.getUser(req)
  684. if assert.Error(t, err) {
  685. assert.Contains(t, err.Error(), "is disabled")
  686. }
  687. err = dataprovider.DeleteAdmin(admin.Username, "", "")
  688. assert.NoError(t, err)
  689. username := "test_oidc_user"
  690. token.Username = username
  691. token.Role = ""
  692. err = token.getUser(req)
  693. if assert.Error(t, err) {
  694. _, ok := err.(*util.RecordNotFoundError)
  695. assert.True(t, ok)
  696. }
  697. user := dataprovider.User{
  698. BaseUser: sdk.BaseUser{
  699. Username: username,
  700. Password: "p",
  701. HomeDir: filepath.Join(os.TempDir(), username),
  702. Status: 0,
  703. Permissions: map[string][]string{
  704. "/": {dataprovider.PermAny},
  705. },
  706. },
  707. Filters: dataprovider.UserFilters{
  708. BaseUserFilters: sdk.BaseUserFilters{
  709. DeniedProtocols: []string{common.ProtocolHTTP},
  710. },
  711. },
  712. }
  713. err = dataprovider.AddUser(&user, "", "")
  714. assert.NoError(t, err)
  715. err = token.getUser(req)
  716. if assert.Error(t, err) {
  717. assert.Contains(t, err.Error(), "is disabled")
  718. }
  719. user, err = dataprovider.UserExists(username)
  720. assert.NoError(t, err)
  721. user.Status = 1
  722. user.Password = "np"
  723. err = dataprovider.UpdateUser(&user, "", "")
  724. assert.NoError(t, err)
  725. err = token.getUser(req)
  726. if assert.Error(t, err) {
  727. assert.Contains(t, err.Error(), "protocol HTTP is not allowed")
  728. }
  729. user.Filters.DeniedProtocols = nil
  730. user.FsConfig.Provider = sdk.SFTPFilesystemProvider
  731. user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{
  732. BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{
  733. Endpoint: "127.0.0.1:8022",
  734. Username: username,
  735. },
  736. Password: kms.NewPlainSecret("np"),
  737. }
  738. err = dataprovider.UpdateUser(&user, "", "")
  739. assert.NoError(t, err)
  740. err = token.getUser(req)
  741. if assert.Error(t, err) {
  742. assert.Contains(t, err.Error(), "SFTP loop")
  743. }
  744. common.Config.PostConnectHook = fmt.Sprintf("http://%v/404", oidcMockAddr)
  745. err = token.getUser(req)
  746. if assert.Error(t, err) {
  747. assert.Contains(t, err.Error(), "access denied by post connect hook")
  748. }
  749. common.Config.PostConnectHook = ""
  750. err = os.RemoveAll(user.GetHomeDir())
  751. assert.NoError(t, err)
  752. err = dataprovider.DeleteUser(username, "", "")
  753. assert.NoError(t, err)
  754. }
  755. func TestMemoryOIDCManager(t *testing.T) {
  756. oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
  757. require.True(t, ok)
  758. require.Len(t, oidcMgr.pendingAuths, 0)
  759. authReq := newOIDCPendingAuth(tokenAudienceWebAdmin)
  760. oidcMgr.addPendingAuth(authReq)
  761. require.Len(t, oidcMgr.pendingAuths, 1)
  762. _, err := oidcMgr.getPendingAuth(authReq.State)
  763. assert.NoError(t, err)
  764. oidcMgr.removePendingAuth(authReq.State)
  765. require.Len(t, oidcMgr.pendingAuths, 0)
  766. authReq.IssuedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-61 * time.Second))
  767. oidcMgr.addPendingAuth(authReq)
  768. require.Len(t, oidcMgr.pendingAuths, 1)
  769. _, err = oidcMgr.getPendingAuth(authReq.State)
  770. if assert.Error(t, err) {
  771. assert.Contains(t, err.Error(), "too old")
  772. }
  773. oidcMgr.cleanup()
  774. require.Len(t, oidcMgr.pendingAuths, 0)
  775. token := oidcToken{
  776. AccessToken: xid.New().String(),
  777. Nonce: xid.New().String(),
  778. SessionID: xid.New().String(),
  779. Cookie: xid.New().String(),
  780. Username: xid.New().String(),
  781. Role: "admin",
  782. Permissions: []string{dataprovider.PermAdminAny},
  783. }
  784. require.Len(t, oidcMgr.tokens, 0)
  785. oidcMgr.addToken(token)
  786. require.Len(t, oidcMgr.tokens, 1)
  787. _, err = oidcMgr.getToken(xid.New().String())
  788. assert.Error(t, err)
  789. storedToken, err := oidcMgr.getToken(token.Cookie)
  790. assert.NoError(t, err)
  791. token.UsedAt = 0 // ensure we don't modify the stored token
  792. assert.Greater(t, storedToken.UsedAt, int64(0))
  793. token.UsedAt = storedToken.UsedAt
  794. assert.Equal(t, token, storedToken)
  795. // the usage will not be updated, it is recent
  796. oidcMgr.updateTokenUsage(storedToken)
  797. storedToken, err = oidcMgr.getToken(token.Cookie)
  798. assert.NoError(t, err)
  799. assert.Equal(t, token, storedToken)
  800. usedAt := util.GetTimeAsMsSinceEpoch(time.Now().Add(-5 * time.Minute))
  801. storedToken.UsedAt = usedAt
  802. oidcMgr.tokens[token.Cookie] = storedToken
  803. storedToken, err = oidcMgr.getToken(token.Cookie)
  804. assert.NoError(t, err)
  805. assert.Equal(t, usedAt, storedToken.UsedAt)
  806. token.UsedAt = storedToken.UsedAt
  807. assert.Equal(t, token, storedToken)
  808. oidcMgr.updateTokenUsage(storedToken)
  809. storedToken, err = oidcMgr.getToken(token.Cookie)
  810. assert.NoError(t, err)
  811. assert.Greater(t, storedToken.UsedAt, usedAt)
  812. token.UsedAt = storedToken.UsedAt
  813. assert.Equal(t, token, storedToken)
  814. storedToken.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now()) - tokenDeleteInterval - 1
  815. oidcMgr.tokens[token.Cookie] = storedToken
  816. storedToken, err = oidcMgr.getToken(token.Cookie)
  817. if assert.Error(t, err) {
  818. assert.Contains(t, err.Error(), "token is too old")
  819. }
  820. oidcMgr.removeToken(xid.New().String())
  821. require.Len(t, oidcMgr.tokens, 1)
  822. oidcMgr.removeToken(token.Cookie)
  823. require.Len(t, oidcMgr.tokens, 0)
  824. oidcMgr.addToken(token)
  825. usedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-6 * time.Hour))
  826. token.UsedAt = usedAt
  827. oidcMgr.tokens[token.Cookie] = token
  828. newToken := oidcToken{
  829. Cookie: xid.New().String(),
  830. }
  831. oidcMgr.addToken(newToken)
  832. oidcMgr.cleanup()
  833. require.Len(t, oidcMgr.tokens, 1)
  834. _, err = oidcMgr.getToken(token.Cookie)
  835. assert.Error(t, err)
  836. _, err = oidcMgr.getToken(newToken.Cookie)
  837. assert.NoError(t, err)
  838. oidcMgr.removeToken(newToken.Cookie)
  839. require.Len(t, oidcMgr.tokens, 0)
  840. }
  841. func TestOIDCPreLoginHook(t *testing.T) {
  842. if runtime.GOOS == osWindows {
  843. t.Skip("this test is not available on Windows")
  844. }
  845. oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
  846. require.True(t, ok)
  847. username := "test_oidc_user_prelogin"
  848. u := dataprovider.User{
  849. BaseUser: sdk.BaseUser{
  850. Username: username,
  851. HomeDir: filepath.Join(os.TempDir(), username),
  852. Status: 1,
  853. Permissions: map[string][]string{
  854. "/": {dataprovider.PermAny},
  855. },
  856. },
  857. }
  858. preLoginPath := filepath.Join(os.TempDir(), "prelogin.sh")
  859. providerConf := dataprovider.GetProviderConfig()
  860. err := dataprovider.Close()
  861. assert.NoError(t, err)
  862. err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm)
  863. assert.NoError(t, err)
  864. newProviderConf := providerConf
  865. newProviderConf.PreLoginHook = preLoginPath
  866. err = dataprovider.Initialize(newProviderConf, "..", true)
  867. assert.NoError(t, err)
  868. server := getTestOIDCServer()
  869. server.binding.OIDC.CustomFields = []string{"field1", "field2"}
  870. err = server.binding.OIDC.initialize()
  871. assert.NoError(t, err)
  872. server.initializeRouter()
  873. _, err = dataprovider.UserExists(username)
  874. _, ok = err.(*util.RecordNotFoundError)
  875. assert.True(t, ok)
  876. // now login with OIDC
  877. authReq := newOIDCPendingAuth(tokenAudienceWebClient)
  878. oidcMgr.addPendingAuth(authReq)
  879. token := &oauth2.Token{
  880. AccessToken: "1234",
  881. Expiry: time.Now().Add(5 * time.Minute),
  882. }
  883. token = token.WithExtra(map[string]any{
  884. "id_token": "id_token_val",
  885. })
  886. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  887. tokenSource: &mockTokenSource{},
  888. authCodeURL: webOIDCRedirectPath,
  889. token: token,
  890. }
  891. idToken := &oidc.IDToken{
  892. Nonce: authReq.Nonce,
  893. Expiry: time.Now().Add(5 * time.Minute),
  894. }
  895. setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+username+`"}`))
  896. server.binding.OIDC.verifier = &mockOIDCVerifier{
  897. err: nil,
  898. token: idToken,
  899. }
  900. rr := httptest.NewRecorder()
  901. r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  902. assert.NoError(t, err)
  903. server.router.ServeHTTP(rr, r)
  904. assert.Equal(t, http.StatusFound, rr.Code)
  905. assert.Equal(t, webClientFilesPath, rr.Header().Get("Location"))
  906. _, err = dataprovider.UserExists(username)
  907. assert.NoError(t, err)
  908. err = dataprovider.DeleteUser(username, "", "")
  909. assert.NoError(t, err)
  910. err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, true), os.ModePerm)
  911. assert.NoError(t, err)
  912. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  913. oidcMgr.addPendingAuth(authReq)
  914. idToken = &oidc.IDToken{
  915. Nonce: authReq.Nonce,
  916. Expiry: time.Now().Add(5 * time.Minute),
  917. }
  918. setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+username+`","field1":"value1","field2":"value2","field3":"value3"}`))
  919. server.binding.OIDC.verifier = &mockOIDCVerifier{
  920. err: nil,
  921. token: idToken,
  922. }
  923. rr = httptest.NewRecorder()
  924. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  925. assert.NoError(t, err)
  926. server.router.ServeHTTP(rr, r)
  927. assert.Equal(t, http.StatusFound, rr.Code)
  928. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  929. _, err = dataprovider.UserExists(username)
  930. _, ok = err.(*util.RecordNotFoundError)
  931. assert.True(t, ok)
  932. if assert.Len(t, oidcMgr.tokens, 1) {
  933. for k := range oidcMgr.tokens {
  934. oidcMgr.removeToken(k)
  935. }
  936. }
  937. require.Len(t, oidcMgr.pendingAuths, 0)
  938. require.Len(t, oidcMgr.tokens, 0)
  939. err = dataprovider.Close()
  940. assert.NoError(t, err)
  941. err = dataprovider.Initialize(providerConf, "..", true)
  942. assert.NoError(t, err)
  943. err = os.Remove(preLoginPath)
  944. assert.NoError(t, err)
  945. }
  946. func TestOIDCIsAdmin(t *testing.T) {
  947. type test struct {
  948. input any
  949. want bool
  950. }
  951. emptySlice := make([]any, 0)
  952. tests := []test{
  953. {input: "admin", want: true},
  954. {input: append(emptySlice, "admin"), want: true},
  955. {input: append(emptySlice, "user", "admin"), want: true},
  956. {input: "user", want: false},
  957. {input: emptySlice, want: false},
  958. {input: append(emptySlice, 1), want: false},
  959. {input: 1, want: false},
  960. {input: nil, want: false},
  961. }
  962. for _, tc := range tests {
  963. token := oidcToken{
  964. Role: tc.input,
  965. }
  966. assert.Equal(t, tc.want, token.isAdmin(), "%v should return %t", tc.input, tc.want)
  967. }
  968. }
  969. func TestDbOIDCManager(t *testing.T) {
  970. if !isSharedProviderSupported() {
  971. t.Skip("this test it is not available with this provider")
  972. }
  973. mgr := newOIDCManager(1)
  974. pendingAuth := newOIDCPendingAuth(tokenAudienceWebAdmin)
  975. mgr.addPendingAuth(pendingAuth)
  976. authReq, err := mgr.getPendingAuth(pendingAuth.State)
  977. assert.NoError(t, err)
  978. assert.Equal(t, pendingAuth, authReq)
  979. pendingAuth.IssuedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour))
  980. mgr.addPendingAuth(pendingAuth)
  981. _, err = mgr.getPendingAuth(pendingAuth.State)
  982. if assert.Error(t, err) {
  983. assert.Contains(t, err.Error(), "auth request is too old")
  984. }
  985. mgr.removePendingAuth(pendingAuth.State)
  986. _, err = mgr.getPendingAuth(pendingAuth.State)
  987. if assert.Error(t, err) {
  988. assert.Contains(t, err.Error(), "unable to get the auth request for the specified state")
  989. }
  990. mgr.addPendingAuth(pendingAuth)
  991. _, err = mgr.getPendingAuth(pendingAuth.State)
  992. if assert.Error(t, err) {
  993. assert.Contains(t, err.Error(), "auth request is too old")
  994. }
  995. mgr.cleanup()
  996. _, err = mgr.getPendingAuth(pendingAuth.State)
  997. if assert.Error(t, err) {
  998. assert.Contains(t, err.Error(), "unable to get the auth request for the specified state")
  999. }
  1000. token := oidcToken{
  1001. Cookie: xid.New().String(),
  1002. AccessToken: xid.New().String(),
  1003. TokenType: "Bearer",
  1004. RefreshToken: xid.New().String(),
  1005. ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-2 * time.Minute)),
  1006. SessionID: xid.New().String(),
  1007. IDToken: xid.New().String(),
  1008. Nonce: xid.New().String(),
  1009. Username: xid.New().String(),
  1010. Permissions: []string{dataprovider.PermAdminAny},
  1011. Role: "admin",
  1012. }
  1013. mgr.addToken(token)
  1014. tokenGet, err := mgr.getToken(token.Cookie)
  1015. assert.NoError(t, err)
  1016. assert.Greater(t, tokenGet.UsedAt, int64(0))
  1017. token.UsedAt = tokenGet.UsedAt
  1018. assert.Equal(t, token, tokenGet)
  1019. time.Sleep(100 * time.Millisecond)
  1020. mgr.updateTokenUsage(token)
  1021. // no change
  1022. tokenGet, err = mgr.getToken(token.Cookie)
  1023. assert.NoError(t, err)
  1024. assert.Equal(t, token.UsedAt, tokenGet.UsedAt)
  1025. tokenGet.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour))
  1026. tokenGet.RefreshToken = xid.New().String()
  1027. mgr.updateTokenUsage(tokenGet)
  1028. tokenGet, err = mgr.getToken(token.Cookie)
  1029. assert.NoError(t, err)
  1030. assert.NotEmpty(t, tokenGet.RefreshToken)
  1031. assert.NotEqual(t, token.RefreshToken, tokenGet.RefreshToken)
  1032. assert.Greater(t, tokenGet.UsedAt, token.UsedAt)
  1033. mgr.removeToken(token.Cookie)
  1034. tokenGet, err = mgr.getToken(token.Cookie)
  1035. if assert.Error(t, err) {
  1036. assert.Contains(t, err.Error(), "unable to get the token for the specified session")
  1037. }
  1038. // add an expired token
  1039. token.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour))
  1040. session := dataprovider.Session{
  1041. Key: token.Cookie,
  1042. Data: token,
  1043. Type: dataprovider.SessionTypeOIDCToken,
  1044. Timestamp: token.UsedAt + tokenDeleteInterval,
  1045. }
  1046. err = dataprovider.AddSharedSession(session)
  1047. assert.NoError(t, err)
  1048. _, err = mgr.getToken(token.Cookie)
  1049. if assert.Error(t, err) {
  1050. assert.Contains(t, err.Error(), "token is too old")
  1051. }
  1052. mgr.cleanup()
  1053. _, err = mgr.getToken(token.Cookie)
  1054. if assert.Error(t, err) {
  1055. assert.Contains(t, err.Error(), "unable to get the token for the specified session")
  1056. }
  1057. // adding a session without a key should fail
  1058. session.Key = ""
  1059. err = dataprovider.AddSharedSession(session)
  1060. if assert.Error(t, err) {
  1061. assert.Contains(t, err.Error(), "unable to save a session with an empty key")
  1062. }
  1063. session.Key = xid.New().String()
  1064. session.Type = 1000
  1065. err = dataprovider.AddSharedSession(session)
  1066. if assert.Error(t, err) {
  1067. assert.Contains(t, err.Error(), "invalid session type")
  1068. }
  1069. dbMgr, ok := mgr.(*dbOIDCManager)
  1070. if assert.True(t, ok) {
  1071. _, err = dbMgr.decodePendingAuthData(2)
  1072. assert.Error(t, err)
  1073. _, err = dbMgr.decodeTokenData(true)
  1074. assert.Error(t, err)
  1075. }
  1076. }
  1077. func getTestOIDCServer() *httpdServer {
  1078. return &httpdServer{
  1079. binding: Binding{
  1080. OIDC: OIDC{
  1081. ClientID: "sftpgo-client",
  1082. ClientSecret: "jRsmE0SWnuZjP7djBqNq0mrf8QN77j2c",
  1083. ConfigURL: fmt.Sprintf("http://%v/auth/realms/sftpgo", oidcMockAddr),
  1084. RedirectBaseURL: "http://127.0.0.1:8081/",
  1085. UsernameField: "preferred_username",
  1086. RoleField: "sftpgo_role",
  1087. CustomFields: nil,
  1088. },
  1089. },
  1090. enableWebAdmin: true,
  1091. enableWebClient: true,
  1092. }
  1093. }
  1094. func getPreLoginScriptContent(user dataprovider.User, nonJSONResponse bool) []byte {
  1095. content := []byte("#!/bin/sh\n\n")
  1096. if nonJSONResponse {
  1097. content = append(content, []byte("echo 'text response'\n")...)
  1098. return content
  1099. }
  1100. if len(user.Username) > 0 {
  1101. u, _ := json.Marshal(user)
  1102. content = append(content, []byte(fmt.Sprintf("echo '%v'\n", string(u)))...)
  1103. }
  1104. return content
  1105. }