浏览代码

oidc/oauth2: use an opaque state

Signed-off-by: Nicola Murino <[email protected]>
Nicola Murino 11 月之前
父节点
当前提交
d7d08c3d2f
共有 2 个文件被更改,包括 11 次插入5 次删除
  1. 5 3
      internal/httpd/oauth2.go
  2. 6 2
      internal/httpd/oidc.go

+ 5 - 3
internal/httpd/oauth2.go

@@ -15,13 +15,13 @@
 package httpd
 package httpd
 
 
 import (
 import (
+	"crypto/sha256"
+	"encoding/hex"
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
 	"sync"
 	"sync"
 	"time"
 	"time"
 
 
-	"github.com/rs/xid"
-
 	"github.com/drakkan/sftpgo/v2/internal/dataprovider"
 	"github.com/drakkan/sftpgo/v2/internal/dataprovider"
 	"github.com/drakkan/sftpgo/v2/internal/kms"
 	"github.com/drakkan/sftpgo/v2/internal/kms"
 	"github.com/drakkan/sftpgo/v2/internal/logger"
 	"github.com/drakkan/sftpgo/v2/internal/logger"
@@ -53,8 +53,10 @@ type oauth2PendingAuth struct {
 }
 }
 
 
 func newOAuth2PendingAuth(provider int, redirectURL, clientID string, clientSecret *kms.Secret) oauth2PendingAuth {
 func newOAuth2PendingAuth(provider int, redirectURL, clientID string, clientSecret *kms.Secret) oauth2PendingAuth {
+	state := sha256.Sum256(util.GenerateRandomBytes(32))
+
 	return oauth2PendingAuth{
 	return oauth2PendingAuth{
-		State:        xid.New().String(),
+		State:        hex.EncodeToString(state[:]),
 		Provider:     provider,
 		Provider:     provider,
 		ClientID:     clientID,
 		ClientID:     clientID,
 		ClientSecret: clientSecret,
 		ClientSecret: clientSecret,

+ 6 - 2
internal/httpd/oidc.go

@@ -16,6 +16,7 @@ package httpd
 
 
 import (
 import (
 	"context"
 	"context"
+	"crypto/sha256"
 	"encoding/hex"
 	"encoding/hex"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
@@ -202,9 +203,12 @@ type oidcPendingAuth struct {
 }
 }
 
 
 func newOIDCPendingAuth(audience tokenAudience) oidcPendingAuth {
 func newOIDCPendingAuth(audience tokenAudience) oidcPendingAuth {
+	state := sha256.Sum256(util.GenerateRandomBytes(32))
+	nonce := util.GenerateUniqueID()
+
 	return oidcPendingAuth{
 	return oidcPendingAuth{
-		State:    xid.New().String(),
-		Nonce:    hex.EncodeToString(util.GenerateRandomBytes(20)),
+		State:    hex.EncodeToString(state[:]),
+		Nonce:    nonce,
 		Audience: audience,
 		Audience: audience,
 		IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
 		IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
 	}
 	}