瀏覽代碼

jwt: increase leeway and add some tests

also export a constant for the Cookie name

Signed-off-by: Nicola Murino <[email protected]>
Nicola Murino 1 周之前
父節點
當前提交
a768dac29d
共有 6 個文件被更改,包括 42 次插入9 次删除
  1. 3 4
      internal/httpd/auth_utils.go
  2. 1 1
      internal/httpd/oidc.go
  3. 1 1
      internal/httpd/oidc_test.go
  4. 1 1
      internal/httpd/server.go
  5. 6 2
      internal/jwt/jwt.go
  6. 30 0
      internal/jwt/jwt_test.go

+ 3 - 4
internal/httpd/auth_utils.go

@@ -49,8 +49,7 @@ const (
 )
 
 const (
-	basicRealm   = "Basic realm=\"SFTPGo\""
-	jwtCookieKey = "jwt"
+	basicRealm = "Basic realm=\"SFTPGo\""
 )
 
 var (
@@ -142,7 +141,7 @@ func createAndSetCookie(w http.ResponseWriter, r *http.Request, claims *jwt.Clai
 
 func setCookie(w http.ResponseWriter, r *http.Request, cookiePath, cookieValue string, duration time.Duration) {
 	http.SetCookie(w, &http.Cookie{
-		Name:     jwtCookieKey,
+		Name:     jwt.CookieKey,
 		Value:    cookieValue,
 		Path:     cookiePath,
 		Expires:  time.Now().Add(duration),
@@ -156,7 +155,7 @@ func setCookie(w http.ResponseWriter, r *http.Request, cookiePath, cookieValue s
 func removeCookie(w http.ResponseWriter, r *http.Request, cookiePath string) {
 	invalidateToken(r)
 	http.SetCookie(w, &http.Cookie{
-		Name:     jwtCookieKey,
+		Name:     jwt.CookieKey,
 		Value:    "",
 		Path:     cookiePath,
 		Expires:  time.Unix(0, 0),

+ 1 - 1
internal/httpd/oidc.go

@@ -809,7 +809,7 @@ func removeOIDCCookie(w http.ResponseWriter, r *http.Request) {
 func canSkipOIDCValidation(r *http.Request) bool {
 	_, err := r.Cookie(oidcCookieKey)
 	if err != nil {
-		_, err = r.Cookie(jwtCookieKey)
+		_, err = r.Cookie(jwt.CookieKey)
 		return err == nil
 	}
 	return false

+ 1 - 1
internal/httpd/oidc_test.go

@@ -845,7 +845,7 @@ func TestSkipOIDCAuth(t *testing.T) {
 	rr := httptest.NewRecorder()
 	r, err := http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
 	assert.NoError(t, err)
-	r.Header.Set("Cookie", fmt.Sprintf("%v=%v", jwtCookieKey, tokenString))
+	r.Header.Set("Cookie", fmt.Sprintf("%v=%v", jwt.CookieKey, tokenString))
 	server.router.ServeHTTP(rr, r)
 	assert.Equal(t, http.StatusFound, rr.Code)
 	assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))

+ 1 - 1
internal/httpd/server.go

@@ -1071,7 +1071,7 @@ func (s *httpdServer) refreshAdminToken(w http.ResponseWriter, r *http.Request,
 func (s *httpdServer) updateContextFromCookie(r *http.Request) *http.Request {
 	_, err := jwt.FromContext(r.Context())
 	if err != nil {
-		_, err = r.Cookie(jwtCookieKey)
+		_, err = r.Cookie(jwt.CookieKey)
 		if err != nil {
 			return r
 		}

+ 6 - 2
internal/jwt/jwt.go

@@ -30,6 +30,10 @@ import (
 	"github.com/rs/xid"
 )
 
+const (
+	CookieKey = "jwt"
+)
+
 var (
 	TokenCtxKey = &contextKey{"Token"}
 	ErrorCtxKey = &contextKey{"Error"}
@@ -235,7 +239,7 @@ func VerifyTokenWithKey(payload string, algo []jose.SignatureAlgorithm, key any)
 	if err != nil {
 		return nil, err
 	}
-	if err := claims.ValidateWithLeeway(jwt.Expected{Time: time.Now()}, 15*time.Second); err != nil {
+	if err := claims.ValidateWithLeeway(jwt.Expected{Time: time.Now()}, 30*time.Second); err != nil {
 		return nil, err
 	}
 	return &claims, nil
@@ -244,7 +248,7 @@ func VerifyTokenWithKey(payload string, algo []jose.SignatureAlgorithm, key any)
 // TokenFromCookie tries to retrieve the token string from a cookie named
 // "jwt".
 func TokenFromCookie(r *http.Request) string {
-	cookie, err := r.Cookie("jwt")
+	cookie, err := r.Cookie(CookieKey)
 	if err != nil {
 		return ""
 	}

+ 30 - 0
internal/jwt/jwt_test.go

@@ -223,3 +223,33 @@ func TestContext(t *testing.T) {
 
 	assert.Equal(t, "jwt context value Token", TokenCtxKey.String())
 }
+
+func TestValidationLeeway(t *testing.T) {
+	s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32))
+	require.NoError(t, err)
+	claims := &Claims{}
+	claims.Audience = []string{util.GenerateUniqueID()}
+	claims.SetIssuedAt(time.Now().Add(10 * time.Second)) // issued at in the future
+	claims.SetExpiry(time.Now().Add(10 * time.Second))
+	token, err := s.Sign(claims)
+	require.NoError(t, err)
+	_, err = VerifyToken(s, token)
+	assert.NoError(t, err)
+
+	claims = &Claims{}
+	claims.Audience = []string{util.GenerateUniqueID()}
+	claims.SetExpiry(time.Now().Add(-10 * time.Second)) // expired
+	token, err = s.Sign(claims)
+	require.NoError(t, err)
+	_, err = VerifyToken(s, token)
+	assert.NoError(t, err)
+
+	claims = &Claims{}
+	claims.Audience = []string{util.GenerateUniqueID()}
+	claims.SetExpiry(time.Now().Add(30 * time.Second))
+	claims.SetNotBefore(time.Now().Add(10 * time.Second)) // not before in the future
+	token, err = s.Sign(claims)
+	require.NoError(t, err)
+	_, err = VerifyToken(s, token)
+	assert.NoError(t, err)
+}