oidc_test.go 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883
  1. package httpd
  2. import (
  3. "context"
  4. "fmt"
  5. "net/http"
  6. "net/http/httptest"
  7. "os"
  8. "path/filepath"
  9. "reflect"
  10. "testing"
  11. "time"
  12. "unsafe"
  13. "github.com/coreos/go-oidc/v3/oidc"
  14. "github.com/go-chi/jwtauth/v5"
  15. "github.com/rs/xid"
  16. "github.com/sftpgo/sdk"
  17. "github.com/stretchr/testify/assert"
  18. "github.com/stretchr/testify/require"
  19. "golang.org/x/oauth2"
  20. "github.com/drakkan/sftpgo/v2/common"
  21. "github.com/drakkan/sftpgo/v2/dataprovider"
  22. "github.com/drakkan/sftpgo/v2/kms"
  23. "github.com/drakkan/sftpgo/v2/util"
  24. "github.com/drakkan/sftpgo/v2/vfs"
  25. )
  26. const (
  27. oidcMockAddr = "127.0.0.1:11111"
  28. )
  29. type mockTokenSource struct {
  30. token *oauth2.Token
  31. err error
  32. }
  33. func (t *mockTokenSource) Token() (*oauth2.Token, error) {
  34. return t.token, t.err
  35. }
  36. type mockOAuth2Config struct {
  37. tokenSource *mockTokenSource
  38. authCodeURL string
  39. token *oauth2.Token
  40. err error
  41. }
  42. func (c *mockOAuth2Config) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
  43. return c.authCodeURL
  44. }
  45. func (c *mockOAuth2Config) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
  46. return c.token, c.err
  47. }
  48. func (c *mockOAuth2Config) TokenSource(ctx context.Context, t *oauth2.Token) oauth2.TokenSource {
  49. return c.tokenSource
  50. }
  51. type mockOIDCVerifier struct {
  52. token *oidc.IDToken
  53. err error
  54. }
  55. func (v *mockOIDCVerifier) Verify(ctx context.Context, rawIDToken string) (*oidc.IDToken, error) {
  56. return v.token, v.err
  57. }
  58. // hack because the field is unexported
  59. func setIDTokenClaims(idToken *oidc.IDToken, claims []byte) {
  60. pointerVal := reflect.ValueOf(idToken)
  61. val := reflect.Indirect(pointerVal)
  62. member := val.FieldByName("claims")
  63. ptr := unsafe.Pointer(member.UnsafeAddr())
  64. realPtr := (*[]byte)(ptr)
  65. *realPtr = claims
  66. }
  67. func TestOIDCInitialization(t *testing.T) {
  68. config := OIDC{
  69. ClientID: "sftpgo-client",
  70. ClientSecret: "jRsmE0SWnuZjP7djBqNq0mrf8QN77j2c",
  71. ConfigURL: fmt.Sprintf("http://%v/", oidcMockAddr),
  72. RedirectBaseURL: "http://127.0.0.1:8081/",
  73. UsernameField: "preferred_username",
  74. RoleField: "sftpgo_role",
  75. }
  76. err := config.initialize()
  77. if assert.Error(t, err) {
  78. assert.Contains(t, err.Error(), "oidc: unable to initialize provider")
  79. }
  80. config.ConfigURL = fmt.Sprintf("http://%v/auth/realms/sftpgo", oidcMockAddr)
  81. err = config.initialize()
  82. assert.NoError(t, err)
  83. assert.Equal(t, "http://127.0.0.1:8081"+webOIDCRedirectPath, config.getRedirectURL())
  84. }
  85. func TestOIDCLoginLogout(t *testing.T) {
  86. server := getTestOIDCServer()
  87. err := server.binding.OIDC.initialize()
  88. assert.NoError(t, err)
  89. server.initializeRouter()
  90. rr := httptest.NewRecorder()
  91. r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath, nil)
  92. assert.NoError(t, err)
  93. server.router.ServeHTTP(rr, r)
  94. assert.Equal(t, http.StatusBadRequest, rr.Code)
  95. assert.Contains(t, rr.Body.String(), "Authentication state did not match")
  96. expiredAuthReq := oidcPendingAuth{
  97. State: xid.New().String(),
  98. Nonce: xid.New().String(),
  99. Audience: tokenAudienceWebClient,
  100. IssueAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)),
  101. }
  102. oidcMgr.addPendingAuth(expiredAuthReq)
  103. rr = httptest.NewRecorder()
  104. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+expiredAuthReq.State, nil)
  105. assert.NoError(t, err)
  106. server.router.ServeHTTP(rr, r)
  107. assert.Equal(t, http.StatusBadRequest, rr.Code)
  108. assert.Contains(t, rr.Body.String(), "Authentication state did not match")
  109. oidcMgr.removePendingAuth(expiredAuthReq.State)
  110. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  111. tokenSource: &mockTokenSource{},
  112. authCodeURL: webOIDCRedirectPath,
  113. err: common.ErrGenericFailure,
  114. }
  115. server.binding.OIDC.verifier = &mockOIDCVerifier{
  116. err: common.ErrGenericFailure,
  117. }
  118. rr = httptest.NewRecorder()
  119. r, err = http.NewRequest(http.MethodGet, webAdminOIDCLoginPath, nil)
  120. assert.NoError(t, err)
  121. server.router.ServeHTTP(rr, r)
  122. assert.Equal(t, http.StatusFound, rr.Code)
  123. assert.Equal(t, webOIDCRedirectPath, rr.Header().Get("Location"))
  124. require.Len(t, oidcMgr.pendingAuths, 1)
  125. var state string
  126. for k := range oidcMgr.pendingAuths {
  127. state = k
  128. }
  129. rr = httptest.NewRecorder()
  130. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+state, nil)
  131. assert.NoError(t, err)
  132. server.router.ServeHTTP(rr, r)
  133. assert.Equal(t, http.StatusFound, rr.Code)
  134. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  135. require.Len(t, oidcMgr.pendingAuths, 0)
  136. rr = httptest.NewRecorder()
  137. r, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil)
  138. assert.NoError(t, err)
  139. server.router.ServeHTTP(rr, r)
  140. assert.Equal(t, http.StatusOK, rr.Code)
  141. // now the same for the web client
  142. rr = httptest.NewRecorder()
  143. r, err = http.NewRequest(http.MethodGet, webClientOIDCLoginPath, nil)
  144. assert.NoError(t, err)
  145. server.router.ServeHTTP(rr, r)
  146. assert.Equal(t, http.StatusFound, rr.Code)
  147. assert.Equal(t, webOIDCRedirectPath, rr.Header().Get("Location"))
  148. require.Len(t, oidcMgr.pendingAuths, 1)
  149. for k := range oidcMgr.pendingAuths {
  150. state = k
  151. }
  152. rr = httptest.NewRecorder()
  153. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+state, nil)
  154. assert.NoError(t, err)
  155. server.router.ServeHTTP(rr, r)
  156. assert.Equal(t, http.StatusFound, rr.Code)
  157. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  158. require.Len(t, oidcMgr.pendingAuths, 0)
  159. rr = httptest.NewRecorder()
  160. r, err = http.NewRequest(http.MethodGet, webClientLoginPath, nil)
  161. assert.NoError(t, err)
  162. server.router.ServeHTTP(rr, r)
  163. assert.Equal(t, http.StatusOK, rr.Code)
  164. // now return an OAuth2 token without the id_token
  165. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  166. tokenSource: &mockTokenSource{},
  167. authCodeURL: webOIDCRedirectPath,
  168. token: &oauth2.Token{
  169. AccessToken: "123",
  170. Expiry: time.Now().Add(5 * time.Minute),
  171. },
  172. err: nil,
  173. }
  174. authReq := newOIDCPendingAuth(tokenAudienceWebClient)
  175. oidcMgr.addPendingAuth(authReq)
  176. rr = httptest.NewRecorder()
  177. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  178. assert.NoError(t, err)
  179. server.router.ServeHTTP(rr, r)
  180. assert.Equal(t, http.StatusFound, rr.Code)
  181. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  182. require.Len(t, oidcMgr.pendingAuths, 0)
  183. // now fail to verify the id token
  184. token := &oauth2.Token{
  185. AccessToken: "123",
  186. Expiry: time.Now().Add(5 * time.Minute),
  187. }
  188. token = token.WithExtra(map[string]interface{}{
  189. "id_token": "id_token_val",
  190. })
  191. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  192. tokenSource: &mockTokenSource{},
  193. authCodeURL: webOIDCRedirectPath,
  194. token: token,
  195. err: nil,
  196. }
  197. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  198. oidcMgr.addPendingAuth(authReq)
  199. rr = httptest.NewRecorder()
  200. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  201. assert.NoError(t, err)
  202. server.router.ServeHTTP(rr, r)
  203. assert.Equal(t, http.StatusFound, rr.Code)
  204. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  205. require.Len(t, oidcMgr.pendingAuths, 0)
  206. // id token nonce does not match
  207. server.binding.OIDC.verifier = &mockOIDCVerifier{
  208. err: nil,
  209. token: &oidc.IDToken{},
  210. }
  211. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  212. oidcMgr.addPendingAuth(authReq)
  213. rr = httptest.NewRecorder()
  214. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  215. assert.NoError(t, err)
  216. server.router.ServeHTTP(rr, r)
  217. assert.Equal(t, http.StatusFound, rr.Code)
  218. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  219. require.Len(t, oidcMgr.pendingAuths, 0)
  220. // null id token claims
  221. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  222. oidcMgr.addPendingAuth(authReq)
  223. server.binding.OIDC.verifier = &mockOIDCVerifier{
  224. err: nil,
  225. token: &oidc.IDToken{
  226. Nonce: authReq.Nonce,
  227. },
  228. }
  229. rr = httptest.NewRecorder()
  230. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  231. assert.NoError(t, err)
  232. server.router.ServeHTTP(rr, r)
  233. assert.Equal(t, http.StatusFound, rr.Code)
  234. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  235. require.Len(t, oidcMgr.pendingAuths, 0)
  236. // invalid id token claims (no username)
  237. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  238. oidcMgr.addPendingAuth(authReq)
  239. idToken := &oidc.IDToken{
  240. Nonce: authReq.Nonce,
  241. Expiry: time.Now().Add(5 * time.Minute),
  242. }
  243. setIDTokenClaims(idToken, []byte(`{}`))
  244. server.binding.OIDC.verifier = &mockOIDCVerifier{
  245. err: nil,
  246. token: idToken,
  247. }
  248. rr = httptest.NewRecorder()
  249. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  250. assert.NoError(t, err)
  251. server.router.ServeHTTP(rr, r)
  252. assert.Equal(t, http.StatusFound, rr.Code)
  253. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  254. require.Len(t, oidcMgr.pendingAuths, 0)
  255. // invalid audience
  256. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  257. oidcMgr.addPendingAuth(authReq)
  258. idToken = &oidc.IDToken{
  259. Nonce: authReq.Nonce,
  260. Expiry: time.Now().Add(5 * time.Minute),
  261. }
  262. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test","sftpgo_role":"admin"}`))
  263. server.binding.OIDC.verifier = &mockOIDCVerifier{
  264. err: nil,
  265. token: idToken,
  266. }
  267. rr = httptest.NewRecorder()
  268. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  269. assert.NoError(t, err)
  270. server.router.ServeHTTP(rr, r)
  271. assert.Equal(t, http.StatusFound, rr.Code)
  272. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  273. require.Len(t, oidcMgr.pendingAuths, 0)
  274. // invalid audience
  275. authReq = newOIDCPendingAuth(tokenAudienceWebAdmin)
  276. oidcMgr.addPendingAuth(authReq)
  277. idToken = &oidc.IDToken{
  278. Nonce: authReq.Nonce,
  279. Expiry: time.Now().Add(5 * time.Minute),
  280. }
  281. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test"}`))
  282. server.binding.OIDC.verifier = &mockOIDCVerifier{
  283. err: nil,
  284. token: idToken,
  285. }
  286. rr = httptest.NewRecorder()
  287. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  288. assert.NoError(t, err)
  289. server.router.ServeHTTP(rr, r)
  290. assert.Equal(t, http.StatusFound, rr.Code)
  291. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  292. require.Len(t, oidcMgr.pendingAuths, 0)
  293. // mapped user not found
  294. authReq = newOIDCPendingAuth(tokenAudienceWebAdmin)
  295. oidcMgr.addPendingAuth(authReq)
  296. idToken = &oidc.IDToken{
  297. Nonce: authReq.Nonce,
  298. Expiry: time.Now().Add(5 * time.Minute),
  299. }
  300. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test","sftpgo_role":"admin"}`))
  301. server.binding.OIDC.verifier = &mockOIDCVerifier{
  302. err: nil,
  303. token: idToken,
  304. }
  305. rr = httptest.NewRecorder()
  306. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  307. assert.NoError(t, err)
  308. server.router.ServeHTTP(rr, r)
  309. assert.Equal(t, http.StatusFound, rr.Code)
  310. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  311. require.Len(t, oidcMgr.pendingAuths, 0)
  312. // admin login ok
  313. authReq = newOIDCPendingAuth(tokenAudienceWebAdmin)
  314. oidcMgr.addPendingAuth(authReq)
  315. idToken = &oidc.IDToken{
  316. Nonce: authReq.Nonce,
  317. Expiry: time.Now().Add(5 * time.Minute),
  318. }
  319. setIDTokenClaims(idToken, []byte(`{"preferred_username":"admin","sftpgo_role":"admin","sid":"sid123"}`))
  320. server.binding.OIDC.verifier = &mockOIDCVerifier{
  321. err: nil,
  322. token: idToken,
  323. }
  324. rr = httptest.NewRecorder()
  325. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  326. assert.NoError(t, err)
  327. server.router.ServeHTTP(rr, r)
  328. assert.Equal(t, http.StatusFound, rr.Code)
  329. assert.Equal(t, webUsersPath, rr.Header().Get("Location"))
  330. require.Len(t, oidcMgr.pendingAuths, 0)
  331. require.Len(t, oidcMgr.tokens, 1)
  332. // admin profile is not available
  333. var tokenCookie string
  334. for k := range oidcMgr.tokens {
  335. tokenCookie = k
  336. }
  337. oidcToken, err := oidcMgr.getToken(tokenCookie)
  338. assert.NoError(t, err)
  339. assert.Equal(t, "sid123", oidcToken.SessionID)
  340. assert.True(t, oidcToken.isAdmin())
  341. assert.False(t, oidcToken.isExpired())
  342. rr = httptest.NewRecorder()
  343. r, err = http.NewRequest(http.MethodGet, webAdminProfilePath, nil)
  344. assert.NoError(t, err)
  345. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  346. server.router.ServeHTTP(rr, r)
  347. assert.Equal(t, http.StatusForbidden, rr.Code)
  348. // the admin can access the allowed pages
  349. rr = httptest.NewRecorder()
  350. r, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
  351. assert.NoError(t, err)
  352. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  353. server.router.ServeHTTP(rr, r)
  354. assert.Equal(t, http.StatusOK, rr.Code)
  355. // try with an invalid cookie
  356. rr = httptest.NewRecorder()
  357. r, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
  358. assert.NoError(t, err)
  359. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, xid.New().String()))
  360. server.router.ServeHTTP(rr, r)
  361. assert.Equal(t, http.StatusFound, rr.Code)
  362. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  363. // Web Client is not available with an admin token
  364. rr = httptest.NewRecorder()
  365. r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
  366. assert.NoError(t, err)
  367. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  368. server.router.ServeHTTP(rr, r)
  369. assert.Equal(t, http.StatusFound, rr.Code)
  370. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  371. // logout the admin user
  372. rr = httptest.NewRecorder()
  373. r, err = http.NewRequest(http.MethodGet, webLogoutPath, nil)
  374. assert.NoError(t, err)
  375. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  376. server.router.ServeHTTP(rr, r)
  377. assert.Equal(t, http.StatusFound, rr.Code)
  378. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  379. require.Len(t, oidcMgr.pendingAuths, 0)
  380. require.Len(t, oidcMgr.tokens, 0)
  381. // now login and logout a user
  382. username := "test_oidc_user"
  383. user := dataprovider.User{
  384. BaseUser: sdk.BaseUser{
  385. Username: username,
  386. Password: "pwd",
  387. HomeDir: filepath.Join(os.TempDir(), username),
  388. Status: 1,
  389. Permissions: map[string][]string{
  390. "/": {dataprovider.PermAny},
  391. },
  392. },
  393. Filters: dataprovider.UserFilters{
  394. BaseUserFilters: sdk.BaseUserFilters{
  395. WebClient: []string{sdk.WebClientSharesDisabled},
  396. },
  397. },
  398. }
  399. err = dataprovider.AddUser(&user, "", "")
  400. assert.NoError(t, err)
  401. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  402. oidcMgr.addPendingAuth(authReq)
  403. idToken = &oidc.IDToken{
  404. Nonce: authReq.Nonce,
  405. Expiry: time.Now().Add(5 * time.Minute),
  406. }
  407. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test_oidc_user"}`))
  408. server.binding.OIDC.verifier = &mockOIDCVerifier{
  409. err: nil,
  410. token: idToken,
  411. }
  412. rr = httptest.NewRecorder()
  413. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  414. assert.NoError(t, err)
  415. server.router.ServeHTTP(rr, r)
  416. assert.Equal(t, http.StatusFound, rr.Code)
  417. assert.Equal(t, webClientFilesPath, rr.Header().Get("Location"))
  418. require.Len(t, oidcMgr.pendingAuths, 0)
  419. require.Len(t, oidcMgr.tokens, 1)
  420. // user profile is not available
  421. for k := range oidcMgr.tokens {
  422. tokenCookie = k
  423. }
  424. oidcToken, err = oidcMgr.getToken(tokenCookie)
  425. assert.NoError(t, err)
  426. assert.Empty(t, oidcToken.SessionID)
  427. assert.False(t, oidcToken.isAdmin())
  428. assert.False(t, oidcToken.isExpired())
  429. if assert.Len(t, oidcToken.Permissions, 1) {
  430. assert.Equal(t, sdk.WebClientSharesDisabled, oidcToken.Permissions[0])
  431. }
  432. rr = httptest.NewRecorder()
  433. r, err = http.NewRequest(http.MethodGet, webClientProfilePath, nil)
  434. assert.NoError(t, err)
  435. r.RequestURI = webClientProfilePath
  436. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  437. server.router.ServeHTTP(rr, r)
  438. assert.Equal(t, http.StatusForbidden, rr.Code)
  439. // the user can access the allowed pages
  440. rr = httptest.NewRecorder()
  441. r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
  442. assert.NoError(t, err)
  443. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  444. server.router.ServeHTTP(rr, r)
  445. assert.Equal(t, http.StatusOK, rr.Code)
  446. // try with an invalid cookie
  447. rr = httptest.NewRecorder()
  448. r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
  449. assert.NoError(t, err)
  450. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, xid.New().String()))
  451. server.router.ServeHTTP(rr, r)
  452. assert.Equal(t, http.StatusFound, rr.Code)
  453. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  454. // Web Admin is not available with a client cookie
  455. rr = httptest.NewRecorder()
  456. r, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
  457. assert.NoError(t, err)
  458. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  459. server.router.ServeHTTP(rr, r)
  460. assert.Equal(t, http.StatusFound, rr.Code)
  461. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  462. // logout the user
  463. rr = httptest.NewRecorder()
  464. r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  465. assert.NoError(t, err)
  466. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  467. server.router.ServeHTTP(rr, r)
  468. assert.Equal(t, http.StatusFound, rr.Code)
  469. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  470. require.Len(t, oidcMgr.pendingAuths, 0)
  471. require.Len(t, oidcMgr.tokens, 0)
  472. err = os.RemoveAll(user.GetHomeDir())
  473. assert.NoError(t, err)
  474. err = dataprovider.DeleteUser(username, "", "")
  475. assert.NoError(t, err)
  476. }
  477. func TestOIDCRefreshToken(t *testing.T) {
  478. token := oidcToken{
  479. Cookie: xid.New().String(),
  480. AccessToken: xid.New().String(),
  481. TokenType: "Bearer",
  482. ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-1 * time.Minute)),
  483. Nonce: xid.New().String(),
  484. }
  485. config := mockOAuth2Config{
  486. tokenSource: &mockTokenSource{
  487. err: common.ErrGenericFailure,
  488. },
  489. }
  490. verifier := mockOIDCVerifier{
  491. err: common.ErrGenericFailure,
  492. }
  493. err := token.refresh(&config, &verifier)
  494. if assert.Error(t, err) {
  495. assert.Contains(t, err.Error(), "refresh token not set")
  496. }
  497. token.RefreshToken = xid.New().String()
  498. err = token.refresh(&config, &verifier)
  499. assert.ErrorIs(t, err, common.ErrGenericFailure)
  500. newToken := &oauth2.Token{
  501. AccessToken: xid.New().String(),
  502. RefreshToken: xid.New().String(),
  503. Expiry: time.Now().Add(5 * time.Minute),
  504. }
  505. config = mockOAuth2Config{
  506. tokenSource: &mockTokenSource{
  507. token: newToken,
  508. },
  509. }
  510. verifier = mockOIDCVerifier{
  511. token: &oidc.IDToken{},
  512. }
  513. err = token.refresh(&config, &verifier)
  514. if assert.Error(t, err) {
  515. assert.Contains(t, err.Error(), "the refreshed token has no id token")
  516. }
  517. newToken = newToken.WithExtra(map[string]interface{}{
  518. "id_token": "id_token_val",
  519. })
  520. newToken.Expiry = time.Time{}
  521. config = mockOAuth2Config{
  522. tokenSource: &mockTokenSource{
  523. token: newToken,
  524. },
  525. }
  526. verifier = mockOIDCVerifier{
  527. err: common.ErrGenericFailure,
  528. }
  529. err = token.refresh(&config, &verifier)
  530. assert.ErrorIs(t, err, common.ErrGenericFailure)
  531. newToken = newToken.WithExtra(map[string]interface{}{
  532. "id_token": "id_token_val",
  533. })
  534. newToken.Expiry = time.Now().Add(5 * time.Minute)
  535. config = mockOAuth2Config{
  536. tokenSource: &mockTokenSource{
  537. token: newToken,
  538. },
  539. }
  540. verifier = mockOIDCVerifier{
  541. token: &oidc.IDToken{},
  542. }
  543. err = token.refresh(&config, &verifier)
  544. if assert.Error(t, err) {
  545. assert.Contains(t, err.Error(), "the refreshed token nonce mismatch")
  546. }
  547. verifier = mockOIDCVerifier{
  548. token: &oidc.IDToken{
  549. Nonce: token.Nonce,
  550. },
  551. }
  552. err = token.refresh(&config, &verifier)
  553. if assert.Error(t, err) {
  554. assert.Contains(t, err.Error(), "oidc: claims not set")
  555. }
  556. idToken := &oidc.IDToken{
  557. Nonce: token.Nonce,
  558. }
  559. setIDTokenClaims(idToken, []byte(`{"sid":"id_token_sid"}`))
  560. verifier = mockOIDCVerifier{
  561. token: idToken,
  562. }
  563. err = token.refresh(&config, &verifier)
  564. assert.NoError(t, err)
  565. require.Len(t, oidcMgr.tokens, 1)
  566. oidcMgr.removeToken(token.Cookie)
  567. require.Len(t, oidcMgr.tokens, 0)
  568. }
  569. func TestValidateOIDCToken(t *testing.T) {
  570. server := getTestOIDCServer()
  571. err := server.binding.OIDC.initialize()
  572. assert.NoError(t, err)
  573. server.initializeRouter()
  574. rr := httptest.NewRecorder()
  575. r, err := http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  576. assert.NoError(t, err)
  577. _, err = server.validateOIDCToken(rr, r, false)
  578. assert.ErrorIs(t, err, errInvalidToken)
  579. // expired token and refresh error
  580. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  581. tokenSource: &mockTokenSource{
  582. err: common.ErrGenericFailure,
  583. },
  584. }
  585. token := oidcToken{
  586. Cookie: xid.New().String(),
  587. AccessToken: xid.New().String(),
  588. ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-2 * time.Minute)),
  589. }
  590. oidcMgr.addToken(token)
  591. rr = httptest.NewRecorder()
  592. r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  593. assert.NoError(t, err)
  594. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie))
  595. _, err = server.validateOIDCToken(rr, r, false)
  596. assert.ErrorIs(t, err, errInvalidToken)
  597. oidcMgr.removeToken(token.Cookie)
  598. assert.Len(t, oidcMgr.tokens, 0)
  599. server.tokenAuth = jwtauth.New("PS256", util.GenerateRandomBytes(32), nil)
  600. token = oidcToken{
  601. Cookie: xid.New().String(),
  602. AccessToken: xid.New().String(),
  603. }
  604. oidcMgr.addToken(token)
  605. rr = httptest.NewRecorder()
  606. r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  607. assert.NoError(t, err)
  608. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie))
  609. server.router.ServeHTTP(rr, r)
  610. assert.Equal(t, http.StatusFound, rr.Code)
  611. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  612. oidcMgr.removeToken(token.Cookie)
  613. assert.Len(t, oidcMgr.tokens, 0)
  614. token = oidcToken{
  615. Cookie: xid.New().String(),
  616. AccessToken: xid.New().String(),
  617. Role: "admin",
  618. }
  619. oidcMgr.addToken(token)
  620. rr = httptest.NewRecorder()
  621. r, err = http.NewRequest(http.MethodGet, webLogoutPath, nil)
  622. assert.NoError(t, err)
  623. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie))
  624. server.router.ServeHTTP(rr, r)
  625. assert.Equal(t, http.StatusFound, rr.Code)
  626. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  627. oidcMgr.removeToken(token.Cookie)
  628. assert.Len(t, oidcMgr.tokens, 0)
  629. }
  630. func TestSkipOIDCAuth(t *testing.T) {
  631. server := getTestOIDCServer()
  632. err := server.binding.OIDC.initialize()
  633. assert.NoError(t, err)
  634. server.initializeRouter()
  635. jwtTokenClaims := jwtTokenClaims{
  636. Username: "user",
  637. }
  638. _, tokenString, err := jwtTokenClaims.createToken(server.tokenAuth, tokenAudienceWebClient)
  639. assert.NoError(t, err)
  640. rr := httptest.NewRecorder()
  641. r, err := http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  642. assert.NoError(t, err)
  643. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", jwtCookieKey, tokenString))
  644. server.router.ServeHTTP(rr, r)
  645. assert.Equal(t, http.StatusFound, rr.Code)
  646. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  647. }
  648. func TestOIDCLogoutErrors(t *testing.T) {
  649. server := getTestOIDCServer()
  650. assert.Empty(t, server.binding.OIDC.providerLogoutURL)
  651. server.logoutFromOIDCOP("")
  652. server.binding.OIDC.providerLogoutURL = "http://foo\x7f.com/"
  653. server.doOIDCFromLogout("")
  654. server.binding.OIDC.providerLogoutURL = "http://127.0.0.1:11234"
  655. server.doOIDCFromLogout("")
  656. }
  657. func TestOIDCToken(t *testing.T) {
  658. admin := dataprovider.Admin{
  659. Username: "test_oidc_admin",
  660. Password: "p",
  661. Permissions: []string{dataprovider.PermAdminAny},
  662. Status: 0,
  663. }
  664. err := dataprovider.AddAdmin(&admin, "", "")
  665. assert.NoError(t, err)
  666. token := oidcToken{
  667. Username: admin.Username,
  668. Role: "admin",
  669. }
  670. req, err := http.NewRequest(http.MethodGet, webUsersPath, nil)
  671. assert.NoError(t, err)
  672. err = token.getUser(req)
  673. if assert.Error(t, err) {
  674. assert.Contains(t, err.Error(), "is disabled")
  675. }
  676. err = dataprovider.DeleteAdmin(admin.Username, "", "")
  677. assert.NoError(t, err)
  678. username := "test_oidc_user"
  679. token.Username = username
  680. token.Role = ""
  681. err = token.getUser(req)
  682. if assert.Error(t, err) {
  683. _, ok := err.(*util.RecordNotFoundError)
  684. assert.True(t, ok)
  685. }
  686. user := dataprovider.User{
  687. BaseUser: sdk.BaseUser{
  688. Username: username,
  689. Password: "p",
  690. HomeDir: filepath.Join(os.TempDir(), username),
  691. Status: 0,
  692. Permissions: map[string][]string{
  693. "/": {dataprovider.PermAny},
  694. },
  695. },
  696. Filters: dataprovider.UserFilters{
  697. BaseUserFilters: sdk.BaseUserFilters{
  698. DeniedProtocols: []string{common.ProtocolHTTP},
  699. },
  700. },
  701. }
  702. err = dataprovider.AddUser(&user, "", "")
  703. assert.NoError(t, err)
  704. err = token.getUser(req)
  705. if assert.Error(t, err) {
  706. assert.Contains(t, err.Error(), "is disabled")
  707. }
  708. user, err = dataprovider.UserExists(username)
  709. assert.NoError(t, err)
  710. user.Status = 1
  711. user.Password = "np"
  712. err = dataprovider.UpdateUser(&user, "", "")
  713. assert.NoError(t, err)
  714. err = token.getUser(req)
  715. if assert.Error(t, err) {
  716. assert.Contains(t, err.Error(), "protocol HTTP is not allowed")
  717. }
  718. user.Filters.DeniedProtocols = nil
  719. user.FsConfig.Provider = sdk.SFTPFilesystemProvider
  720. user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{
  721. BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{
  722. Endpoint: "127.0.0.1:8022",
  723. Username: username,
  724. },
  725. Password: kms.NewPlainSecret("np"),
  726. }
  727. err = dataprovider.UpdateUser(&user, "", "")
  728. assert.NoError(t, err)
  729. err = token.getUser(req)
  730. if assert.Error(t, err) {
  731. assert.Contains(t, err.Error(), "SFTP loop")
  732. }
  733. common.Config.PostConnectHook = fmt.Sprintf("http://%v/404", oidcMockAddr)
  734. err = token.getUser(req)
  735. if assert.Error(t, err) {
  736. assert.Contains(t, err.Error(), "access denied by post connect hook")
  737. }
  738. common.Config.PostConnectHook = ""
  739. err = os.RemoveAll(user.GetHomeDir())
  740. assert.NoError(t, err)
  741. err = dataprovider.DeleteUser(username, "", "")
  742. assert.NoError(t, err)
  743. }
  744. func getTestOIDCServer() *httpdServer {
  745. return &httpdServer{
  746. binding: Binding{
  747. OIDC: OIDC{
  748. ClientID: "sftpgo-client",
  749. ClientSecret: "jRsmE0SWnuZjP7djBqNq0mrf8QN77j2c",
  750. ConfigURL: fmt.Sprintf("http://%v/auth/realms/sftpgo", oidcMockAddr),
  751. RedirectBaseURL: "http://127.0.0.1:8081/",
  752. UsernameField: "preferred_username",
  753. RoleField: "sftpgo_role",
  754. },
  755. },
  756. enableWebAdmin: true,
  757. enableWebClient: true,
  758. }
  759. }
  760. func TestOIDCManager(t *testing.T) {
  761. require.Len(t, oidcMgr.pendingAuths, 0)
  762. authReq := newOIDCPendingAuth(tokenAudienceWebAdmin)
  763. oidcMgr.addPendingAuth(authReq)
  764. require.Len(t, oidcMgr.pendingAuths, 1)
  765. _, err := oidcMgr.getPendingAuth(authReq.State)
  766. assert.NoError(t, err)
  767. oidcMgr.removePendingAuth(authReq.State)
  768. require.Len(t, oidcMgr.pendingAuths, 0)
  769. authReq.IssueAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-61 * time.Second))
  770. oidcMgr.addPendingAuth(authReq)
  771. require.Len(t, oidcMgr.pendingAuths, 1)
  772. _, err = oidcMgr.getPendingAuth(authReq.State)
  773. if assert.Error(t, err) {
  774. assert.Contains(t, err.Error(), "too old")
  775. }
  776. oidcMgr.checkCleanup()
  777. require.Len(t, oidcMgr.pendingAuths, 1)
  778. oidcMgr.lastCleanup = time.Now().Add(-1 * time.Hour)
  779. oidcMgr.checkCleanup()
  780. require.Len(t, oidcMgr.pendingAuths, 0)
  781. assert.True(t, oidcMgr.lastCleanup.After(time.Now().Add(-10*time.Second)))
  782. token := oidcToken{
  783. AccessToken: xid.New().String(),
  784. Nonce: xid.New().String(),
  785. SessionID: xid.New().String(),
  786. Cookie: xid.New().String(),
  787. Username: xid.New().String(),
  788. Role: "admin",
  789. Permissions: []string{dataprovider.PermAdminAny},
  790. }
  791. require.Len(t, oidcMgr.tokens, 0)
  792. oidcMgr.addToken(token)
  793. require.Len(t, oidcMgr.tokens, 1)
  794. _, err = oidcMgr.getToken(xid.New().String())
  795. assert.Error(t, err)
  796. storedToken, err := oidcMgr.getToken(token.Cookie)
  797. assert.NoError(t, err)
  798. assert.Greater(t, storedToken.UsedAt, int64(0))
  799. token.UsedAt = storedToken.UsedAt
  800. assert.Equal(t, token, storedToken)
  801. // the usage will not be updated, it is recent
  802. oidcMgr.updateTokenUsage(storedToken)
  803. storedToken, err = oidcMgr.getToken(token.Cookie)
  804. assert.NoError(t, err)
  805. assert.Equal(t, token, storedToken)
  806. usedAt := util.GetTimeAsMsSinceEpoch(time.Now().Add(-5 * time.Minute))
  807. storedToken.UsedAt = usedAt
  808. oidcMgr.tokens[token.Cookie] = storedToken
  809. storedToken, err = oidcMgr.getToken(token.Cookie)
  810. assert.NoError(t, err)
  811. assert.Equal(t, usedAt, storedToken.UsedAt)
  812. token.UsedAt = storedToken.UsedAt
  813. assert.Equal(t, token, storedToken)
  814. oidcMgr.updateTokenUsage(storedToken)
  815. storedToken, err = oidcMgr.getToken(token.Cookie)
  816. assert.NoError(t, err)
  817. assert.Greater(t, storedToken.UsedAt, usedAt)
  818. token.UsedAt = storedToken.UsedAt
  819. assert.Equal(t, token, storedToken)
  820. oidcMgr.removeToken(xid.New().String())
  821. require.Len(t, oidcMgr.tokens, 1)
  822. oidcMgr.removeToken(token.Cookie)
  823. require.Len(t, oidcMgr.tokens, 0)
  824. oidcMgr.addToken(token)
  825. usedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-6 * time.Hour))
  826. token.UsedAt = usedAt
  827. oidcMgr.tokens[token.Cookie] = token
  828. newToken := oidcToken{
  829. Cookie: xid.New().String(),
  830. }
  831. oidcMgr.lastCleanup = time.Now().Add(-1 * time.Hour)
  832. oidcMgr.addToken(newToken)
  833. require.Len(t, oidcMgr.tokens, 1)
  834. _, err = oidcMgr.getToken(token.Cookie)
  835. assert.Error(t, err)
  836. _, err = oidcMgr.getToken(newToken.Cookie)
  837. assert.NoError(t, err)
  838. oidcMgr.removeToken(newToken.Cookie)
  839. require.Len(t, oidcMgr.tokens, 0)
  840. }