|  | @@ -6,8 +6,12 @@ package models
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import (
 | 
	
		
			
				|  |  |  	"fmt"
 | 
	
		
			
				|  |  | +	"strings"
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +	"github.com/Unknwon/com"
 | 
	
		
			
				|  |  |  	"github.com/gogits/git-module"
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	"github.com/gogits/gogs/modules/base"
 | 
	
		
			
				|  |  |  )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  type Branch struct {
 | 
	
	
		
			
				|  | @@ -58,6 +62,20 @@ func (br *Branch) GetCommit() (*git.Commit, error) {
 | 
	
		
			
				|  |  |  	return gitRepo.GetBranchCommit(br.Name)
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +type ProtectBranchWhitelist struct {
 | 
	
		
			
				|  |  | +	ID              int64
 | 
	
		
			
				|  |  | +	ProtectBranchID int64
 | 
	
		
			
				|  |  | +	RepoID          int64  `xorm:"UNIQUE(protect_branch_whitelist)"`
 | 
	
		
			
				|  |  | +	Name            string `xorm:"UNIQUE(protect_branch_whitelist)"`
 | 
	
		
			
				|  |  | +	UserID          int64  `xorm:"UNIQUE(protect_branch_whitelist)"`
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +// IsUserInProtectBranchWhitelist returns true if given user is in the whitelist of a branch in a repository.
 | 
	
		
			
				|  |  | +func IsUserInProtectBranchWhitelist(repoID, userID int64, branch string) bool {
 | 
	
		
			
				|  |  | +	has, err := x.Where("repo_id = ?", repoID).And("user_id = ?", userID).And("name = ?", branch).Get(new(ProtectBranchWhitelist))
 | 
	
		
			
				|  |  | +	return has && err == nil
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  // ProtectBranch contains options of a protected branch.
 | 
	
		
			
				|  |  |  type ProtectBranch struct {
 | 
	
		
			
				|  |  |  	ID                 int64
 | 
	
	
		
			
				|  | @@ -65,6 +83,9 @@ type ProtectBranch struct {
 | 
	
		
			
				|  |  |  	Name               string `xorm:"UNIQUE(protect_branch)"`
 | 
	
		
			
				|  |  |  	Protected          bool
 | 
	
		
			
				|  |  |  	RequirePullRequest bool
 | 
	
		
			
				|  |  | +	EnableWhitelist    bool
 | 
	
		
			
				|  |  | +	WhitelistUserIDs   string `xorm:"TEXT"`
 | 
	
		
			
				|  |  | +	WhitelistTeamIDs   string `xorm:"TEXT"`
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  // GetProtectBranchOfRepoByName returns *ProtectBranch by branch name in given repostiory.
 | 
	
	
		
			
				|  | @@ -94,15 +115,133 @@ func IsBranchOfRepoRequirePullRequest(repoID int64, name string) bool {
 | 
	
		
			
				|  |  |  // UpdateProtectBranch saves branch protection options.
 | 
	
		
			
				|  |  |  // If ID is 0, it creates a new record. Otherwise, updates existing record.
 | 
	
		
			
				|  |  |  func UpdateProtectBranch(protectBranch *ProtectBranch) (err error) {
 | 
	
		
			
				|  |  | +	sess := x.NewSession()
 | 
	
		
			
				|  |  | +	defer sessionRelease(sess)
 | 
	
		
			
				|  |  | +	if err = sess.Begin(); err != nil {
 | 
	
		
			
				|  |  | +		return err
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	if protectBranch.ID == 0 {
 | 
	
		
			
				|  |  | +		if _, err = sess.Insert(protectBranch); err != nil {
 | 
	
		
			
				|  |  | +			return fmt.Errorf("Insert: %v", err)
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +		return
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	if _, err = sess.Id(protectBranch.ID).AllCols().Update(protectBranch); err != nil {
 | 
	
		
			
				|  |  | +		return fmt.Errorf("Update: %v", err)
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	return sess.Commit()
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +// UpdateOrgProtectBranch saves branch protection options of organizational repository.
 | 
	
		
			
				|  |  | +// If ID is 0, it creates a new record. Otherwise, updates existing record.
 | 
	
		
			
				|  |  | +// This function also performs check if whitelist user and team's IDs have been changed
 | 
	
		
			
				|  |  | +// to avoid unnecessary whitelist delete and regenerate.
 | 
	
		
			
				|  |  | +func UpdateOrgProtectBranch(repo *Repository, protectBranch *ProtectBranch, whitelistUserIDs, whitelistTeamIDs string) (err error) {
 | 
	
		
			
				|  |  | +	if err = repo.GetOwner(); err != nil {
 | 
	
		
			
				|  |  | +		return fmt.Errorf("GetOwner: %v", err)
 | 
	
		
			
				|  |  | +	} else if !repo.Owner.IsOrganization() {
 | 
	
		
			
				|  |  | +		return fmt.Errorf("expect repository owner to be an organization")
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	hasUsersChanged := false
 | 
	
		
			
				|  |  | +	validUserIDs := base.StringsToInt64s(strings.Split(protectBranch.WhitelistUserIDs, ","))
 | 
	
		
			
				|  |  | +	if protectBranch.WhitelistUserIDs != whitelistUserIDs {
 | 
	
		
			
				|  |  | +		hasUsersChanged = true
 | 
	
		
			
				|  |  | +		userIDs := base.StringsToInt64s(strings.Split(whitelistUserIDs, ","))
 | 
	
		
			
				|  |  | +		validUserIDs = make([]int64, 0, len(userIDs))
 | 
	
		
			
				|  |  | +		for _, userID := range userIDs {
 | 
	
		
			
				|  |  | +			has, err := HasAccess(userID, repo, ACCESS_MODE_WRITE)
 | 
	
		
			
				|  |  | +			if err != nil {
 | 
	
		
			
				|  |  | +				return fmt.Errorf("HasAccess [user_id: %d, repo_id: %d]: %v", userID, protectBranch.RepoID, err)
 | 
	
		
			
				|  |  | +			} else if !has {
 | 
	
		
			
				|  |  | +				continue // Drop invalid user ID
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +			validUserIDs = append(validUserIDs, userID)
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		protectBranch.WhitelistUserIDs = strings.Join(base.Int64sToStrings(validUserIDs), ",")
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	hasTeamsChanged := false
 | 
	
		
			
				|  |  | +	validTeamIDs := base.StringsToInt64s(strings.Split(protectBranch.WhitelistTeamIDs, ","))
 | 
	
		
			
				|  |  | +	if protectBranch.WhitelistTeamIDs != whitelistTeamIDs {
 | 
	
		
			
				|  |  | +		hasTeamsChanged = true
 | 
	
		
			
				|  |  | +		teamIDs := base.StringsToInt64s(strings.Split(whitelistTeamIDs, ","))
 | 
	
		
			
				|  |  | +		teams, err := GetTeamsByOrgID(repo.OwnerID)
 | 
	
		
			
				|  |  | +		if err != nil {
 | 
	
		
			
				|  |  | +			return fmt.Errorf("GetTeamsByOrgID [org_id: %d]: %v", repo.OwnerID, err)
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +		validTeamIDs = make([]int64, 0, len(teams))
 | 
	
		
			
				|  |  | +		for i := range teams {
 | 
	
		
			
				|  |  | +			if teams[i].HasWriteAccess() && com.IsSliceContainsInt64(teamIDs, teams[i].ID) {
 | 
	
		
			
				|  |  | +				validTeamIDs = append(validTeamIDs, teams[i].ID)
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		protectBranch.WhitelistTeamIDs = strings.Join(base.Int64sToStrings(validTeamIDs), ",")
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	// Merge users and members of teams
 | 
	
		
			
				|  |  | +	var whitelists []*ProtectBranchWhitelist
 | 
	
		
			
				|  |  | +	if hasUsersChanged || hasTeamsChanged {
 | 
	
		
			
				|  |  | +		mergedUserIDs := make(map[int64]bool)
 | 
	
		
			
				|  |  | +		for _, userID := range validUserIDs {
 | 
	
		
			
				|  |  | +			mergedUserIDs[userID] = true
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		for _, teamID := range validTeamIDs {
 | 
	
		
			
				|  |  | +			members, err := GetTeamMembers(teamID)
 | 
	
		
			
				|  |  | +			if err != nil {
 | 
	
		
			
				|  |  | +				return fmt.Errorf("GetTeamMembers [team_id: %d]: %v", teamID, err)
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +			for i := range members {
 | 
	
		
			
				|  |  | +				mergedUserIDs[members[i].ID] = true
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		whitelists = make([]*ProtectBranchWhitelist, 0, len(mergedUserIDs))
 | 
	
		
			
				|  |  | +		for userID := range mergedUserIDs {
 | 
	
		
			
				|  |  | +			whitelists = append(whitelists, &ProtectBranchWhitelist{
 | 
	
		
			
				|  |  | +				ProtectBranchID: protectBranch.ID,
 | 
	
		
			
				|  |  | +				RepoID:          repo.ID,
 | 
	
		
			
				|  |  | +				Name:            protectBranch.Name,
 | 
	
		
			
				|  |  | +				UserID:          userID,
 | 
	
		
			
				|  |  | +			})
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	sess := x.NewSession()
 | 
	
		
			
				|  |  | +	defer sessionRelease(sess)
 | 
	
		
			
				|  |  | +	if err = sess.Begin(); err != nil {
 | 
	
		
			
				|  |  | +		return err
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  	if protectBranch.ID == 0 {
 | 
	
		
			
				|  |  | -		if _, err = x.Insert(protectBranch); err != nil {
 | 
	
		
			
				|  |  | +		if _, err = sess.Insert(protectBranch); err != nil {
 | 
	
		
			
				|  |  |  			return fmt.Errorf("Insert: %v", err)
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  |  		return
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	_, err = x.Id(protectBranch.ID).AllCols().Update(protectBranch)
 | 
	
		
			
				|  |  | -	return err
 | 
	
		
			
				|  |  | +	if _, err = sess.Id(protectBranch.ID).AllCols().Update(protectBranch); err != nil {
 | 
	
		
			
				|  |  | +		return fmt.Errorf("Update: %v", err)
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	// Refresh whitelists
 | 
	
		
			
				|  |  | +	if hasUsersChanged || hasTeamsChanged {
 | 
	
		
			
				|  |  | +		if _, err = sess.Delete(&ProtectBranchWhitelist{ProtectBranchID: protectBranch.ID}); err != nil {
 | 
	
		
			
				|  |  | +			return fmt.Errorf("delete old protect branch whitelists: %v", err)
 | 
	
		
			
				|  |  | +		} else if _, err = sess.Insert(whitelists); err != nil {
 | 
	
		
			
				|  |  | +			return fmt.Errorf("insert new protect branch whitelists: %v", err)
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	return sess.Commit()
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  // GetProtectBranchesByRepoID returns a list of *ProtectBranch in given repostiory.
 |