|
@@ -34,7 +34,7 @@ import (
|
|
|
)
|
|
|
|
|
|
const (
|
|
|
- sqlDatabaseVersion = 23
|
|
|
+ sqlDatabaseVersion = 24
|
|
|
defaultSQLQueryTimeout = 10 * time.Second
|
|
|
longSQLQueryTimeout = 60 * time.Second
|
|
|
)
|
|
@@ -78,6 +78,7 @@ func sqlReplaceAll(sql string) string {
|
|
|
sql = strings.ReplaceAll(sql, "{{rules_actions_mapping}}", sqlTableRulesActionsMapping)
|
|
|
sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks)
|
|
|
sql = strings.ReplaceAll(sql, "{{nodes}}", sqlTableNodes)
|
|
|
+ sql = strings.ReplaceAll(sql, "{{roles}}", sqlTableRoles)
|
|
|
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
|
|
|
return sql
|
|
|
}
|
|
@@ -105,7 +106,7 @@ func sqlCommonAddShare(share *Share, dbHandle *sql.DB) error {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
- user, err := provider.userExists(share.Username)
|
|
|
+ user, err := provider.userExists(share.Username, "")
|
|
|
if err != nil {
|
|
|
return util.NewValidationError(fmt.Sprintf("unable to validate user %#v", share.Username))
|
|
|
}
|
|
@@ -165,7 +166,7 @@ func sqlCommonUpdateShare(share *Share, dbHandle *sql.DB) error {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- user, err := provider.userExists(share.Username)
|
|
|
+ user, err := provider.userExists(share.Username, "")
|
|
|
if err != nil {
|
|
|
return util.NewValidationError(fmt.Sprintf("unable to validate user %#v", share.Username))
|
|
|
}
|
|
@@ -431,10 +432,10 @@ func sqlCommonAddAdmin(admin *Admin, dbHandle *sql.DB) error {
|
|
|
defer cancel()
|
|
|
|
|
|
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
|
|
|
- q := getAddAdminQuery()
|
|
|
+ q := getAddAdminQuery(admin.Role)
|
|
|
_, err = tx.ExecContext(ctx, q, admin.Username, admin.Password, admin.Status, admin.Email, string(perms),
|
|
|
string(filters), admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
|
|
|
- util.GetTimeAsMsSinceEpoch(time.Now()))
|
|
|
+ util.GetTimeAsMsSinceEpoch(time.Now()), admin.Role)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
@@ -462,9 +463,9 @@ func sqlCommonUpdateAdmin(admin *Admin, dbHandle *sql.DB) error {
|
|
|
defer cancel()
|
|
|
|
|
|
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
|
|
|
- q := getUpdateAdminQuery()
|
|
|
+ q := getUpdateAdminQuery(admin.Role)
|
|
|
_, err = tx.ExecContext(ctx, q, admin.Password, admin.Status, admin.Email, string(perms), string(filters),
|
|
|
- admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()), admin.Username)
|
|
|
+ admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()), admin.Role, admin.Username)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
@@ -537,6 +538,122 @@ func sqlCommonDumpAdmins(dbHandle sqlQuerier) ([]Admin, error) {
|
|
|
return getAdminsWithGroups(ctx, admins, dbHandle)
|
|
|
}
|
|
|
|
|
|
+func sqlCommonGetRoleByName(name string, dbHandle sqlQuerier) (Role, error) {
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ q := getRoleByNameQuery()
|
|
|
+ row := dbHandle.QueryRowContext(ctx, q, name)
|
|
|
+ role, err := getRoleFromDbRow(row)
|
|
|
+ if err != nil {
|
|
|
+ return role, err
|
|
|
+ }
|
|
|
+ role, err = getRoleWithUsers(ctx, role, dbHandle)
|
|
|
+ if err != nil {
|
|
|
+ return role, err
|
|
|
+ }
|
|
|
+ return getRoleWithAdmins(ctx, role, dbHandle)
|
|
|
+}
|
|
|
+
|
|
|
+func sqlCommonDumpRoles(dbHandle sqlQuerier) ([]Role, error) {
|
|
|
+ roles := make([]Role, 0, 10)
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ q := getDumpRolesQuery()
|
|
|
+
|
|
|
+ rows, err := dbHandle.QueryContext(ctx, q)
|
|
|
+ if err != nil {
|
|
|
+ return roles, err
|
|
|
+ }
|
|
|
+ defer rows.Close()
|
|
|
+
|
|
|
+ for rows.Next() {
|
|
|
+ role, err := getRoleFromDbRow(rows)
|
|
|
+ if err != nil {
|
|
|
+ return roles, err
|
|
|
+ }
|
|
|
+ roles = append(roles, role)
|
|
|
+ }
|
|
|
+ return roles, rows.Err()
|
|
|
+}
|
|
|
+
|
|
|
+func sqlCommonGetRoles(limit int, offset int, order string, minimal bool, dbHandle sqlQuerier) ([]Role, error) {
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ q := getRolesQuery(order, minimal)
|
|
|
+
|
|
|
+ roles := make([]Role, 0, limit)
|
|
|
+ rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
|
|
|
+ if err != nil {
|
|
|
+ return roles, err
|
|
|
+ }
|
|
|
+ defer rows.Close()
|
|
|
+
|
|
|
+ for rows.Next() {
|
|
|
+ var role Role
|
|
|
+ if minimal {
|
|
|
+ err = rows.Scan(&role.ID, &role.Name)
|
|
|
+ } else {
|
|
|
+ role, err = getRoleFromDbRow(rows)
|
|
|
+ }
|
|
|
+ if err != nil {
|
|
|
+ return roles, err
|
|
|
+ }
|
|
|
+ roles = append(roles, role)
|
|
|
+ }
|
|
|
+ err = rows.Err()
|
|
|
+ if err != nil {
|
|
|
+ return roles, err
|
|
|
+ }
|
|
|
+ if minimal {
|
|
|
+ return roles, nil
|
|
|
+ }
|
|
|
+ roles, err = getRolesWithUsers(ctx, roles, dbHandle)
|
|
|
+ if err != nil {
|
|
|
+ return roles, err
|
|
|
+ }
|
|
|
+ return getRolesWithAdmins(ctx, roles, dbHandle)
|
|
|
+}
|
|
|
+
|
|
|
+func sqlCommonAddRole(role *Role, dbHandle *sql.DB) error {
|
|
|
+ if err := role.validate(); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ q := getAddRoleQuery()
|
|
|
+ _, err := dbHandle.ExecContext(ctx, q, role.Name, role.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
|
|
|
+ util.GetTimeAsMsSinceEpoch(time.Now()))
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
+func sqlCommonUpdateRole(role *Role, dbHandle *sql.DB) error {
|
|
|
+ if err := role.validate(); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ q := getUpdateRoleQuery()
|
|
|
+ _, err := dbHandle.ExecContext(ctx, q, role.Description, util.GetTimeAsMsSinceEpoch(time.Now()), role.Name)
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
+func sqlCommonDeleteRole(role Role, dbHandle *sql.DB) error {
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ q := getDeleteRoleQuery()
|
|
|
+ res, err := dbHandle.ExecContext(ctx, q, role.Name)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ return sqlCommonRequireRowAffected(res)
|
|
|
+}
|
|
|
+
|
|
|
func sqlCommonGetGroupByName(name string, dbHandle sqlQuerier) (Group, error) {
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
|
|
defer cancel()
|
|
@@ -756,12 +873,16 @@ func sqlCommonDeleteGroup(group Group, dbHandle *sql.DB) error {
|
|
|
return sqlCommonRequireRowAffected(res)
|
|
|
}
|
|
|
|
|
|
-func sqlCommonGetUserByUsername(username string, dbHandle sqlQuerier) (User, error) {
|
|
|
+func sqlCommonGetUserByUsername(username, role string, dbHandle sqlQuerier) (User, error) {
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
|
|
defer cancel()
|
|
|
|
|
|
- q := getUserByUsernameQuery()
|
|
|
- row := dbHandle.QueryRowContext(ctx, q, username)
|
|
|
+ q := getUserByUsernameQuery(role)
|
|
|
+ args := []any{username}
|
|
|
+ if role != "" {
|
|
|
+ args = append(args, role)
|
|
|
+ }
|
|
|
+ row := dbHandle.QueryRowContext(ctx, q, args...)
|
|
|
user, err := getUserFromDbRow(row)
|
|
|
if err != nil {
|
|
|
return user, err
|
|
@@ -774,7 +895,7 @@ func sqlCommonGetUserByUsername(username string, dbHandle sqlQuerier) (User, err
|
|
|
}
|
|
|
|
|
|
func sqlCommonValidateUserAndPass(username, password, ip, protocol string, dbHandle *sql.DB) (User, error) {
|
|
|
- user, err := sqlCommonGetUserByUsername(username, dbHandle)
|
|
|
+ user, err := sqlCommonGetUserByUsername(username, "", dbHandle)
|
|
|
if err != nil {
|
|
|
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
|
|
|
return user, err
|
|
@@ -787,7 +908,7 @@ func sqlCommonValidateUserAndTLSCertificate(username, protocol string, tlsCert *
|
|
|
if tlsCert == nil {
|
|
|
return user, errors.New("TLS certificate cannot be null or empty")
|
|
|
}
|
|
|
- user, err := sqlCommonGetUserByUsername(username, dbHandle)
|
|
|
+ user, err := sqlCommonGetUserByUsername(username, "", dbHandle)
|
|
|
if err != nil {
|
|
|
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
|
|
|
return user, err
|
|
@@ -800,7 +921,7 @@ func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, isSSHCert bo
|
|
|
if len(pubKey) == 0 {
|
|
|
return user, "", errors.New("credentials cannot be null or empty")
|
|
|
}
|
|
|
- user, err := sqlCommonGetUserByUsername(username, dbHandle)
|
|
|
+ user, err := sqlCommonGetUserByUsername(username, "", dbHandle)
|
|
|
if err != nil {
|
|
|
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
|
|
|
return user, "", err
|
|
@@ -993,12 +1114,12 @@ func sqlCommonAddUser(user *User, dbHandle *sql.DB) error {
|
|
|
return err
|
|
|
}
|
|
|
}
|
|
|
- q := getAddUserQuery()
|
|
|
+ q := getAddUserQuery(user.Role)
|
|
|
_, err := tx.ExecContext(ctx, q, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID,
|
|
|
user.MaxSessions, user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth,
|
|
|
user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters), string(fsConfig), user.AdditionalInfo,
|
|
|
user.Description, user.Email, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now()),
|
|
|
- user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer)
|
|
|
+ user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer, user.Role)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
@@ -1044,12 +1165,12 @@ func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
|
|
|
defer cancel()
|
|
|
|
|
|
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
|
|
|
- q := getUpdateUserQuery()
|
|
|
+ q := getUpdateUserQuery(user.Role)
|
|
|
_, err := tx.ExecContext(ctx, q, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions,
|
|
|
user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status,
|
|
|
user.ExpirationDate, string(filters), string(fsConfig), user.AdditionalInfo, user.Description, user.Email,
|
|
|
util.GetTimeAsMsSinceEpoch(time.Now()), user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer,
|
|
|
- user.ID)
|
|
|
+ user.Role, user.ID)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
@@ -1264,19 +1385,17 @@ func sqlCommonGetUsersRangeForQuotaCheck(usernames []string, dbHandle sqlQuerier
|
|
|
|
|
|
for rows.Next() {
|
|
|
var user User
|
|
|
- var filters sql.NullString
|
|
|
+ var filters []byte
|
|
|
err = rows.Scan(&user.ID, &user.Username, &user.QuotaSize, &user.UsedQuotaSize, &user.TotalDataTransfer,
|
|
|
&user.UploadDataTransfer, &user.DownloadDataTransfer, &user.UsedUploadDataTransfer,
|
|
|
&user.UsedDownloadDataTransfer, &filters)
|
|
|
if err != nil {
|
|
|
return users, err
|
|
|
}
|
|
|
- if filters.Valid {
|
|
|
- var userFilters UserFilters
|
|
|
- err = json.Unmarshal([]byte(filters.String), &userFilters)
|
|
|
- if err == nil {
|
|
|
- user.Filters = userFilters
|
|
|
- }
|
|
|
+ var userFilters UserFilters
|
|
|
+ err = json.Unmarshal(filters, &userFilters)
|
|
|
+ if err == nil {
|
|
|
+ user.Filters = userFilters
|
|
|
}
|
|
|
users = append(users, user)
|
|
|
}
|
|
@@ -1353,13 +1472,19 @@ func sqlCommonGetActiveTransfers(from time.Time, dbHandle sqlQuerier) ([]ActiveT
|
|
|
return transfers, rows.Err()
|
|
|
}
|
|
|
|
|
|
-func sqlCommonGetUsers(limit int, offset int, order string, dbHandle sqlQuerier) ([]User, error) {
|
|
|
+func sqlCommonGetUsers(limit int, offset int, order, role string, dbHandle sqlQuerier) ([]User, error) {
|
|
|
users := make([]User, 0, limit)
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
|
|
defer cancel()
|
|
|
|
|
|
- q := getUsersQuery(order)
|
|
|
- rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
|
|
|
+ q := getUsersQuery(order, role)
|
|
|
+ var args []any
|
|
|
+ if role == "" {
|
|
|
+ args = append(args, limit, offset)
|
|
|
+ } else {
|
|
|
+ args = append(args, role, limit, offset)
|
|
|
+ }
|
|
|
+ rows, err := dbHandle.QueryContext(ctx, q, args...)
|
|
|
if err != nil {
|
|
|
return users, err
|
|
|
}
|
|
@@ -1593,7 +1718,8 @@ func sqlCommonCleanupDefenderEvents(from int64, dbHandle *sql.DB) error {
|
|
|
|
|
|
func getShareFromDbRow(row sqlScanner) (Share, error) {
|
|
|
var share Share
|
|
|
- var description, password, allowFrom, paths sql.NullString
|
|
|
+ var description, password sql.NullString
|
|
|
+ var allowFrom, paths []byte
|
|
|
|
|
|
err := row.Scan(&share.ShareID, &share.Name, &description, &share.Scope,
|
|
|
&paths, &share.Username, &share.CreatedAt, &share.UpdatedAt,
|
|
@@ -1605,28 +1731,22 @@ func getShareFromDbRow(row sqlScanner) (Share, error) {
|
|
|
}
|
|
|
return share, err
|
|
|
}
|
|
|
- if paths.Valid {
|
|
|
- var list []string
|
|
|
- err = json.Unmarshal([]byte(paths.String), &list)
|
|
|
- if err != nil {
|
|
|
- return share, err
|
|
|
- }
|
|
|
- share.Paths = list
|
|
|
- } else {
|
|
|
- return share, errors.New("unable to decode shared paths")
|
|
|
+ var list []string
|
|
|
+ err = json.Unmarshal(paths, &list)
|
|
|
+ if err != nil {
|
|
|
+ return share, err
|
|
|
}
|
|
|
+ share.Paths = list
|
|
|
if description.Valid {
|
|
|
share.Description = description.String
|
|
|
}
|
|
|
if password.Valid {
|
|
|
share.Password = password.String
|
|
|
}
|
|
|
- if allowFrom.Valid {
|
|
|
- var list []string
|
|
|
- err = json.Unmarshal([]byte(allowFrom.String), &list)
|
|
|
- if err == nil {
|
|
|
- share.AllowFrom = list
|
|
|
- }
|
|
|
+ list = nil
|
|
|
+ err = json.Unmarshal(allowFrom, &list)
|
|
|
+ if err == nil {
|
|
|
+ share.AllowFrom = list
|
|
|
}
|
|
|
return share, nil
|
|
|
}
|
|
@@ -1661,10 +1781,11 @@ func getAPIKeyFromDbRow(row sqlScanner) (APIKey, error) {
|
|
|
|
|
|
func getAdminFromDbRow(row sqlScanner) (Admin, error) {
|
|
|
var admin Admin
|
|
|
- var email, filters, additionalInfo, permissions, description sql.NullString
|
|
|
+ var email, additionalInfo, description, role sql.NullString
|
|
|
+ var permissions, filters []byte
|
|
|
|
|
|
err := row.Scan(&admin.ID, &admin.Username, &admin.Password, &admin.Status, &email, &permissions,
|
|
|
- &filters, &additionalInfo, &description, &admin.CreatedAt, &admin.UpdatedAt, &admin.LastLogin)
|
|
|
+ &filters, &additionalInfo, &description, &admin.CreatedAt, &admin.UpdatedAt, &admin.LastLogin, &role)
|
|
|
|
|
|
if err != nil {
|
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
@@ -1673,24 +1794,21 @@ func getAdminFromDbRow(row sqlScanner) (Admin, error) {
|
|
|
return admin, err
|
|
|
}
|
|
|
|
|
|
- if permissions.Valid {
|
|
|
- var perms []string
|
|
|
- err = json.Unmarshal([]byte(permissions.String), &perms)
|
|
|
- if err != nil {
|
|
|
- return admin, err
|
|
|
- }
|
|
|
- admin.Permissions = perms
|
|
|
+ var perms []string
|
|
|
+ err = json.Unmarshal(permissions, &perms)
|
|
|
+ if err != nil {
|
|
|
+ return admin, err
|
|
|
}
|
|
|
+ admin.Permissions = perms
|
|
|
|
|
|
if email.Valid {
|
|
|
admin.Email = email.String
|
|
|
}
|
|
|
- if filters.Valid {
|
|
|
- var adminFilters AdminFilters
|
|
|
- err = json.Unmarshal([]byte(filters.String), &adminFilters)
|
|
|
- if err == nil {
|
|
|
- admin.Filters = adminFilters
|
|
|
- }
|
|
|
+
|
|
|
+ var adminFilters AdminFilters
|
|
|
+ err = json.Unmarshal(filters, &adminFilters)
|
|
|
+ if err == nil {
|
|
|
+ admin.Filters = adminFilters
|
|
|
}
|
|
|
if additionalInfo.Valid {
|
|
|
admin.AdditionalInfo = additionalInfo.String
|
|
@@ -1698,6 +1816,9 @@ func getAdminFromDbRow(row sqlScanner) (Admin, error) {
|
|
|
if description.Valid {
|
|
|
admin.Description = description.String
|
|
|
}
|
|
|
+ if role.Valid {
|
|
|
+ admin.Role = role.String
|
|
|
+ }
|
|
|
|
|
|
admin.SetEmptySecretsIfNil()
|
|
|
return admin, nil
|
|
@@ -1718,11 +1839,10 @@ func getEventActionFromDbRow(row sqlScanner) (BaseEventAction, error) {
|
|
|
if description.Valid {
|
|
|
action.Description = description.String
|
|
|
}
|
|
|
- if len(options) > 0 {
|
|
|
- err = json.Unmarshal(options, &action.Options)
|
|
|
- if err != nil {
|
|
|
- return action, err
|
|
|
- }
|
|
|
+ var actionOptions BaseEventActionOptions
|
|
|
+ err = json.Unmarshal(options, &actionOptions)
|
|
|
+ if err == nil {
|
|
|
+ action.Options = actionOptions
|
|
|
}
|
|
|
return action, nil
|
|
|
}
|
|
@@ -1740,21 +1860,40 @@ func getEventRuleFromDbRow(row sqlScanner) (EventRule, error) {
|
|
|
}
|
|
|
return rule, err
|
|
|
}
|
|
|
- if len(conditions) > 0 {
|
|
|
- err = json.Unmarshal(conditions, &rule.Conditions)
|
|
|
- if err != nil {
|
|
|
- return rule, err
|
|
|
- }
|
|
|
+ var ruleConditions EventConditions
|
|
|
+ err = json.Unmarshal(conditions, &ruleConditions)
|
|
|
+ if err == nil {
|
|
|
+ rule.Conditions = ruleConditions
|
|
|
}
|
|
|
+
|
|
|
if description.Valid {
|
|
|
rule.Description = description.String
|
|
|
}
|
|
|
return rule, nil
|
|
|
}
|
|
|
|
|
|
+func getRoleFromDbRow(row sqlScanner) (Role, error) {
|
|
|
+ var role Role
|
|
|
+ var description sql.NullString
|
|
|
+
|
|
|
+ err := row.Scan(&role.ID, &role.Name, &description, &role.CreatedAt, &role.UpdatedAt)
|
|
|
+ if err != nil {
|
|
|
+ if errors.Is(err, sql.ErrNoRows) {
|
|
|
+ return role, util.NewRecordNotFoundError(err.Error())
|
|
|
+ }
|
|
|
+ return role, err
|
|
|
+ }
|
|
|
+ if description.Valid {
|
|
|
+ role.Description = description.String
|
|
|
+ }
|
|
|
+
|
|
|
+ return role, nil
|
|
|
+}
|
|
|
+
|
|
|
func getGroupFromDbRow(row sqlScanner) (Group, error) {
|
|
|
var group Group
|
|
|
- var userSettings, description sql.NullString
|
|
|
+ var description sql.NullString
|
|
|
+ var userSettings []byte
|
|
|
|
|
|
err := row.Scan(&group.ID, &group.Name, &description, &group.CreatedAt, &group.UpdatedAt, &userSettings)
|
|
|
if err != nil {
|
|
@@ -1766,12 +1905,11 @@ func getGroupFromDbRow(row sqlScanner) (Group, error) {
|
|
|
if description.Valid {
|
|
|
group.Description = description.String
|
|
|
}
|
|
|
- if userSettings.Valid {
|
|
|
- var settings GroupUserSettings
|
|
|
- err = json.Unmarshal([]byte(userSettings.String), &settings)
|
|
|
- if err == nil {
|
|
|
- group.UserSettings = settings
|
|
|
- }
|
|
|
+
|
|
|
+ var settings GroupUserSettings
|
|
|
+ err = json.Unmarshal(userSettings, &settings)
|
|
|
+ if err == nil {
|
|
|
+ group.UserSettings = settings
|
|
|
}
|
|
|
|
|
|
return group, nil
|
|
@@ -1779,19 +1917,16 @@ func getGroupFromDbRow(row sqlScanner) (Group, error) {
|
|
|
|
|
|
func getUserFromDbRow(row sqlScanner) (User, error) {
|
|
|
var user User
|
|
|
- var permissions sql.NullString
|
|
|
var password sql.NullString
|
|
|
- var publicKey sql.NullString
|
|
|
- var filters sql.NullString
|
|
|
- var fsConfig sql.NullString
|
|
|
- var additionalInfo, description, email sql.NullString
|
|
|
+ var permissions, publicKey, filters, fsConfig []byte
|
|
|
+ var additionalInfo, description, email, role sql.NullString
|
|
|
|
|
|
err := row.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions,
|
|
|
&user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate,
|
|
|
&user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status, &filters, &fsConfig,
|
|
|
&additionalInfo, &description, &email, &user.CreatedAt, &user.UpdatedAt, &user.UploadDataTransfer, &user.DownloadDataTransfer,
|
|
|
&user.TotalDataTransfer, &user.UsedUploadDataTransfer, &user.UsedDownloadDataTransfer, &user.DeletedAt, &user.FirstDownload,
|
|
|
- &user.FirstUpload)
|
|
|
+ &user.FirstUpload, &role)
|
|
|
if err != nil {
|
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
|
return user, util.NewRecordNotFoundError(err.Error())
|
|
@@ -1801,38 +1936,30 @@ func getUserFromDbRow(row sqlScanner) (User, error) {
|
|
|
if password.Valid {
|
|
|
user.Password = password.String
|
|
|
}
|
|
|
+ perms := make(map[string][]string)
|
|
|
+ err = json.Unmarshal(permissions, &perms)
|
|
|
+ if err != nil {
|
|
|
+ providerLog(logger.LevelError, "unable to deserialize permissions for user %#v: %v", user.Username, err)
|
|
|
+ return user, fmt.Errorf("unable to deserialize permissions for user %#v: %v", user.Username, err)
|
|
|
+ }
|
|
|
+ user.Permissions = perms
|
|
|
// we can have a empty string or an invalid json in null string
|
|
|
// so we do a relaxed test if the field is optional, for example we
|
|
|
// populate public keys only if unmarshal does not return an error
|
|
|
- if publicKey.Valid {
|
|
|
- var list []string
|
|
|
- err = json.Unmarshal([]byte(publicKey.String), &list)
|
|
|
- if err == nil {
|
|
|
- user.PublicKeys = list
|
|
|
- }
|
|
|
- }
|
|
|
- if permissions.Valid {
|
|
|
- perms := make(map[string][]string)
|
|
|
- err = json.Unmarshal([]byte(permissions.String), &perms)
|
|
|
- if err != nil {
|
|
|
- providerLog(logger.LevelError, "unable to deserialize permissions for user %#v: %v", user.Username, err)
|
|
|
- return user, fmt.Errorf("unable to deserialize permissions for user %#v: %v", user.Username, err)
|
|
|
- }
|
|
|
- user.Permissions = perms
|
|
|
+ var pKeys []string
|
|
|
+ err = json.Unmarshal(publicKey, &pKeys)
|
|
|
+ if err == nil {
|
|
|
+ user.PublicKeys = pKeys
|
|
|
}
|
|
|
- if filters.Valid {
|
|
|
- var userFilters UserFilters
|
|
|
- err = json.Unmarshal([]byte(filters.String), &userFilters)
|
|
|
- if err == nil {
|
|
|
- user.Filters = userFilters
|
|
|
- }
|
|
|
+ var userFilters UserFilters
|
|
|
+ err = json.Unmarshal(filters, &userFilters)
|
|
|
+ if err == nil {
|
|
|
+ user.Filters = userFilters
|
|
|
}
|
|
|
- if fsConfig.Valid {
|
|
|
- var fs vfs.Filesystem
|
|
|
- err = json.Unmarshal([]byte(fsConfig.String), &fs)
|
|
|
- if err == nil {
|
|
|
- user.FsConfig = fs
|
|
|
- }
|
|
|
+ var fs vfs.Filesystem
|
|
|
+ err = json.Unmarshal(fsConfig, &fs)
|
|
|
+ if err == nil {
|
|
|
+ user.FsConfig = fs
|
|
|
}
|
|
|
if additionalInfo.Valid {
|
|
|
user.AdditionalInfo = additionalInfo.String
|
|
@@ -1843,6 +1970,9 @@ func getUserFromDbRow(row sqlScanner) (User, error) {
|
|
|
if email.Valid {
|
|
|
user.Email = email.String
|
|
|
}
|
|
|
+ if role.Valid {
|
|
|
+ user.Role = role.String
|
|
|
+ }
|
|
|
user.SetEmptySecretsIfNil()
|
|
|
return user, nil
|
|
|
}
|
|
@@ -1851,7 +1981,8 @@ func sqlCommonGetFolder(ctx context.Context, name string, dbHandle sqlQuerier) (
|
|
|
var folder vfs.BaseVirtualFolder
|
|
|
q := getFolderByNameQuery()
|
|
|
row := dbHandle.QueryRowContext(ctx, q, name)
|
|
|
- var mappedPath, description, fsConfig sql.NullString
|
|
|
+ var mappedPath, description sql.NullString
|
|
|
+ var fsConfig []byte
|
|
|
err := row.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, &folder.LastQuotaUpdate,
|
|
|
&folder.Name, &description, &fsConfig)
|
|
|
if err != nil {
|
|
@@ -1866,12 +1997,10 @@ func sqlCommonGetFolder(ctx context.Context, name string, dbHandle sqlQuerier) (
|
|
|
if description.Valid {
|
|
|
folder.Description = description.String
|
|
|
}
|
|
|
- if fsConfig.Valid {
|
|
|
- var fs vfs.Filesystem
|
|
|
- err = json.Unmarshal([]byte(fsConfig.String), &fs)
|
|
|
- if err == nil {
|
|
|
- folder.FsConfig = fs
|
|
|
- }
|
|
|
+ var fs vfs.Filesystem
|
|
|
+ err = json.Unmarshal(fsConfig, &fs)
|
|
|
+ if err == nil {
|
|
|
+ folder.FsConfig = fs
|
|
|
}
|
|
|
return folder, err
|
|
|
}
|
|
@@ -1971,7 +2100,8 @@ func sqlCommonDumpFolders(dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error)
|
|
|
defer rows.Close()
|
|
|
for rows.Next() {
|
|
|
var folder vfs.BaseVirtualFolder
|
|
|
- var mappedPath, description, fsConfig sql.NullString
|
|
|
+ var mappedPath, description sql.NullString
|
|
|
+ var fsConfig []byte
|
|
|
err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
|
|
|
&folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig)
|
|
|
if err != nil {
|
|
@@ -1983,12 +2113,10 @@ func sqlCommonDumpFolders(dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error)
|
|
|
if description.Valid {
|
|
|
folder.Description = description.String
|
|
|
}
|
|
|
- if fsConfig.Valid {
|
|
|
- var fs vfs.Filesystem
|
|
|
- err = json.Unmarshal([]byte(fsConfig.String), &fs)
|
|
|
- if err == nil {
|
|
|
- folder.FsConfig = fs
|
|
|
- }
|
|
|
+ var fs vfs.Filesystem
|
|
|
+ err = json.Unmarshal(fsConfig, &fs)
|
|
|
+ if err == nil {
|
|
|
+ folder.FsConfig = fs
|
|
|
}
|
|
|
folders = append(folders, folder)
|
|
|
}
|
|
@@ -2014,7 +2142,8 @@ func sqlCommonGetFolders(limit, offset int, order string, minimal bool, dbHandle
|
|
|
return folders, err
|
|
|
}
|
|
|
} else {
|
|
|
- var mappedPath, description, fsConfig sql.NullString
|
|
|
+ var mappedPath, description sql.NullString
|
|
|
+ var fsConfig []byte
|
|
|
err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
|
|
|
&folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig)
|
|
|
if err != nil {
|
|
@@ -2026,12 +2155,10 @@ func sqlCommonGetFolders(limit, offset int, order string, minimal bool, dbHandle
|
|
|
if description.Valid {
|
|
|
folder.Description = description.String
|
|
|
}
|
|
|
- if fsConfig.Valid {
|
|
|
- var fs vfs.Filesystem
|
|
|
- err = json.Unmarshal([]byte(fsConfig.String), &fs)
|
|
|
- if err == nil {
|
|
|
- folder.FsConfig = fs
|
|
|
- }
|
|
|
+ var fs vfs.Filesystem
|
|
|
+ err = json.Unmarshal(fsConfig, &fs)
|
|
|
+ if err == nil {
|
|
|
+ folder.FsConfig = fs
|
|
|
}
|
|
|
}
|
|
|
folder.PrepareForRendering()
|
|
@@ -2297,7 +2424,8 @@ func getUsersWithVirtualFolders(ctx context.Context, users []User, dbHandle sqlQ
|
|
|
for rows.Next() {
|
|
|
var folder vfs.VirtualFolder
|
|
|
var userID int64
|
|
|
- var mappedPath, fsConfig, description sql.NullString
|
|
|
+ var mappedPath, description sql.NullString
|
|
|
+ var fsConfig []byte
|
|
|
err = rows.Scan(&folder.ID, &folder.Name, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
|
|
|
&folder.LastQuotaUpdate, &folder.VirtualPath, &folder.QuotaSize, &folder.QuotaFiles, &userID, &fsConfig,
|
|
|
&description)
|
|
@@ -2310,12 +2438,10 @@ func getUsersWithVirtualFolders(ctx context.Context, users []User, dbHandle sqlQ
|
|
|
if description.Valid {
|
|
|
folder.Description = description.String
|
|
|
}
|
|
|
- if fsConfig.Valid {
|
|
|
- var fs vfs.Filesystem
|
|
|
- err = json.Unmarshal([]byte(fsConfig.String), &fs)
|
|
|
- if err == nil {
|
|
|
- folder.FsConfig = fs
|
|
|
- }
|
|
|
+ var fs vfs.Filesystem
|
|
|
+ err = json.Unmarshal(fsConfig, &fs)
|
|
|
+ if err == nil {
|
|
|
+ folder.FsConfig = fs
|
|
|
}
|
|
|
usersVirtualFolders[userID] = append(usersVirtualFolders[userID], folder)
|
|
|
}
|
|
@@ -2390,6 +2516,28 @@ func getGroupWithUsers(ctx context.Context, group Group, dbHandle sqlQuerier) (G
|
|
|
return groups[0], err
|
|
|
}
|
|
|
|
|
|
+func getRoleWithUsers(ctx context.Context, role Role, dbHandle sqlQuerier) (Role, error) {
|
|
|
+ roles, err := getRolesWithUsers(ctx, []Role{role}, dbHandle)
|
|
|
+ if err != nil {
|
|
|
+ return role, err
|
|
|
+ }
|
|
|
+ if len(roles) == 0 {
|
|
|
+ return role, errors.New("unable to associate users with role")
|
|
|
+ }
|
|
|
+ return roles[0], err
|
|
|
+}
|
|
|
+
|
|
|
+func getRoleWithAdmins(ctx context.Context, role Role, dbHandle sqlQuerier) (Role, error) {
|
|
|
+ roles, err := getRolesWithAdmins(ctx, []Role{role}, dbHandle)
|
|
|
+ if err != nil {
|
|
|
+ return role, err
|
|
|
+ }
|
|
|
+ if len(roles) == 0 {
|
|
|
+ return role, errors.New("unable to associate admins with role")
|
|
|
+ }
|
|
|
+ return roles[0], err
|
|
|
+}
|
|
|
+
|
|
|
func getGroupWithAdmins(ctx context.Context, group Group, dbHandle sqlQuerier) (Group, error) {
|
|
|
groups, err := getGroupsWithAdmins(ctx, []Group{group}, dbHandle)
|
|
|
if err != nil {
|
|
@@ -2427,7 +2575,8 @@ func getGroupsWithVirtualFolders(ctx context.Context, groups []Group, dbHandle s
|
|
|
for rows.Next() {
|
|
|
var groupID int64
|
|
|
var folder vfs.VirtualFolder
|
|
|
- var mappedPath, fsConfig, description sql.NullString
|
|
|
+ var mappedPath, description sql.NullString
|
|
|
+ var fsConfig []byte
|
|
|
err = rows.Scan(&folder.ID, &folder.Name, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
|
|
|
&folder.LastQuotaUpdate, &folder.VirtualPath, &folder.QuotaSize, &folder.QuotaFiles, &groupID, &fsConfig,
|
|
|
&description)
|
|
@@ -2440,12 +2589,10 @@ func getGroupsWithVirtualFolders(ctx context.Context, groups []Group, dbHandle s
|
|
|
if description.Valid {
|
|
|
folder.Description = description.String
|
|
|
}
|
|
|
- if fsConfig.Valid {
|
|
|
- var fs vfs.Filesystem
|
|
|
- err = json.Unmarshal([]byte(fsConfig.String), &fs)
|
|
|
- if err == nil {
|
|
|
- folder.FsConfig = fs
|
|
|
- }
|
|
|
+ var fs vfs.Filesystem
|
|
|
+ err = json.Unmarshal(fsConfig, &fs)
|
|
|
+ if err == nil {
|
|
|
+ folder.FsConfig = fs
|
|
|
}
|
|
|
groupsVirtualFolders[groupID] = append(groupsVirtualFolders[groupID], folder)
|
|
|
}
|
|
@@ -2498,6 +2645,71 @@ func getGroupsWithUsers(ctx context.Context, groups []Group, dbHandle sqlQuerier
|
|
|
return groups, err
|
|
|
}
|
|
|
|
|
|
+func getRolesWithUsers(ctx context.Context, roles []Role, dbHandle sqlQuerier) ([]Role, error) {
|
|
|
+ if len(roles) == 0 {
|
|
|
+ return roles, nil
|
|
|
+ }
|
|
|
+ rows, err := dbHandle.QueryContext(ctx, getUsersWithRolesQuery(roles))
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ defer rows.Close()
|
|
|
+
|
|
|
+ rolesUsers := make(map[int64][]string)
|
|
|
+ for rows.Next() {
|
|
|
+ var roleID int64
|
|
|
+ var username string
|
|
|
+ err = rows.Scan(&roleID, &username)
|
|
|
+ if err != nil {
|
|
|
+ return roles, err
|
|
|
+ }
|
|
|
+ rolesUsers[roleID] = append(rolesUsers[roleID], username)
|
|
|
+ }
|
|
|
+ err = rows.Err()
|
|
|
+ if err != nil {
|
|
|
+ return roles, err
|
|
|
+ }
|
|
|
+ if len(rolesUsers) > 0 {
|
|
|
+ for idx := range roles {
|
|
|
+ ref := &roles[idx]
|
|
|
+ ref.Users = rolesUsers[ref.ID]
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return roles, nil
|
|
|
+}
|
|
|
+
|
|
|
+func getRolesWithAdmins(ctx context.Context, roles []Role, dbHandle sqlQuerier) ([]Role, error) {
|
|
|
+ if len(roles) == 0 {
|
|
|
+ return roles, nil
|
|
|
+ }
|
|
|
+ rows, err := dbHandle.QueryContext(ctx, getAdminsWithRolesQuery(roles))
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ defer rows.Close()
|
|
|
+
|
|
|
+ rolesAdmins := make(map[int64][]string)
|
|
|
+ for rows.Next() {
|
|
|
+ var roleID int64
|
|
|
+ var username string
|
|
|
+ err = rows.Scan(&roleID, &username)
|
|
|
+ if err != nil {
|
|
|
+ return roles, err
|
|
|
+ }
|
|
|
+ rolesAdmins[roleID] = append(rolesAdmins[roleID], username)
|
|
|
+ }
|
|
|
+ if err = rows.Err(); err != nil {
|
|
|
+ return roles, err
|
|
|
+ }
|
|
|
+ if len(rolesAdmins) > 0 {
|
|
|
+ for idx := range roles {
|
|
|
+ ref := &roles[idx]
|
|
|
+ ref.Admins = rolesAdmins[ref.ID]
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return roles, nil
|
|
|
+}
|
|
|
+
|
|
|
func getGroupsWithAdmins(ctx context.Context, groups []Group, dbHandle sqlQuerier) ([]Group, error) {
|
|
|
if len(groups) == 0 {
|
|
|
return groups, nil
|
|
@@ -2701,7 +2913,7 @@ func getRelatedValuesForAPIKeys(ctx context.Context, apiKeys []APIKey, dbHandle
|
|
|
func sqlCommonGetAPIKeyRelatedIDs(apiKey *APIKey) (sql.NullInt64, sql.NullInt64, error) {
|
|
|
var userID, adminID sql.NullInt64
|
|
|
if apiKey.User != "" {
|
|
|
- u, err := provider.userExists(apiKey.User)
|
|
|
+ u, err := provider.userExists(apiKey.User, "")
|
|
|
if err != nil {
|
|
|
return userID, adminID, util.NewValidationError(fmt.Sprintf("unable to validate user %v", apiKey.User))
|
|
|
}
|