oidc_test.go 33 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036
  1. package httpd
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "net/http"
  7. "net/http/httptest"
  8. "os"
  9. "path/filepath"
  10. "reflect"
  11. "runtime"
  12. "testing"
  13. "time"
  14. "unsafe"
  15. "github.com/coreos/go-oidc/v3/oidc"
  16. "github.com/go-chi/jwtauth/v5"
  17. "github.com/rs/xid"
  18. "github.com/sftpgo/sdk"
  19. "github.com/stretchr/testify/assert"
  20. "github.com/stretchr/testify/require"
  21. "golang.org/x/oauth2"
  22. "github.com/drakkan/sftpgo/v2/common"
  23. "github.com/drakkan/sftpgo/v2/dataprovider"
  24. "github.com/drakkan/sftpgo/v2/kms"
  25. "github.com/drakkan/sftpgo/v2/util"
  26. "github.com/drakkan/sftpgo/v2/vfs"
  27. )
  28. const (
  29. oidcMockAddr = "127.0.0.1:11111"
  30. )
  31. type mockTokenSource struct {
  32. token *oauth2.Token
  33. err error
  34. }
  35. func (t *mockTokenSource) Token() (*oauth2.Token, error) {
  36. return t.token, t.err
  37. }
  38. type mockOAuth2Config struct {
  39. tokenSource *mockTokenSource
  40. authCodeURL string
  41. token *oauth2.Token
  42. err error
  43. }
  44. func (c *mockOAuth2Config) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
  45. return c.authCodeURL
  46. }
  47. func (c *mockOAuth2Config) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
  48. return c.token, c.err
  49. }
  50. func (c *mockOAuth2Config) TokenSource(ctx context.Context, t *oauth2.Token) oauth2.TokenSource {
  51. return c.tokenSource
  52. }
  53. type mockOIDCVerifier struct {
  54. token *oidc.IDToken
  55. err error
  56. }
  57. func (v *mockOIDCVerifier) Verify(ctx context.Context, rawIDToken string) (*oidc.IDToken, error) {
  58. return v.token, v.err
  59. }
  60. // hack because the field is unexported
  61. func setIDTokenClaims(idToken *oidc.IDToken, claims []byte) {
  62. pointerVal := reflect.ValueOf(idToken)
  63. val := reflect.Indirect(pointerVal)
  64. member := val.FieldByName("claims")
  65. ptr := unsafe.Pointer(member.UnsafeAddr())
  66. realPtr := (*[]byte)(ptr)
  67. *realPtr = claims
  68. }
  69. func TestOIDCInitialization(t *testing.T) {
  70. config := OIDC{}
  71. err := config.initialize()
  72. assert.NoError(t, err)
  73. config = OIDC{
  74. ClientID: "sftpgo-client",
  75. ClientSecret: "jRsmE0SWnuZjP7djBqNq0mrf8QN77j2c",
  76. ConfigURL: fmt.Sprintf("http://%v/", oidcMockAddr),
  77. RedirectBaseURL: "http://127.0.0.1:8081/",
  78. UsernameField: "preferred_username",
  79. RoleField: "sftpgo_role",
  80. }
  81. err = config.initialize()
  82. if assert.Error(t, err) {
  83. assert.Contains(t, err.Error(), "oidc: unable to initialize provider")
  84. }
  85. config.ConfigURL = fmt.Sprintf("http://%v/auth/realms/sftpgo", oidcMockAddr)
  86. err = config.initialize()
  87. assert.NoError(t, err)
  88. assert.Equal(t, "http://127.0.0.1:8081"+webOIDCRedirectPath, config.getRedirectURL())
  89. }
  90. func TestOIDCLoginLogout(t *testing.T) {
  91. server := getTestOIDCServer()
  92. err := server.binding.OIDC.initialize()
  93. assert.NoError(t, err)
  94. server.initializeRouter()
  95. rr := httptest.NewRecorder()
  96. r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath, nil)
  97. assert.NoError(t, err)
  98. server.router.ServeHTTP(rr, r)
  99. assert.Equal(t, http.StatusBadRequest, rr.Code)
  100. assert.Contains(t, rr.Body.String(), "Authentication state did not match")
  101. expiredAuthReq := oidcPendingAuth{
  102. State: xid.New().String(),
  103. Nonce: xid.New().String(),
  104. Audience: tokenAudienceWebClient,
  105. IssueAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)),
  106. }
  107. oidcMgr.addPendingAuth(expiredAuthReq)
  108. rr = httptest.NewRecorder()
  109. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+expiredAuthReq.State, nil)
  110. assert.NoError(t, err)
  111. server.router.ServeHTTP(rr, r)
  112. assert.Equal(t, http.StatusBadRequest, rr.Code)
  113. assert.Contains(t, rr.Body.String(), "Authentication state did not match")
  114. oidcMgr.removePendingAuth(expiredAuthReq.State)
  115. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  116. tokenSource: &mockTokenSource{},
  117. authCodeURL: webOIDCRedirectPath,
  118. err: common.ErrGenericFailure,
  119. }
  120. server.binding.OIDC.verifier = &mockOIDCVerifier{
  121. err: common.ErrGenericFailure,
  122. }
  123. rr = httptest.NewRecorder()
  124. r, err = http.NewRequest(http.MethodGet, webAdminOIDCLoginPath, nil)
  125. assert.NoError(t, err)
  126. server.router.ServeHTTP(rr, r)
  127. assert.Equal(t, http.StatusFound, rr.Code)
  128. assert.Equal(t, webOIDCRedirectPath, rr.Header().Get("Location"))
  129. require.Len(t, oidcMgr.pendingAuths, 1)
  130. var state string
  131. for k := range oidcMgr.pendingAuths {
  132. state = k
  133. }
  134. rr = httptest.NewRecorder()
  135. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+state, nil)
  136. assert.NoError(t, err)
  137. server.router.ServeHTTP(rr, r)
  138. assert.Equal(t, http.StatusFound, rr.Code)
  139. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  140. require.Len(t, oidcMgr.pendingAuths, 0)
  141. rr = httptest.NewRecorder()
  142. r, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil)
  143. assert.NoError(t, err)
  144. server.router.ServeHTTP(rr, r)
  145. assert.Equal(t, http.StatusOK, rr.Code)
  146. // now the same for the web client
  147. rr = httptest.NewRecorder()
  148. r, err = http.NewRequest(http.MethodGet, webClientOIDCLoginPath, nil)
  149. assert.NoError(t, err)
  150. server.router.ServeHTTP(rr, r)
  151. assert.Equal(t, http.StatusFound, rr.Code)
  152. assert.Equal(t, webOIDCRedirectPath, rr.Header().Get("Location"))
  153. require.Len(t, oidcMgr.pendingAuths, 1)
  154. for k := range oidcMgr.pendingAuths {
  155. state = k
  156. }
  157. rr = httptest.NewRecorder()
  158. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+state, nil)
  159. assert.NoError(t, err)
  160. server.router.ServeHTTP(rr, r)
  161. assert.Equal(t, http.StatusFound, rr.Code)
  162. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  163. require.Len(t, oidcMgr.pendingAuths, 0)
  164. rr = httptest.NewRecorder()
  165. r, err = http.NewRequest(http.MethodGet, webClientLoginPath, nil)
  166. assert.NoError(t, err)
  167. server.router.ServeHTTP(rr, r)
  168. assert.Equal(t, http.StatusOK, rr.Code)
  169. // now return an OAuth2 token without the id_token
  170. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  171. tokenSource: &mockTokenSource{},
  172. authCodeURL: webOIDCRedirectPath,
  173. token: &oauth2.Token{
  174. AccessToken: "123",
  175. Expiry: time.Now().Add(5 * time.Minute),
  176. },
  177. err: nil,
  178. }
  179. authReq := newOIDCPendingAuth(tokenAudienceWebClient)
  180. oidcMgr.addPendingAuth(authReq)
  181. rr = httptest.NewRecorder()
  182. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  183. assert.NoError(t, err)
  184. server.router.ServeHTTP(rr, r)
  185. assert.Equal(t, http.StatusFound, rr.Code)
  186. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  187. require.Len(t, oidcMgr.pendingAuths, 0)
  188. // now fail to verify the id token
  189. token := &oauth2.Token{
  190. AccessToken: "123",
  191. Expiry: time.Now().Add(5 * time.Minute),
  192. }
  193. token = token.WithExtra(map[string]interface{}{
  194. "id_token": "id_token_val",
  195. })
  196. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  197. tokenSource: &mockTokenSource{},
  198. authCodeURL: webOIDCRedirectPath,
  199. token: token,
  200. err: nil,
  201. }
  202. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  203. oidcMgr.addPendingAuth(authReq)
  204. rr = httptest.NewRecorder()
  205. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  206. assert.NoError(t, err)
  207. server.router.ServeHTTP(rr, r)
  208. assert.Equal(t, http.StatusFound, rr.Code)
  209. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  210. require.Len(t, oidcMgr.pendingAuths, 0)
  211. // id token nonce does not match
  212. server.binding.OIDC.verifier = &mockOIDCVerifier{
  213. err: nil,
  214. token: &oidc.IDToken{},
  215. }
  216. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  217. oidcMgr.addPendingAuth(authReq)
  218. rr = httptest.NewRecorder()
  219. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  220. assert.NoError(t, err)
  221. server.router.ServeHTTP(rr, r)
  222. assert.Equal(t, http.StatusFound, rr.Code)
  223. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  224. require.Len(t, oidcMgr.pendingAuths, 0)
  225. // null id token claims
  226. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  227. oidcMgr.addPendingAuth(authReq)
  228. server.binding.OIDC.verifier = &mockOIDCVerifier{
  229. err: nil,
  230. token: &oidc.IDToken{
  231. Nonce: authReq.Nonce,
  232. },
  233. }
  234. rr = httptest.NewRecorder()
  235. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  236. assert.NoError(t, err)
  237. server.router.ServeHTTP(rr, r)
  238. assert.Equal(t, http.StatusFound, rr.Code)
  239. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  240. require.Len(t, oidcMgr.pendingAuths, 0)
  241. // invalid id token claims (no username)
  242. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  243. oidcMgr.addPendingAuth(authReq)
  244. idToken := &oidc.IDToken{
  245. Nonce: authReq.Nonce,
  246. Expiry: time.Now().Add(5 * time.Minute),
  247. }
  248. setIDTokenClaims(idToken, []byte(`{"aud": "my_client_id"}`))
  249. server.binding.OIDC.verifier = &mockOIDCVerifier{
  250. err: nil,
  251. token: idToken,
  252. }
  253. rr = httptest.NewRecorder()
  254. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  255. assert.NoError(t, err)
  256. server.router.ServeHTTP(rr, r)
  257. assert.Equal(t, http.StatusFound, rr.Code)
  258. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  259. require.Len(t, oidcMgr.pendingAuths, 0)
  260. // invalid audience
  261. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  262. oidcMgr.addPendingAuth(authReq)
  263. idToken = &oidc.IDToken{
  264. Nonce: authReq.Nonce,
  265. Expiry: time.Now().Add(5 * time.Minute),
  266. }
  267. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test","sftpgo_role":"admin"}`))
  268. server.binding.OIDC.verifier = &mockOIDCVerifier{
  269. err: nil,
  270. token: idToken,
  271. }
  272. rr = httptest.NewRecorder()
  273. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  274. assert.NoError(t, err)
  275. server.router.ServeHTTP(rr, r)
  276. assert.Equal(t, http.StatusFound, rr.Code)
  277. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  278. require.Len(t, oidcMgr.pendingAuths, 0)
  279. // invalid audience
  280. authReq = newOIDCPendingAuth(tokenAudienceWebAdmin)
  281. oidcMgr.addPendingAuth(authReq)
  282. idToken = &oidc.IDToken{
  283. Nonce: authReq.Nonce,
  284. Expiry: time.Now().Add(5 * time.Minute),
  285. }
  286. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test"}`))
  287. server.binding.OIDC.verifier = &mockOIDCVerifier{
  288. err: nil,
  289. token: idToken,
  290. }
  291. rr = httptest.NewRecorder()
  292. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  293. assert.NoError(t, err)
  294. server.router.ServeHTTP(rr, r)
  295. assert.Equal(t, http.StatusFound, rr.Code)
  296. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  297. require.Len(t, oidcMgr.pendingAuths, 0)
  298. // mapped user not found
  299. authReq = newOIDCPendingAuth(tokenAudienceWebAdmin)
  300. oidcMgr.addPendingAuth(authReq)
  301. idToken = &oidc.IDToken{
  302. Nonce: authReq.Nonce,
  303. Expiry: time.Now().Add(5 * time.Minute),
  304. }
  305. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test","sftpgo_role":"admin"}`))
  306. server.binding.OIDC.verifier = &mockOIDCVerifier{
  307. err: nil,
  308. token: idToken,
  309. }
  310. rr = httptest.NewRecorder()
  311. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  312. assert.NoError(t, err)
  313. server.router.ServeHTTP(rr, r)
  314. assert.Equal(t, http.StatusFound, rr.Code)
  315. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  316. require.Len(t, oidcMgr.pendingAuths, 0)
  317. // admin login ok
  318. authReq = newOIDCPendingAuth(tokenAudienceWebAdmin)
  319. oidcMgr.addPendingAuth(authReq)
  320. idToken = &oidc.IDToken{
  321. Nonce: authReq.Nonce,
  322. Expiry: time.Now().Add(5 * time.Minute),
  323. }
  324. setIDTokenClaims(idToken, []byte(`{"preferred_username":"admin","sftpgo_role":"admin","sid":"sid123"}`))
  325. server.binding.OIDC.verifier = &mockOIDCVerifier{
  326. err: nil,
  327. token: idToken,
  328. }
  329. rr = httptest.NewRecorder()
  330. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  331. assert.NoError(t, err)
  332. server.router.ServeHTTP(rr, r)
  333. assert.Equal(t, http.StatusFound, rr.Code)
  334. assert.Equal(t, webUsersPath, rr.Header().Get("Location"))
  335. require.Len(t, oidcMgr.pendingAuths, 0)
  336. require.Len(t, oidcMgr.tokens, 1)
  337. // admin profile is not available
  338. var tokenCookie string
  339. for k := range oidcMgr.tokens {
  340. tokenCookie = k
  341. }
  342. oidcToken, err := oidcMgr.getToken(tokenCookie)
  343. assert.NoError(t, err)
  344. assert.Equal(t, "sid123", oidcToken.SessionID)
  345. assert.True(t, oidcToken.isAdmin())
  346. assert.False(t, oidcToken.isExpired())
  347. rr = httptest.NewRecorder()
  348. r, err = http.NewRequest(http.MethodGet, webAdminProfilePath, nil)
  349. assert.NoError(t, err)
  350. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  351. server.router.ServeHTTP(rr, r)
  352. assert.Equal(t, http.StatusForbidden, rr.Code)
  353. // the admin can access the allowed pages
  354. rr = httptest.NewRecorder()
  355. r, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
  356. assert.NoError(t, err)
  357. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  358. server.router.ServeHTTP(rr, r)
  359. assert.Equal(t, http.StatusOK, rr.Code)
  360. // try with an invalid cookie
  361. rr = httptest.NewRecorder()
  362. r, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
  363. assert.NoError(t, err)
  364. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, xid.New().String()))
  365. server.router.ServeHTTP(rr, r)
  366. assert.Equal(t, http.StatusFound, rr.Code)
  367. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  368. // Web Client is not available with an admin token
  369. rr = httptest.NewRecorder()
  370. r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
  371. assert.NoError(t, err)
  372. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  373. server.router.ServeHTTP(rr, r)
  374. assert.Equal(t, http.StatusFound, rr.Code)
  375. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  376. // logout the admin user
  377. rr = httptest.NewRecorder()
  378. r, err = http.NewRequest(http.MethodGet, webLogoutPath, nil)
  379. assert.NoError(t, err)
  380. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  381. server.router.ServeHTTP(rr, r)
  382. assert.Equal(t, http.StatusFound, rr.Code)
  383. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  384. require.Len(t, oidcMgr.pendingAuths, 0)
  385. require.Len(t, oidcMgr.tokens, 0)
  386. // now login and logout a user
  387. username := "test_oidc_user"
  388. user := dataprovider.User{
  389. BaseUser: sdk.BaseUser{
  390. Username: username,
  391. Password: "pwd",
  392. HomeDir: filepath.Join(os.TempDir(), username),
  393. Status: 1,
  394. Permissions: map[string][]string{
  395. "/": {dataprovider.PermAny},
  396. },
  397. },
  398. Filters: dataprovider.UserFilters{
  399. BaseUserFilters: sdk.BaseUserFilters{
  400. WebClient: []string{sdk.WebClientSharesDisabled},
  401. },
  402. },
  403. }
  404. err = dataprovider.AddUser(&user, "", "")
  405. assert.NoError(t, err)
  406. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  407. oidcMgr.addPendingAuth(authReq)
  408. idToken = &oidc.IDToken{
  409. Nonce: authReq.Nonce,
  410. Expiry: time.Now().Add(5 * time.Minute),
  411. }
  412. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test_oidc_user"}`))
  413. server.binding.OIDC.verifier = &mockOIDCVerifier{
  414. err: nil,
  415. token: idToken,
  416. }
  417. rr = httptest.NewRecorder()
  418. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  419. assert.NoError(t, err)
  420. server.router.ServeHTTP(rr, r)
  421. assert.Equal(t, http.StatusFound, rr.Code)
  422. assert.Equal(t, webClientFilesPath, rr.Header().Get("Location"))
  423. require.Len(t, oidcMgr.pendingAuths, 0)
  424. require.Len(t, oidcMgr.tokens, 1)
  425. // user profile is not available
  426. for k := range oidcMgr.tokens {
  427. tokenCookie = k
  428. }
  429. oidcToken, err = oidcMgr.getToken(tokenCookie)
  430. assert.NoError(t, err)
  431. assert.Empty(t, oidcToken.SessionID)
  432. assert.False(t, oidcToken.isAdmin())
  433. assert.False(t, oidcToken.isExpired())
  434. if assert.Len(t, oidcToken.Permissions, 1) {
  435. assert.Equal(t, sdk.WebClientSharesDisabled, oidcToken.Permissions[0])
  436. }
  437. rr = httptest.NewRecorder()
  438. r, err = http.NewRequest(http.MethodGet, webClientProfilePath, nil)
  439. assert.NoError(t, err)
  440. r.RequestURI = webClientProfilePath
  441. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  442. server.router.ServeHTTP(rr, r)
  443. assert.Equal(t, http.StatusForbidden, rr.Code)
  444. // the user can access the allowed pages
  445. rr = httptest.NewRecorder()
  446. r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
  447. assert.NoError(t, err)
  448. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  449. server.router.ServeHTTP(rr, r)
  450. assert.Equal(t, http.StatusOK, rr.Code)
  451. // try with an invalid cookie
  452. rr = httptest.NewRecorder()
  453. r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
  454. assert.NoError(t, err)
  455. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, xid.New().String()))
  456. server.router.ServeHTTP(rr, r)
  457. assert.Equal(t, http.StatusFound, rr.Code)
  458. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  459. // Web Admin is not available with a client cookie
  460. rr = httptest.NewRecorder()
  461. r, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
  462. assert.NoError(t, err)
  463. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  464. server.router.ServeHTTP(rr, r)
  465. assert.Equal(t, http.StatusFound, rr.Code)
  466. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  467. // logout the user
  468. rr = httptest.NewRecorder()
  469. r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  470. assert.NoError(t, err)
  471. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  472. server.router.ServeHTTP(rr, r)
  473. assert.Equal(t, http.StatusFound, rr.Code)
  474. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  475. require.Len(t, oidcMgr.pendingAuths, 0)
  476. require.Len(t, oidcMgr.tokens, 0)
  477. err = os.RemoveAll(user.GetHomeDir())
  478. assert.NoError(t, err)
  479. err = dataprovider.DeleteUser(username, "", "")
  480. assert.NoError(t, err)
  481. }
  482. func TestOIDCRefreshToken(t *testing.T) {
  483. token := oidcToken{
  484. Cookie: xid.New().String(),
  485. AccessToken: xid.New().String(),
  486. TokenType: "Bearer",
  487. ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-1 * time.Minute)),
  488. Nonce: xid.New().String(),
  489. }
  490. config := mockOAuth2Config{
  491. tokenSource: &mockTokenSource{
  492. err: common.ErrGenericFailure,
  493. },
  494. }
  495. verifier := mockOIDCVerifier{
  496. err: common.ErrGenericFailure,
  497. }
  498. err := token.refresh(&config, &verifier)
  499. if assert.Error(t, err) {
  500. assert.Contains(t, err.Error(), "refresh token not set")
  501. }
  502. token.RefreshToken = xid.New().String()
  503. err = token.refresh(&config, &verifier)
  504. assert.ErrorIs(t, err, common.ErrGenericFailure)
  505. newToken := &oauth2.Token{
  506. AccessToken: xid.New().String(),
  507. RefreshToken: xid.New().String(),
  508. Expiry: time.Now().Add(5 * time.Minute),
  509. }
  510. config = mockOAuth2Config{
  511. tokenSource: &mockTokenSource{
  512. token: newToken,
  513. },
  514. }
  515. verifier = mockOIDCVerifier{
  516. token: &oidc.IDToken{},
  517. }
  518. err = token.refresh(&config, &verifier)
  519. if assert.Error(t, err) {
  520. assert.Contains(t, err.Error(), "the refreshed token has no id token")
  521. }
  522. newToken = newToken.WithExtra(map[string]interface{}{
  523. "id_token": "id_token_val",
  524. })
  525. newToken.Expiry = time.Time{}
  526. config = mockOAuth2Config{
  527. tokenSource: &mockTokenSource{
  528. token: newToken,
  529. },
  530. }
  531. verifier = mockOIDCVerifier{
  532. err: common.ErrGenericFailure,
  533. }
  534. err = token.refresh(&config, &verifier)
  535. assert.ErrorIs(t, err, common.ErrGenericFailure)
  536. newToken = newToken.WithExtra(map[string]interface{}{
  537. "id_token": "id_token_val",
  538. })
  539. newToken.Expiry = time.Now().Add(5 * time.Minute)
  540. config = mockOAuth2Config{
  541. tokenSource: &mockTokenSource{
  542. token: newToken,
  543. },
  544. }
  545. verifier = mockOIDCVerifier{
  546. token: &oidc.IDToken{},
  547. }
  548. err = token.refresh(&config, &verifier)
  549. if assert.Error(t, err) {
  550. assert.Contains(t, err.Error(), "the refreshed token nonce mismatch")
  551. }
  552. verifier = mockOIDCVerifier{
  553. token: &oidc.IDToken{
  554. Nonce: token.Nonce,
  555. },
  556. }
  557. err = token.refresh(&config, &verifier)
  558. if assert.Error(t, err) {
  559. assert.Contains(t, err.Error(), "oidc: claims not set")
  560. }
  561. idToken := &oidc.IDToken{
  562. Nonce: token.Nonce,
  563. }
  564. setIDTokenClaims(idToken, []byte(`{"sid":"id_token_sid"}`))
  565. verifier = mockOIDCVerifier{
  566. token: idToken,
  567. }
  568. err = token.refresh(&config, &verifier)
  569. assert.NoError(t, err)
  570. require.Len(t, oidcMgr.tokens, 1)
  571. oidcMgr.removeToken(token.Cookie)
  572. require.Len(t, oidcMgr.tokens, 0)
  573. }
  574. func TestValidateOIDCToken(t *testing.T) {
  575. server := getTestOIDCServer()
  576. err := server.binding.OIDC.initialize()
  577. assert.NoError(t, err)
  578. server.initializeRouter()
  579. rr := httptest.NewRecorder()
  580. r, err := http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  581. assert.NoError(t, err)
  582. _, err = server.validateOIDCToken(rr, r, false)
  583. assert.ErrorIs(t, err, errInvalidToken)
  584. // expired token and refresh error
  585. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  586. tokenSource: &mockTokenSource{
  587. err: common.ErrGenericFailure,
  588. },
  589. }
  590. token := oidcToken{
  591. Cookie: xid.New().String(),
  592. AccessToken: xid.New().String(),
  593. ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-2 * time.Minute)),
  594. }
  595. oidcMgr.addToken(token)
  596. rr = httptest.NewRecorder()
  597. r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  598. assert.NoError(t, err)
  599. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie))
  600. _, err = server.validateOIDCToken(rr, r, false)
  601. assert.ErrorIs(t, err, errInvalidToken)
  602. oidcMgr.removeToken(token.Cookie)
  603. assert.Len(t, oidcMgr.tokens, 0)
  604. server.tokenAuth = jwtauth.New("PS256", util.GenerateRandomBytes(32), nil)
  605. token = oidcToken{
  606. Cookie: xid.New().String(),
  607. AccessToken: xid.New().String(),
  608. }
  609. oidcMgr.addToken(token)
  610. rr = httptest.NewRecorder()
  611. r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  612. assert.NoError(t, err)
  613. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie))
  614. server.router.ServeHTTP(rr, r)
  615. assert.Equal(t, http.StatusFound, rr.Code)
  616. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  617. oidcMgr.removeToken(token.Cookie)
  618. assert.Len(t, oidcMgr.tokens, 0)
  619. token = oidcToken{
  620. Cookie: xid.New().String(),
  621. AccessToken: xid.New().String(),
  622. Role: "admin",
  623. }
  624. oidcMgr.addToken(token)
  625. rr = httptest.NewRecorder()
  626. r, err = http.NewRequest(http.MethodGet, webLogoutPath, nil)
  627. assert.NoError(t, err)
  628. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie))
  629. server.router.ServeHTTP(rr, r)
  630. assert.Equal(t, http.StatusFound, rr.Code)
  631. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  632. oidcMgr.removeToken(token.Cookie)
  633. assert.Len(t, oidcMgr.tokens, 0)
  634. }
  635. func TestSkipOIDCAuth(t *testing.T) {
  636. server := getTestOIDCServer()
  637. err := server.binding.OIDC.initialize()
  638. assert.NoError(t, err)
  639. server.initializeRouter()
  640. jwtTokenClaims := jwtTokenClaims{
  641. Username: "user",
  642. }
  643. _, tokenString, err := jwtTokenClaims.createToken(server.tokenAuth, tokenAudienceWebClient, "")
  644. assert.NoError(t, err)
  645. rr := httptest.NewRecorder()
  646. r, err := http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  647. assert.NoError(t, err)
  648. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", jwtCookieKey, tokenString))
  649. server.router.ServeHTTP(rr, r)
  650. assert.Equal(t, http.StatusFound, rr.Code)
  651. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  652. }
  653. func TestOIDCLogoutErrors(t *testing.T) {
  654. server := getTestOIDCServer()
  655. assert.Empty(t, server.binding.OIDC.providerLogoutURL)
  656. server.logoutFromOIDCOP("")
  657. server.binding.OIDC.providerLogoutURL = "http://foo\x7f.com/"
  658. server.doOIDCFromLogout("")
  659. server.binding.OIDC.providerLogoutURL = "http://127.0.0.1:11234"
  660. server.doOIDCFromLogout("")
  661. }
  662. func TestOIDCToken(t *testing.T) {
  663. admin := dataprovider.Admin{
  664. Username: "test_oidc_admin",
  665. Password: "p",
  666. Permissions: []string{dataprovider.PermAdminAny},
  667. Status: 0,
  668. }
  669. err := dataprovider.AddAdmin(&admin, "", "")
  670. assert.NoError(t, err)
  671. token := oidcToken{
  672. Username: admin.Username,
  673. Role: "admin",
  674. }
  675. req, err := http.NewRequest(http.MethodGet, webUsersPath, nil)
  676. assert.NoError(t, err)
  677. err = token.getUser(req)
  678. if assert.Error(t, err) {
  679. assert.Contains(t, err.Error(), "is disabled")
  680. }
  681. err = dataprovider.DeleteAdmin(admin.Username, "", "")
  682. assert.NoError(t, err)
  683. username := "test_oidc_user"
  684. token.Username = username
  685. token.Role = ""
  686. err = token.getUser(req)
  687. if assert.Error(t, err) {
  688. _, ok := err.(*util.RecordNotFoundError)
  689. assert.True(t, ok)
  690. }
  691. user := dataprovider.User{
  692. BaseUser: sdk.BaseUser{
  693. Username: username,
  694. Password: "p",
  695. HomeDir: filepath.Join(os.TempDir(), username),
  696. Status: 0,
  697. Permissions: map[string][]string{
  698. "/": {dataprovider.PermAny},
  699. },
  700. },
  701. Filters: dataprovider.UserFilters{
  702. BaseUserFilters: sdk.BaseUserFilters{
  703. DeniedProtocols: []string{common.ProtocolHTTP},
  704. },
  705. },
  706. }
  707. err = dataprovider.AddUser(&user, "", "")
  708. assert.NoError(t, err)
  709. err = token.getUser(req)
  710. if assert.Error(t, err) {
  711. assert.Contains(t, err.Error(), "is disabled")
  712. }
  713. user, err = dataprovider.UserExists(username)
  714. assert.NoError(t, err)
  715. user.Status = 1
  716. user.Password = "np"
  717. err = dataprovider.UpdateUser(&user, "", "")
  718. assert.NoError(t, err)
  719. err = token.getUser(req)
  720. if assert.Error(t, err) {
  721. assert.Contains(t, err.Error(), "protocol HTTP is not allowed")
  722. }
  723. user.Filters.DeniedProtocols = nil
  724. user.FsConfig.Provider = sdk.SFTPFilesystemProvider
  725. user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{
  726. BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{
  727. Endpoint: "127.0.0.1:8022",
  728. Username: username,
  729. },
  730. Password: kms.NewPlainSecret("np"),
  731. }
  732. err = dataprovider.UpdateUser(&user, "", "")
  733. assert.NoError(t, err)
  734. err = token.getUser(req)
  735. if assert.Error(t, err) {
  736. assert.Contains(t, err.Error(), "SFTP loop")
  737. }
  738. common.Config.PostConnectHook = fmt.Sprintf("http://%v/404", oidcMockAddr)
  739. err = token.getUser(req)
  740. if assert.Error(t, err) {
  741. assert.Contains(t, err.Error(), "access denied by post connect hook")
  742. }
  743. common.Config.PostConnectHook = ""
  744. err = os.RemoveAll(user.GetHomeDir())
  745. assert.NoError(t, err)
  746. err = dataprovider.DeleteUser(username, "", "")
  747. assert.NoError(t, err)
  748. }
  749. func TestOIDCManager(t *testing.T) {
  750. require.Len(t, oidcMgr.pendingAuths, 0)
  751. authReq := newOIDCPendingAuth(tokenAudienceWebAdmin)
  752. oidcMgr.addPendingAuth(authReq)
  753. require.Len(t, oidcMgr.pendingAuths, 1)
  754. _, err := oidcMgr.getPendingAuth(authReq.State)
  755. assert.NoError(t, err)
  756. oidcMgr.removePendingAuth(authReq.State)
  757. require.Len(t, oidcMgr.pendingAuths, 0)
  758. authReq.IssueAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-61 * time.Second))
  759. oidcMgr.addPendingAuth(authReq)
  760. require.Len(t, oidcMgr.pendingAuths, 1)
  761. _, err = oidcMgr.getPendingAuth(authReq.State)
  762. if assert.Error(t, err) {
  763. assert.Contains(t, err.Error(), "too old")
  764. }
  765. oidcMgr.checkCleanup()
  766. require.Len(t, oidcMgr.pendingAuths, 1)
  767. oidcMgr.lastCleanup = time.Now().Add(-1 * time.Hour)
  768. oidcMgr.checkCleanup()
  769. require.Len(t, oidcMgr.pendingAuths, 0)
  770. assert.True(t, oidcMgr.lastCleanup.After(time.Now().Add(-10*time.Second)))
  771. token := oidcToken{
  772. AccessToken: xid.New().String(),
  773. Nonce: xid.New().String(),
  774. SessionID: xid.New().String(),
  775. Cookie: xid.New().String(),
  776. Username: xid.New().String(),
  777. Role: "admin",
  778. Permissions: []string{dataprovider.PermAdminAny},
  779. }
  780. require.Len(t, oidcMgr.tokens, 0)
  781. oidcMgr.addToken(token)
  782. require.Len(t, oidcMgr.tokens, 1)
  783. _, err = oidcMgr.getToken(xid.New().String())
  784. assert.Error(t, err)
  785. storedToken, err := oidcMgr.getToken(token.Cookie)
  786. assert.NoError(t, err)
  787. assert.Greater(t, storedToken.UsedAt, int64(0))
  788. token.UsedAt = storedToken.UsedAt
  789. assert.Equal(t, token, storedToken)
  790. // the usage will not be updated, it is recent
  791. oidcMgr.updateTokenUsage(storedToken)
  792. storedToken, err = oidcMgr.getToken(token.Cookie)
  793. assert.NoError(t, err)
  794. assert.Equal(t, token, storedToken)
  795. usedAt := util.GetTimeAsMsSinceEpoch(time.Now().Add(-5 * time.Minute))
  796. storedToken.UsedAt = usedAt
  797. oidcMgr.tokens[token.Cookie] = storedToken
  798. storedToken, err = oidcMgr.getToken(token.Cookie)
  799. assert.NoError(t, err)
  800. assert.Equal(t, usedAt, storedToken.UsedAt)
  801. token.UsedAt = storedToken.UsedAt
  802. assert.Equal(t, token, storedToken)
  803. oidcMgr.updateTokenUsage(storedToken)
  804. storedToken, err = oidcMgr.getToken(token.Cookie)
  805. assert.NoError(t, err)
  806. assert.Greater(t, storedToken.UsedAt, usedAt)
  807. token.UsedAt = storedToken.UsedAt
  808. assert.Equal(t, token, storedToken)
  809. oidcMgr.removeToken(xid.New().String())
  810. require.Len(t, oidcMgr.tokens, 1)
  811. oidcMgr.removeToken(token.Cookie)
  812. require.Len(t, oidcMgr.tokens, 0)
  813. oidcMgr.addToken(token)
  814. usedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-6 * time.Hour))
  815. token.UsedAt = usedAt
  816. oidcMgr.tokens[token.Cookie] = token
  817. newToken := oidcToken{
  818. Cookie: xid.New().String(),
  819. }
  820. oidcMgr.lastCleanup = time.Now().Add(-1 * time.Hour)
  821. oidcMgr.addToken(newToken)
  822. require.Len(t, oidcMgr.tokens, 1)
  823. _, err = oidcMgr.getToken(token.Cookie)
  824. assert.Error(t, err)
  825. _, err = oidcMgr.getToken(newToken.Cookie)
  826. assert.NoError(t, err)
  827. oidcMgr.removeToken(newToken.Cookie)
  828. require.Len(t, oidcMgr.tokens, 0)
  829. }
  830. func TestOIDCPreLoginHook(t *testing.T) {
  831. if runtime.GOOS == osWindows {
  832. t.Skip("this test is not available on Windows")
  833. }
  834. username := "test_oidc_user_prelogin"
  835. u := dataprovider.User{
  836. BaseUser: sdk.BaseUser{
  837. Username: username,
  838. Password: "unused",
  839. HomeDir: filepath.Join(os.TempDir(), username),
  840. Status: 1,
  841. Permissions: map[string][]string{
  842. "/": {dataprovider.PermAny},
  843. },
  844. },
  845. }
  846. preLoginPath := filepath.Join(os.TempDir(), "prelogin.sh")
  847. providerConf := dataprovider.GetProviderConfig()
  848. err := dataprovider.Close()
  849. assert.NoError(t, err)
  850. err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm)
  851. assert.NoError(t, err)
  852. newProviderConf := providerConf
  853. newProviderConf.PreLoginHook = preLoginPath
  854. err = dataprovider.Initialize(newProviderConf, "..", true)
  855. assert.NoError(t, err)
  856. server := getTestOIDCServer()
  857. err = server.binding.OIDC.initialize()
  858. assert.NoError(t, err)
  859. server.initializeRouter()
  860. _, err = dataprovider.UserExists(username)
  861. _, ok := err.(*util.RecordNotFoundError)
  862. assert.True(t, ok)
  863. // now login with OIDC
  864. authReq := newOIDCPendingAuth(tokenAudienceWebClient)
  865. oidcMgr.addPendingAuth(authReq)
  866. token := &oauth2.Token{
  867. AccessToken: "1234",
  868. Expiry: time.Now().Add(5 * time.Minute),
  869. }
  870. token = token.WithExtra(map[string]interface{}{
  871. "id_token": "id_token_val",
  872. })
  873. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  874. tokenSource: &mockTokenSource{},
  875. authCodeURL: webOIDCRedirectPath,
  876. token: token,
  877. }
  878. idToken := &oidc.IDToken{
  879. Nonce: authReq.Nonce,
  880. Expiry: time.Now().Add(5 * time.Minute),
  881. }
  882. setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+username+`"}`))
  883. server.binding.OIDC.verifier = &mockOIDCVerifier{
  884. err: nil,
  885. token: idToken,
  886. }
  887. rr := httptest.NewRecorder()
  888. r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  889. assert.NoError(t, err)
  890. server.router.ServeHTTP(rr, r)
  891. assert.Equal(t, http.StatusFound, rr.Code)
  892. assert.Equal(t, webClientFilesPath, rr.Header().Get("Location"))
  893. _, err = dataprovider.UserExists(username)
  894. assert.NoError(t, err)
  895. err = dataprovider.DeleteUser(username, "", "")
  896. assert.NoError(t, err)
  897. err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, true), os.ModePerm)
  898. assert.NoError(t, err)
  899. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  900. oidcMgr.addPendingAuth(authReq)
  901. idToken = &oidc.IDToken{
  902. Nonce: authReq.Nonce,
  903. Expiry: time.Now().Add(5 * time.Minute),
  904. }
  905. setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+username+`"}`))
  906. server.binding.OIDC.verifier = &mockOIDCVerifier{
  907. err: nil,
  908. token: idToken,
  909. }
  910. rr = httptest.NewRecorder()
  911. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  912. assert.NoError(t, err)
  913. server.router.ServeHTTP(rr, r)
  914. assert.Equal(t, http.StatusFound, rr.Code)
  915. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  916. _, err = dataprovider.UserExists(username)
  917. _, ok = err.(*util.RecordNotFoundError)
  918. assert.True(t, ok)
  919. if assert.Len(t, oidcMgr.tokens, 1) {
  920. for k := range oidcMgr.tokens {
  921. oidcMgr.removeToken(k)
  922. }
  923. }
  924. require.Len(t, oidcMgr.pendingAuths, 0)
  925. require.Len(t, oidcMgr.tokens, 0)
  926. err = dataprovider.Close()
  927. assert.NoError(t, err)
  928. err = dataprovider.Initialize(providerConf, "..", true)
  929. assert.NoError(t, err)
  930. err = os.Remove(preLoginPath)
  931. assert.NoError(t, err)
  932. }
  933. func getTestOIDCServer() *httpdServer {
  934. return &httpdServer{
  935. binding: Binding{
  936. OIDC: OIDC{
  937. ClientID: "sftpgo-client",
  938. ClientSecret: "jRsmE0SWnuZjP7djBqNq0mrf8QN77j2c",
  939. ConfigURL: fmt.Sprintf("http://%v/auth/realms/sftpgo", oidcMockAddr),
  940. RedirectBaseURL: "http://127.0.0.1:8081/",
  941. UsernameField: "preferred_username",
  942. RoleField: "sftpgo_role",
  943. },
  944. },
  945. enableWebAdmin: true,
  946. enableWebClient: true,
  947. }
  948. }
  949. func getPreLoginScriptContent(user dataprovider.User, nonJSONResponse bool) []byte {
  950. content := []byte("#!/bin/sh\n\n")
  951. if nonJSONResponse {
  952. content = append(content, []byte("echo 'text response'\n")...)
  953. return content
  954. }
  955. if len(user.Username) > 0 {
  956. u, _ := json.Marshal(user)
  957. content = append(content, []byte(fmt.Sprintf("echo '%v'\n", string(u)))...)
  958. }
  959. return content
  960. }
  961. func TestOIDCIsAdmin(t *testing.T) {
  962. type test struct {
  963. input interface{}
  964. want bool
  965. }
  966. emptySlice := make([]interface{}, 0)
  967. tests := []test{
  968. {input: "admin", want: true},
  969. {input: append(emptySlice, "admin"), want: true},
  970. {input: append(emptySlice, "user", "admin"), want: true},
  971. {input: "user", want: false},
  972. {input: emptySlice, want: false},
  973. {input: append(emptySlice, 1), want: false},
  974. {input: 1, want: false},
  975. {input: nil, want: false},
  976. }
  977. for _, tc := range tests {
  978. token := oidcToken{
  979. Role: tc.input,
  980. }
  981. assert.Equal(t, tc.want, token.isAdmin(), "%v should return %t", tc.input, tc.want)
  982. }
  983. }