|
@@ -7,8 +7,8 @@ import (
|
|
|
"os"
|
|
|
"strings"
|
|
|
"sync"
|
|
|
- "time"
|
|
|
|
|
|
+ "github.com/drakkan/sftpgo/v2/logger"
|
|
|
"github.com/drakkan/sftpgo/v2/util"
|
|
|
)
|
|
|
|
|
@@ -26,8 +26,13 @@ type SecretProvider interface {
|
|
|
SetKey(string)
|
|
|
SetAdditionalData(string)
|
|
|
SetStatus(SecretStatus)
|
|
|
+ Clone() SecretProvider
|
|
|
}
|
|
|
|
|
|
+const (
|
|
|
+ logSender = "kms"
|
|
|
+)
|
|
|
+
|
|
|
// SecretStatus defines the statuses of a Secret object
|
|
|
type SecretStatus = string
|
|
|
|
|
@@ -51,12 +56,16 @@ const (
|
|
|
SecretStatusRedacted SecretStatus = "Redacted"
|
|
|
)
|
|
|
|
|
|
+// Scheme defines the supported URL scheme
|
|
|
+type Scheme = string
|
|
|
+
|
|
|
+// supported URL schemes
|
|
|
const (
|
|
|
- localProviderName = "Local"
|
|
|
- builtinProviderName = "Builtin"
|
|
|
- awsProviderName = "AWS"
|
|
|
- gcpProviderName = "GCP"
|
|
|
- vaultProviderName = "VaultTransit"
|
|
|
+ SchemeLocal Scheme = "local://"
|
|
|
+ SchemeBuiltin Scheme = "builtin://"
|
|
|
+ SchemeAWS Scheme = "awskms://"
|
|
|
+ SchemeGCP Scheme = "gcpkms://"
|
|
|
+ SchemeVaultTransit Scheme = "hashivault://"
|
|
|
)
|
|
|
|
|
|
// Configuration defines the KMS configuration
|
|
@@ -71,16 +80,32 @@ type Secrets struct {
|
|
|
masterKey string
|
|
|
}
|
|
|
|
|
|
+type registeredSecretProvider struct {
|
|
|
+ encryptedStatus SecretStatus
|
|
|
+ newFn func(base BaseSecret, url, masterKey string) SecretProvider
|
|
|
+}
|
|
|
+
|
|
|
var (
|
|
|
- errWrongSecretStatus = errors.New("wrong secret status")
|
|
|
+ // ErrWrongSecretStatus defines the error to return if the secret status is not appropriate
|
|
|
+ // for the request operation
|
|
|
+ ErrWrongSecretStatus = errors.New("wrong secret status")
|
|
|
+ // ErrInvalidSecret defines the error to return if a secret is not valid
|
|
|
+ ErrInvalidSecret = errors.New("invalid secret")
|
|
|
errMalformedCiphertext = errors.New("malformed ciphertext")
|
|
|
- errInvalidSecret = errors.New("invalid secret")
|
|
|
validSecretStatuses = []string{SecretStatusPlain, SecretStatusAES256GCM, SecretStatusSecretBox,
|
|
|
SecretStatusVaultTransit, SecretStatusAWS, SecretStatusGCP, SecretStatusRedacted}
|
|
|
- config Configuration
|
|
|
- defaultTimeout = 10 * time.Second
|
|
|
+ config Configuration
|
|
|
+ secretProviders = make(map[string]registeredSecretProvider)
|
|
|
)
|
|
|
|
|
|
+// RegisterSecretProvider register a new secret provider
|
|
|
+func RegisterSecretProvider(scheme string, encryptedStatus SecretStatus, fn func(base BaseSecret, url, masterKey string) SecretProvider) {
|
|
|
+ secretProviders[scheme] = registeredSecretProvider{
|
|
|
+ encryptedStatus: encryptedStatus,
|
|
|
+ newFn: fn,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
// NewSecret builds a new Secret using the provided arguments
|
|
|
func NewSecret(status SecretStatus, payload, key, data string) *Secret {
|
|
|
return config.newSecret(status, payload, key, data)
|
|
@@ -115,11 +140,18 @@ func (c *Configuration) Initialize() error {
|
|
|
c.Secrets.masterKey = strings.TrimSpace(string(mKey))
|
|
|
}
|
|
|
config = *c
|
|
|
+ if config.Secrets.URL == "" {
|
|
|
+ config.Secrets.URL = "local://"
|
|
|
+ }
|
|
|
+ for k, v := range secretProviders {
|
|
|
+ logger.Debug(logSender, "", "secret provider registered for scheme: %#v, encrypted status: %#v",
|
|
|
+ k, v.encryptedStatus)
|
|
|
+ }
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
func (c *Configuration) newSecret(status SecretStatus, payload, key, data string) *Secret {
|
|
|
- base := baseSecret{
|
|
|
+ base := BaseSecret{
|
|
|
Status: status,
|
|
|
Key: key,
|
|
|
Payload: payload,
|
|
@@ -130,17 +162,13 @@ func (c *Configuration) newSecret(status SecretStatus, payload, key, data string
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (c *Configuration) getSecretProvider(base baseSecret) SecretProvider {
|
|
|
- if strings.HasPrefix(c.Secrets.URL, "hashivault://") {
|
|
|
- return newVaultSecret(base, c.Secrets.URL, c.Secrets.masterKey)
|
|
|
- }
|
|
|
- if strings.HasPrefix(c.Secrets.URL, "awskms://") {
|
|
|
- return newAWSSecret(base, c.Secrets.URL, c.Secrets.masterKey)
|
|
|
- }
|
|
|
- if strings.HasPrefix(c.Secrets.URL, "gcpkms://") {
|
|
|
- return newGCPSecret(base, c.Secrets.URL, c.Secrets.masterKey)
|
|
|
+func (c *Configuration) getSecretProvider(base BaseSecret) SecretProvider {
|
|
|
+ for k, v := range secretProviders {
|
|
|
+ if strings.HasPrefix(c.Secrets.URL, k) {
|
|
|
+ return v.newFn(base, c.Secrets.URL, c.Secrets.masterKey)
|
|
|
+ }
|
|
|
}
|
|
|
- return newLocalSecret(base, c.Secrets.masterKey)
|
|
|
+ return NewLocalSecret(base, c.Secrets.URL, c.Secrets.masterKey)
|
|
|
}
|
|
|
|
|
|
// Secret defines the struct used to store confidential data
|
|
@@ -154,7 +182,7 @@ func (s *Secret) MarshalJSON() ([]byte, error) {
|
|
|
s.RLock()
|
|
|
defer s.RUnlock()
|
|
|
|
|
|
- return json.Marshal(&baseSecret{
|
|
|
+ return json.Marshal(&BaseSecret{
|
|
|
Status: s.provider.GetStatus(),
|
|
|
Payload: s.provider.GetPayload(),
|
|
|
Key: s.provider.GetKey(),
|
|
@@ -169,7 +197,7 @@ func (s *Secret) UnmarshalJSON(data []byte) error {
|
|
|
s.Lock()
|
|
|
defer s.Unlock()
|
|
|
|
|
|
- baseSecret := baseSecret{}
|
|
|
+ baseSecret := BaseSecret{}
|
|
|
err := json.Unmarshal(data, &baseSecret)
|
|
|
if err != nil {
|
|
|
return err
|
|
@@ -178,23 +206,21 @@ func (s *Secret) UnmarshalJSON(data []byte) error {
|
|
|
s.provider = config.getSecretProvider(baseSecret)
|
|
|
return nil
|
|
|
}
|
|
|
- switch baseSecret.Status {
|
|
|
- case SecretStatusAES256GCM:
|
|
|
- s.provider = newBuiltinSecret(baseSecret)
|
|
|
- case SecretStatusSecretBox:
|
|
|
- s.provider = newLocalSecret(baseSecret, config.Secrets.masterKey)
|
|
|
- case SecretStatusVaultTransit:
|
|
|
- s.provider = newVaultSecret(baseSecret, config.Secrets.URL, config.Secrets.masterKey)
|
|
|
- case SecretStatusAWS:
|
|
|
- s.provider = newAWSSecret(baseSecret, config.Secrets.URL, config.Secrets.masterKey)
|
|
|
- case SecretStatusGCP:
|
|
|
- s.provider = newGCPSecret(baseSecret, config.Secrets.URL, config.Secrets.masterKey)
|
|
|
- case SecretStatusPlain, SecretStatusRedacted:
|
|
|
+
|
|
|
+ if baseSecret.Status == SecretStatusPlain || baseSecret.Status == SecretStatusRedacted {
|
|
|
s.provider = config.getSecretProvider(baseSecret)
|
|
|
- default:
|
|
|
- return errInvalidSecret
|
|
|
+ return nil
|
|
|
}
|
|
|
- return nil
|
|
|
+
|
|
|
+ for _, v := range secretProviders {
|
|
|
+ if v.encryptedStatus == baseSecret.Status {
|
|
|
+ s.provider = v.newFn(baseSecret, config.Secrets.URL, config.Secrets.masterKey)
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+ logger.Debug(logSender, "", "no provider registered for status %#v", baseSecret.Status)
|
|
|
+
|
|
|
+ return ErrInvalidSecret
|
|
|
}
|
|
|
|
|
|
// IsEqual returns true if all the secrets fields are equal
|
|
@@ -222,36 +248,9 @@ func (s *Secret) Clone() *Secret {
|
|
|
s.RLock()
|
|
|
defer s.RUnlock()
|
|
|
|
|
|
- baseSecret := baseSecret{
|
|
|
- Status: s.provider.GetStatus(),
|
|
|
- Payload: s.provider.GetPayload(),
|
|
|
- Key: s.provider.GetKey(),
|
|
|
- AdditionalData: s.provider.GetAdditionalData(),
|
|
|
- Mode: s.provider.GetMode(),
|
|
|
- }
|
|
|
- switch s.provider.Name() {
|
|
|
- case builtinProviderName:
|
|
|
- return &Secret{
|
|
|
- provider: newBuiltinSecret(baseSecret),
|
|
|
- }
|
|
|
- case awsProviderName:
|
|
|
- return &Secret{
|
|
|
- provider: newAWSSecret(baseSecret, config.Secrets.URL, config.Secrets.masterKey),
|
|
|
- }
|
|
|
- case gcpProviderName:
|
|
|
- return &Secret{
|
|
|
- provider: newGCPSecret(baseSecret, config.Secrets.URL, config.Secrets.masterKey),
|
|
|
- }
|
|
|
- case localProviderName:
|
|
|
- return &Secret{
|
|
|
- provider: newLocalSecret(baseSecret, config.Secrets.masterKey),
|
|
|
- }
|
|
|
- case vaultProviderName:
|
|
|
- return &Secret{
|
|
|
- provider: newVaultSecret(baseSecret, config.Secrets.URL, config.Secrets.masterKey),
|
|
|
- }
|
|
|
+ return &Secret{
|
|
|
+ provider: s.provider.Clone(),
|
|
|
}
|
|
|
- return NewSecret(s.GetStatus(), s.GetPayload(), s.GetKey(), s.GetAdditionalData())
|
|
|
}
|
|
|
|
|
|
// IsEncrypted returns true if the secret is encrypted
|