Просмотр исходного кода

EventManager: avoid copying user struct when updating parameters

Signed-off-by: Nicola Murino <[email protected]>
Nicola Murino 4 месяцев назад
Родитель
Сommit
5ca3522dc0
3 измененных файлов с 48 добавлено и 29 удалено
  1. 4 2
      internal/common/connection.go
  2. 17 25
      internal/common/eventmanager.go
  3. 27 2
      internal/common/protocol_test.go

+ 4 - 2
internal/common/connection.go

@@ -617,8 +617,10 @@ func (c *BaseConnection) checkCopy(srcInfo, dstInfo os.FileInfo, virtualSource,
 	if dstInfo != nil && dstInfo.IsDir() {
 		return fmt.Errorf("cannot overwrite file %q with dir %q: %w", virtualSource, virtualTarget, c.GetOpUnsupportedError())
 	}
-	if fsSourcePath == fsTargetPath {
-		return fmt.Errorf("the copy source and target cannot be the same: %w", c.GetOpUnsupportedError())
+	if c.IsSameResource(virtualSource, virtualTarget) {
+		if fsSourcePath == fsTargetPath {
+			return fmt.Errorf("the copy source and target cannot be the same: %w", c.GetOpUnsupportedError())
+		}
 	}
 	return nil
 }

+ 17 - 25
internal/common/eventmanager.go

@@ -1380,8 +1380,7 @@ func getHTTPRuleActionBody(c *dataprovider.EventActionHTTPConfig, replacer *stri
 		var conn *BaseConnection
 		if user.Username != "" {
 			var err error
-			user, err = getUserForEventAction(user)
-			if err != nil {
+			if err := getUserForEventAction(&user); err != nil {
 				return body, "", err
 			}
 			connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String())
@@ -1613,8 +1612,7 @@ func executeEmailRuleAction(c dataprovider.EventActionEmailConfig, params *Event
 		if err != nil {
 			return err
 		}
-		user, err = getUserForEventAction(user)
-		if err != nil {
+		if err := getUserForEventAction(&user); err != nil {
 			return err
 		}
 		connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String())
@@ -1641,11 +1639,11 @@ func executeEmailRuleAction(c dataprovider.EventActionEmailConfig, params *Event
 	return nil
 }
 
-func getUserForEventAction(user dataprovider.User) (dataprovider.User, error) {
+func getUserForEventAction(user *dataprovider.User) error {
 	err := user.LoadAndApplyGroupSettings()
 	if err != nil {
 		eventManagerLog(logger.LevelError, "unable to get group for user %q: %+v", user.Username, err)
-		return dataprovider.User{}, fmt.Errorf("unable to get groups for user %q", user.Username)
+		return fmt.Errorf("unable to get groups for user %q", user.Username)
 	}
 	user.UploadDataTransfer = 0
 	user.UploadBandwidth = 0
@@ -1656,7 +1654,7 @@ func getUserForEventAction(user dataprovider.User) (dataprovider.User, error) {
 	for k := range user.Permissions {
 		user.Permissions[k] = []string{dataprovider.PermAny}
 	}
-	return user, nil
+	return nil
 }
 
 func replacePathsPlaceholders(paths []string, replacer *strings.Replacer) []string {
@@ -1676,12 +1674,11 @@ func executeDeleteFileFsAction(conn *BaseConnection, item string, info os.FileIn
 }
 
 func executeDeleteFsActionForUser(deletes []string, replacer *strings.Replacer, user dataprovider.User) error {
-	user, err := getUserForEventAction(user)
-	if err != nil {
+	if err := getUserForEventAction(&user); err != nil {
 		return err
 	}
 	connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String())
-	err = user.CheckFsRoot(connectionID)
+	err := user.CheckFsRoot(connectionID)
 	defer user.CloseFs() //nolint:errcheck
 	if err != nil {
 		return fmt.Errorf("delete error, unable to check root fs for user %q: %w", user.Username, err)
@@ -1746,12 +1743,11 @@ func executeDeleteFsRuleAction(deletes []string, replacer *strings.Replacer,
 }
 
 func executeMkDirsFsActionForUser(dirs []string, replacer *strings.Replacer, user dataprovider.User) error {
-	user, err := getUserForEventAction(user)
-	if err != nil {
+	if err := getUserForEventAction(&user); err != nil {
 		return err
 	}
 	connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String())
-	err = user.CheckFsRoot(connectionID)
+	err := user.CheckFsRoot(connectionID)
 	defer user.CloseFs() //nolint:errcheck
 	if err != nil {
 		return fmt.Errorf("mkdir error, unable to check root fs for user %q: %w", user.Username, err)
@@ -1807,12 +1803,11 @@ func executeMkdirFsRuleAction(dirs []string, replacer *strings.Replacer,
 func executeRenameFsActionForUser(renames []dataprovider.RenameConfig, replacer *strings.Replacer,
 	user dataprovider.User,
 ) error {
-	user, err := getUserForEventAction(user)
-	if err != nil {
+	if err := getUserForEventAction(&user); err != nil {
 		return err
 	}
 	connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String())
-	err = user.CheckFsRoot(connectionID)
+	err := user.CheckFsRoot(connectionID)
 	defer user.CloseFs() //nolint:errcheck
 	if err != nil {
 		return fmt.Errorf("rename error, unable to check root fs for user %q: %w", user.Username, err)
@@ -1838,12 +1833,11 @@ func executeRenameFsActionForUser(renames []dataprovider.RenameConfig, replacer
 func executeCopyFsActionForUser(keyVals []dataprovider.KeyValue, replacer *strings.Replacer,
 	user dataprovider.User,
 ) error {
-	user, err := getUserForEventAction(user)
-	if err != nil {
+	if err := getUserForEventAction(&user); err != nil {
 		return err
 	}
 	connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String())
-	err = user.CheckFsRoot(connectionID)
+	err := user.CheckFsRoot(connectionID)
 	defer user.CloseFs() //nolint:errcheck
 	if err != nil {
 		return fmt.Errorf("copy error, unable to check root fs for user %q: %w", user.Username, err)
@@ -1871,12 +1865,11 @@ func executeCopyFsActionForUser(keyVals []dataprovider.KeyValue, replacer *strin
 func executeExistFsActionForUser(exist []string, replacer *strings.Replacer,
 	user dataprovider.User,
 ) error {
-	user, err := getUserForEventAction(user)
-	if err != nil {
+	if err := getUserForEventAction(&user); err != nil {
 		return err
 	}
 	connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String())
-	err = user.CheckFsRoot(connectionID)
+	err := user.CheckFsRoot(connectionID)
 	defer user.CloseFs() //nolint:errcheck
 	if err != nil {
 		return fmt.Errorf("existence check error, unable to check root fs for user %q: %w", user.Username, err)
@@ -2031,12 +2024,11 @@ func estimateZipSize(conn *BaseConnection, zipPath string, paths []string) (int6
 func executeCompressFsActionForUser(c dataprovider.EventActionFsCompress, replacer *strings.Replacer,
 	user dataprovider.User,
 ) error {
-	user, err := getUserForEventAction(user)
-	if err != nil {
+	if err := getUserForEventAction(&user); err != nil {
 		return err
 	}
 	connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String())
-	err = user.CheckFsRoot(connectionID)
+	err := user.CheckFsRoot(connectionID)
 	defer user.CloseFs() //nolint:errcheck
 	if err != nil {
 		return fmt.Errorf("compress error, unable to check root fs for user %q: %w", user.Username, err)

+ 27 - 2
internal/common/protocol_test.go

@@ -5402,6 +5402,7 @@ func TestEventActionCommandEnvVars(t *testing.T) {
 }
 
 func TestFsActionCopy(t *testing.T) {
+	dirCopy := "/dircopy"
 	a1 := dataprovider.BaseEventAction{
 		Name: "a1",
 		Type: dataprovider.ActionTypeFilesystem,
@@ -5411,7 +5412,7 @@ func TestFsActionCopy(t *testing.T) {
 				Copy: []dataprovider.KeyValue{
 					{
 						Key:   "/{{.VirtualPath}}/",
-						Value: "/dircopy/",
+						Value: dirCopy + "/",
 					},
 				},
 			},
@@ -5441,7 +5442,29 @@ func TestFsActionCopy(t *testing.T) {
 	}
 	rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated)
 	assert.NoError(t, err)
+	g1 := dataprovider.Group{
+		BaseGroup: sdk.BaseGroup{
+			Name: "group1",
+		},
+		UserSettings: dataprovider.GroupUserSettings{
+			BaseGroupUserSettings: sdk.BaseGroupUserSettings{
+				Permissions: map[string][]string{
+					// Restrict permissions in copyPath to check that action
+					// will have full permissions anyway.
+					dirCopy: {dataprovider.PermListItems, dataprovider.PermDelete},
+				},
+			},
+		},
+	}
+	group1, resp, err := httpdtest.AddGroup(g1, http.StatusCreated)
+	assert.NoError(t, err, string(resp))
 	u := getTestUser()
+	u.Groups = []sdk.GroupMapping{
+		{
+			Name: group1.Name,
+			Type: sdk.GroupTypePrimary,
+		},
+	}
 	user, _, err := httpdtest.AddUser(u, http.StatusCreated)
 	assert.NoError(t, err)
 	conn, client, err := getSftpClient(user)
@@ -5451,7 +5474,7 @@ func TestFsActionCopy(t *testing.T) {
 
 		err = writeSFTPFile(testFileName, 100, client)
 		assert.NoError(t, err)
-		_, err = client.Stat(path.Join("dircopy", testFileName))
+		_, err = client.Stat(path.Join(dirCopy, testFileName))
 		assert.NoError(t, err)
 
 		action1.Options.FsConfig.Copy = []dataprovider.KeyValue{
@@ -5474,6 +5497,8 @@ func TestFsActionCopy(t *testing.T) {
 	assert.NoError(t, err)
 	err = os.RemoveAll(user.GetHomeDir())
 	assert.NoError(t, err)
+	_, err = httpdtest.RemoveGroup(group1, http.StatusOK)
+	assert.NoError(t, err)
 }
 
 func TestEventFsActionsGroupFilters(t *testing.T) {