Browse Source

ssh/tailssh: handle not-authenticated-yet connections in matchRule

Also make more fields in conn.info thread safe, there was previously a
data race here.

Fixes #5110

Signed-off-by: Maisem Ali <[email protected]>
Maisem Ali 3 years ago
parent
commit
480fd6c797
3 changed files with 75 additions and 14 deletions
  1. 4 1
      ssh/tailssh/incubator.go
  2. 57 13
      ssh/tailssh/tailssh.go
  3. 14 0
      ssh/tailssh/tailssh_test.go

+ 4 - 1
ssh/tailssh/incubator.go

@@ -86,8 +86,11 @@ func (ss *sshSession) newIncubatorCommand() *exec.Cmd {
 		// TODO(maisem): this doesn't work with sftp
 		return exec.CommandContext(ss.ctx, name, args...)
 	}
+	ss.conn.mu.Lock()
 	lu := ss.conn.localUser
 	ci := ss.conn.info
+	gids := strings.Join(ss.conn.userGroupIDs, ",")
+	ss.conn.mu.Unlock()
 	remoteUser := ci.uprof.LoginName
 	if len(ci.node.Tags) > 0 {
 		remoteUser = strings.Join(ci.node.Tags, ",")
@@ -98,7 +101,7 @@ func (ss *sshSession) newIncubatorCommand() *exec.Cmd {
 		"ssh",
 		"--uid=" + lu.Uid,
 		"--gid=" + lu.Gid,
-		"--groups=" + strings.Join(ss.conn.userGroupIDs, ","),
+		"--groups=" + gids,
 		"--local-user=" + lu.Username,
 		"--remote-user=" + remoteUser,
 		"--remote-ip=" + ci.src.IP().String(),

+ 57 - 13
ssh/tailssh/tailssh.go

@@ -141,6 +141,14 @@ func (srv *server) OnPolicyChange() {
 	srv.mu.Lock()
 	defer srv.mu.Unlock()
 	for c := range srv.activeConns {
+		c.mu.Lock()
+		ci := c.info
+		c.mu.Unlock()
+		if ci == nil {
+			// c.info is nil when the connection hasn't been authenticated yet.
+			// In that case, the connection will be terminated when it is.
+			continue
+		}
 		go c.checkStillValid()
 	}
 }
@@ -152,14 +160,14 @@ type conn struct {
 
 	insecureSkipTailscaleAuth bool // used by tests.
 
-	connID       string             // ID that's shared with control
-	action0      *tailcfg.SSHAction // first matching action
-	srv          *server
-	info         *sshConnInfo // set by setInfo
+	connID  string             // ID that's shared with control
+	action0 *tailcfg.SSHAction // first matching action
+	srv     *server
+
+	mu           sync.Mutex   // protects the following
 	localUser    *user.User   // set by checkAuth
 	userGroupIDs []string     // set by checkAuth
-
-	mu sync.Mutex // protects the following
+	info         *sshConnInfo // set by setInfo
 	// idH is the RFC4253 sec8 hash H. It is used to identify the connection,
 	// and is shared among all sessions. It should not be shared outside
 	// process. It is confusingly referred to as SessionID by the gliderlabs/ssh
@@ -179,9 +187,13 @@ func (c *conn) logf(format string, args ...any) {
 // PublicKeyHandler implements ssh.PublicKeyHandler is called by the the
 // ssh.Server when the client presents a public key.
 func (c *conn) PublicKeyHandler(ctx ssh.Context, pubKey ssh.PublicKey) error {
-	if c.info == nil {
+	c.mu.Lock()
+	ci := c.info
+	c.mu.Unlock()
+	if ci == nil {
 		return gossh.ErrDenied
 	}
+
 	if err := c.checkAuth(pubKey); err != nil {
 		// TODO(maisem/bradfitz): surface the error here.
 		c.logf("rejecting SSH public key %s: %v", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)), err)
@@ -217,7 +229,7 @@ func (c *conn) NoClientAuthCallback(cm gossh.ConnMetadata) (*gossh.Permissions,
 func (c *conn) checkAuth(pubKey ssh.PublicKey) error {
 	a, localUser, err := c.evaluatePolicy(pubKey)
 	if err != nil {
-		if pubKey == nil && c.havePubKeyPolicy(c.info) {
+		if pubKey == nil && c.havePubKeyPolicy() {
 			return errPubKeyRequired
 		}
 		return fmt.Errorf("%w: %v", gossh.ErrDenied, err)
@@ -236,6 +248,8 @@ func (c *conn) checkAuth(pubKey ssh.PublicKey) error {
 		if err != nil {
 			return err
 		}
+		c.mu.Lock()
+		defer c.mu.Unlock()
 		c.userGroupIDs = gids
 		c.localUser = lu
 		return nil
@@ -329,7 +343,13 @@ func (c *conn) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, de
 
 // havePubKeyPolicy reports whether any policy rule may provide access by means
 // of a ssh.PublicKey.
-func (c *conn) havePubKeyPolicy(ci *sshConnInfo) bool {
+func (c *conn) havePubKeyPolicy() bool {
+	c.mu.Lock()
+	ci := c.info
+	c.mu.Unlock()
+	if ci == nil {
+		panic("havePubKeyPolicy called before setInfo")
+	}
 	// Is there any rule that looks like it'd require a public key for this
 	// sshUser?
 	pol, ok := c.sshPolicy()
@@ -414,6 +434,8 @@ func (c *conn) setInfo(cm gossh.ConnMetadata) error {
 	if !ok {
 		return fmt.Errorf("unknown Tailscale identity from src %v", ci.src)
 	}
+	c.mu.Lock()
+	defer c.mu.Unlock()
 	ci.node = node
 	ci.uprof = &uprof
 
@@ -589,8 +611,10 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) {
 	}
 
 	ss := c.newSSHSession(s)
+	c.mu.Lock()
 	ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.IP(), c.localUser.Username)
 	ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, c.localUser.Username)
+	c.mu.Unlock()
 	ss.run()
 }
 
@@ -688,7 +712,10 @@ func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (ac
 
 func (c *conn) expandDelegateURL(actionURL string) string {
 	nm := c.srv.lb.NetMap()
+	c.mu.Lock()
 	ci := c.info
+	lu := c.localUser
+	c.mu.Unlock()
 	var dstNodeID string
 	if nm != nil {
 		dstNodeID = fmt.Sprint(int64(nm.SelfNode.ID))
@@ -699,7 +726,7 @@ func (c *conn) expandDelegateURL(actionURL string) string {
 		"$DST_NODE_IP", url.QueryEscape(ci.dst.IP().String()),
 		"$DST_NODE_ID", dstNodeID,
 		"$SSH_USER", url.QueryEscape(ci.sshUser),
-		"$LOCAL_USER", url.QueryEscape(c.localUser.Username),
+		"$LOCAL_USER", url.QueryEscape(lu.Username),
 	).Replace(actionURL)
 }
 
@@ -709,10 +736,12 @@ func (c *conn) expandPublicKeyURL(pubKeyURL string) string {
 	}
 	var localPart string
 	var loginName string
+	c.mu.Lock()
 	if c.info.uprof != nil {
 		loginName = c.info.uprof.LoginName
 		localPart, _, _ = strings.Cut(loginName, "@")
 	}
+	c.mu.Unlock()
 	return strings.NewReplacer(
 		"$LOGINNAME_EMAIL", loginName,
 		"$LOGINNAME_LOCALPART", localPart,
@@ -768,6 +797,8 @@ func (c *conn) isStillValid() bool {
 	if !a.Accept && a.HoldAndDelegate == "" {
 		return false
 	}
+	c.mu.Lock()
+	defer c.mu.Unlock()
 	return c.localUser.Username == localUser
 }
 
@@ -944,6 +975,8 @@ func (ss *sshSession) run() {
 		return
 	}
 	ss.conn.startSessionLocked(ss)
+	lu := ss.conn.localUser
+	localUser := lu.Username
 	srv.mu.Unlock()
 
 	defer ss.conn.endSession(ss)
@@ -959,8 +992,6 @@ func (ss *sshSession) run() {
 	}
 
 	logf := ss.logf
-	lu := ss.conn.localUser
-	localUser := lu.Username
 
 	if euid := os.Geteuid(); euid != 0 {
 		if lu.Uid != fmt.Sprint(euid) {
@@ -1110,9 +1141,20 @@ var (
 	errRuleExpired    = errors.New("rule expired")
 	errPrincipalMatch = errors.New("principal didn't match")
 	errUserMatch      = errors.New("user didn't match")
+	errInvalidConn    = errors.New("invalid connection state")
 )
 
 func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, err error) {
+	if c == nil {
+		return nil, "", errInvalidConn
+	}
+	c.mu.Lock()
+	ci := c.info
+	c.mu.Unlock()
+	if ci == nil {
+		c.logf("invalid connection state")
+		return nil, "", errInvalidConn
+	}
 	if r == nil {
 		return nil, "", errNilRule
 	}
@@ -1126,7 +1168,7 @@ func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg
 		// For all but Reject rules, SSHUsers is required.
 		// If SSHUsers is nil or empty, mapLocalUser will return an
 		// empty string anyway.
-		localUser = mapLocalUser(r.SSHUsers, c.info.sshUser)
+		localUser = mapLocalUser(r.SSHUsers, ci.sshUser)
 		if localUser == "" {
 			return nil, "", errUserMatch
 		}
@@ -1175,7 +1217,9 @@ func (c *conn) principalMatches(p *tailcfg.SSHPrincipal, pubKey gossh.PublicKey)
 // that match the Tailscale identity match (Node, NodeIP, UserLogin, Any).
 // This function does not consider PubKeys.
 func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool {
+	c.mu.Lock()
 	ci := c.info
+	c.mu.Unlock()
 	if p.Any {
 		return true
 	}

+ 14 - 0
ssh/tailssh/tailssh_test.go

@@ -47,13 +47,26 @@ func TestMatchRule(t *testing.T) {
 		wantErr  error
 		wantUser string
 	}{
+		{
+			name: "invalid-conn",
+			rule: &tailcfg.SSHRule{
+				Action:     someAction,
+				Principals: []*tailcfg.SSHPrincipal{{Any: true}},
+				SSHUsers: map[string]string{
+					"*": "ubuntu",
+				},
+			},
+			wantErr: errInvalidConn,
+		},
 		{
 			name:    "nil-rule",
+			ci:      &sshConnInfo{},
 			rule:    nil,
 			wantErr: errNilRule,
 		},
 		{
 			name:    "nil-action",
+			ci:      &sshConnInfo{},
 			rule:    &tailcfg.SSHRule{},
 			wantErr: errNilAction,
 		},
@@ -180,6 +193,7 @@ func TestMatchRule(t *testing.T) {
 		t.Run(tt.name, func(t *testing.T) {
 			c := &conn{
 				info: tt.ci,
+				srv:  &server{logf: t.Logf},
 			}
 			got, gotUser, err := c.matchRule(tt.rule, nil)
 			if err != tt.wantErr {