Pārlūkot izejas kodu

tailcfg, ssh/tailssh: optionally support SSH public keys in wire policy

And clean up logging.

Updates #3802

Change-Id: I756dc2d579a16757537142283d791f1d0319f4f0
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 4 gadi atpakaļ
vecāks
revīzija
da14e024a8

+ 2 - 2
ssh/tailssh/incubator.go

@@ -330,7 +330,7 @@ func (ss *sshSession) startWithPTY() (ptyFile *os.File, err error) {
 			}
 			k, ok := opcodeShortName[c]
 			if !ok {
-				ss.logf("unknown opcode: %d", c)
+				ss.vlogf("unknown opcode: %d", c)
 				continue
 			}
 			if _, ok := tios.CC[k]; ok {
@@ -341,7 +341,7 @@ func (ss *sshSession) startWithPTY() (ptyFile *os.File, err error) {
 				tios.Opts[k] = v > 0
 				continue
 			}
-			ss.logf("unsupported opcode: %v(%d)=%v", k, c, v)
+			ss.vlogf("unsupported opcode: %v(%d)=%v", k, c, v)
 		}
 
 		// Save PTY settings.

+ 224 - 51
ssh/tailssh/tailssh.go

@@ -9,8 +9,10 @@
 package tailssh
 
 import (
+	"bytes"
 	"context"
 	"crypto/rand"
+	"encoding/base64"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -40,6 +42,8 @@ import (
 	"tailscale.com/types/logger"
 )
 
+var sshVerboseLogging = envknob.Bool("TS_DEBUG_SSH_VLOG")
+
 // TODO(bradfitz): this is all very temporary as code is temporarily
 // being moved around; it will be restructured and documented in
 // following commits.
@@ -50,6 +54,9 @@ func Handle(logf logger.Logf, lb *ipnlocal.LocalBackend, c net.Conn) error {
 	if err != nil {
 		return err
 	}
+	// TODO(bradfitz): make just one server for the whole process. rearrange
+	// netstack's hooks to be a constructor given a lb instead. Then the *server
+	// will have a HandleTailscaleConn method.
 	srv := &server{
 		lb:             lb,
 		logf:           logf,
@@ -60,6 +67,10 @@ func Handle(logf logger.Logf, lb *ipnlocal.LocalBackend, c net.Conn) error {
 		return err
 	}
 	ss.HandleConn(c)
+
+	// Return nil to signal to netstack's interception that it doesn't need to
+	// log. If ss.HandleConn had problems, it can log itself (ideally on an
+	// sshSession.logf).
 	return nil
 }
 
@@ -77,9 +88,19 @@ func (srv *server) newSSHServer() (*ssh.Server, error) {
 		Version:                     "SSH-2.0-Tailscale",
 		LocalPortForwardingCallback: srv.mayForwardLocalPortTo,
 		NoClientAuthCallback: func(m gossh.ConnMetadata) (*gossh.Permissions, error) {
-			srv.logf("SSH connection from %v for %q; client ver %q", m.RemoteAddr(), m.User(), m.ClientVersion())
+			if srv.requiresPubKey(m.User(), m.LocalAddr(), m.RemoteAddr()) {
+				return nil, errors.New("public key required") // any non-nil error will do
+			}
 			return nil, nil
 		},
+		PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
+			if srv.acceptPubKey(ctx.User(), ctx.LocalAddr(), ctx.RemoteAddr(), key) {
+				srv.logf("accepting SSH public key %s", bytes.TrimSpace(gossh.MarshalAuthorizedKey(key)))
+				return true
+			}
+			srv.logf("rejecting SSH public key %s", bytes.TrimSpace(gossh.MarshalAuthorizedKey(key)))
+			return false
+		},
 	}
 	for k, v := range ssh.DefaultRequestHandlers {
 		ss.RequestHandlers[k] = v
@@ -111,7 +132,10 @@ type server struct {
 	activeSessionBySharedID map[string]*sshSession // yyymmddThhmmss-XXXXX => session
 }
 
-var debugPolicyFile = envknob.String("TS_DEBUG_SSH_POLICY_FILE")
+var (
+	debugPolicyFile             = envknob.String("TS_DEBUG_SSH_POLICY_FILE")
+	debugIgnoreTailnetSSHPolicy = envknob.Bool("TS_DEBUG_SSH_IGNORE_TAILNET_POLICY")
+)
 
 // mayForwardLocalPortTo reports whether the ctx should be allowed to port forward
 // to the specified host and port.
@@ -124,6 +148,52 @@ func (srv *server) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string
 	return ss.action.AllowLocalPortForwarding
 }
 
+// requiresPubKey reports whether the SSH server, during the auth negotiation
+// phase, should requires that the client send an SSH public key. (or, more
+// specifically, that "none" auth isn't acceptable)
+func (srv *server) requiresPubKey(sshUser string, localAddr, remoteAddr net.Addr) bool {
+	pol, ok := srv.sshPolicy()
+	if !ok {
+		return false
+	}
+	a, ci, _, err := srv.evaluatePolicy(sshUser, localAddr, remoteAddr, nil)
+	if err == nil && (a.Accept || a.HoldAndDelegate != "") {
+		// Policy doesn't require a public key.
+		return false
+	}
+	if ci == nil {
+		// If we didn't get far enough along through evaluatePolicy to know the Tailscale
+		// identify of the remote side then it's going to fail quickly later anyway.
+		// Return false to accept "none" auth and reject the conn.
+		return false
+	}
+
+	// Is there any rule that looks like it'd require a public key for this
+	// sshUser?
+	for _, r := range pol.Rules {
+		if ci.ruleExpired(r) {
+			continue
+		}
+		if mapLocalUser(r.SSHUsers, sshUser) == "" {
+			continue
+		}
+		for _, p := range r.Principals {
+			if principalMatchesTailscaleIdentity(p, ci) && len(p.PubKeys) > 0 {
+				return true
+			}
+		}
+	}
+	return false
+}
+
+func (srv *server) acceptPubKey(sshUser string, localAddr, remoteAddr net.Addr, pubKey ssh.PublicKey) bool {
+	a, _, _, err := srv.evaluatePolicy(sshUser, localAddr, remoteAddr, pubKey)
+	if err != nil {
+		return false
+	}
+	return a.Accept || a.HoldAndDelegate != ""
+}
+
 // sshPolicy returns the SSHPolicy for current node.
 // If there is no SSHPolicy in the netmap, it returns a debugPolicy
 // if one is defined.
@@ -133,7 +203,7 @@ func (srv *server) sshPolicy() (_ *tailcfg.SSHPolicy, ok bool) {
 	if nm == nil {
 		return nil, false
 	}
-	if pol := nm.SSHPolicy; pol != nil {
+	if pol := nm.SSHPolicy; pol != nil && !debugIgnoreTailnetSSHPolicy {
 		return pol, true
 	}
 	if debugPolicyFile != "" {
@@ -167,19 +237,17 @@ func asTailscaleIPPort(a net.Addr) (netaddr.IPPort, error) {
 	return netaddr.IPPortFrom(tanetaddr, uint16(ta.Port)), nil
 }
 
-// evaluatePolicy returns the SSHAction, sshConnInfo and localUser
-// after evaluating the sshUser and remoteAddr against the SSHPolicy.
-// The remoteAddr and localAddr params must be Tailscale IPs.
-func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr net.Addr) (_ *tailcfg.SSHAction, _ *sshConnInfo, localUser string, _ error) {
-	logf := srv.logf
-	lb := srv.lb
-	logf("Handling SSH from %v for user %v", remoteAddr, sshUser)
-
+// evaluatePolicy returns the SSHAction, sshConnInfo and localUser after
+// evaluating the sshUser and remoteAddr against the SSHPolicy. The remoteAddr
+// and localAddr params must be Tailscale IPs.
+//
+// The return sshConnInfo will be non-nil, even on some errors, if the
+// evaluation made it far enough to resolve the remoteAddr to a Tailscale IP.
+func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr net.Addr, pubKey ssh.PublicKey) (_ *tailcfg.SSHAction, _ *sshConnInfo, localUser string, _ error) {
 	pol, ok := srv.sshPolicy()
 	if !ok {
 		return nil, nil, "", fmt.Errorf("tsshd: rejecting connection; no SSH policy")
 	}
-
 	srcIPP, err := asTailscaleIPPort(remoteAddr)
 	if err != nil {
 		return nil, nil, "", fmt.Errorf("tsshd: rejecting: %w", err)
@@ -188,59 +256,86 @@ func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr net.Addr
 	if err != nil {
 		return nil, nil, "", err
 	}
-	node, uprof, ok := lb.WhoIs(srcIPP)
+	node, uprof, ok := srv.lb.WhoIs(srcIPP)
 	if !ok {
-		return nil, nil, "", fmt.Errorf("Hello, %v. I don't know who you are.\n", srcIPP)
+		return nil, nil, "", fmt.Errorf("unknown Tailscale identity from src %v", srcIPP)
 	}
-
 	ci := &sshConnInfo{
-		now:     time.Now(),
-		sshUser: sshUser,
-		src:     srcIPP,
-		dst:     dstIPP,
-		node:    node,
-		uprof:   &uprof,
+		now:                time.Now(),
+		fetchPublicKeysURL: srv.fetchPublicKeysURL,
+		sshUser:            sshUser,
+		src:                srcIPP,
+		dst:                dstIPP,
+		node:               node,
+		uprof:              &uprof,
+		pubKey:             pubKey,
 	}
 	a, localUser, ok := evalSSHPolicy(pol, ci)
 	if !ok {
-		return nil, nil, "", fmt.Errorf("ssh: access denied for %q from %v", uprof.LoginName, ci.src.IP())
+		return nil, ci, "", fmt.Errorf("ssh: access denied for %q from %v", uprof.LoginName, ci.src.IP())
 	}
 	return a, ci, localUser, nil
 }
 
+func (srv *server) fetchPublicKeysURL(url string) ([]string, error) {
+	if !strings.HasPrefix(url, "https://") {
+		return nil, errors.New("invalid URL scheme")
+	}
+	// TODO(bradfitz): add caching
+
+	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+	defer cancel()
+	req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
+	if err != nil {
+		return nil, err
+	}
+	res, err := http.DefaultClient.Do(req)
+	if err != nil {
+		return nil, err
+	}
+	defer res.Body.Close()
+	if res.StatusCode != http.StatusOK {
+		return nil, errors.New(res.Status)
+	}
+	all, err := io.ReadAll(io.LimitReader(res.Body, 4<<10))
+	return strings.Split(string(all), "\n"), err
+}
+
 // handleSSH is invoked when a new SSH connection attempt is made.
 func (srv *server) handleSSH(s ssh.Session) {
 	logf := srv.logf
 
 	sshUser := s.User()
-	action, ci, localUser, err := srv.evaluatePolicy(sshUser, s.LocalAddr(), s.RemoteAddr())
+	action, ci, localUser, err := srv.evaluatePolicy(sshUser, s.LocalAddr(), s.RemoteAddr(), s.PublicKey())
 	if err != nil {
 		logf(err.Error())
 		s.Exit(1)
 		return
 	}
-
-	lu, err := user.Lookup(localUser)
-	if err != nil {
-		logf("ssh: user Lookup %q: %v", localUser, err)
-		s.Exit(1)
-		return
+	var lu *user.User
+	if localUser != "" {
+		lu, err = user.Lookup(localUser)
+		if err != nil {
+			logf("ssh: user Lookup %q: %v", localUser, err)
+			s.Exit(1)
+			return
+		}
 	}
-
 	ss := srv.newSSHSession(s, ci, lu)
+	ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", ci.uprof.LoginName, ci.src.IP(), sshUser)
 	action, err = ss.resolveTerminalAction(action)
 	if err != nil {
-		logf("ssh: resolveTerminalAction: %v", err)
+		ss.logf("resolveTerminalAction: %v", err)
 		io.WriteString(s.Stderr(), "Access denied: failed to resolve SSHAction.\n")
 		s.Exit(1)
 		return
 	}
 	if action.Reject || !action.Accept {
-		logf("ssh: access denied for %q from %v", ci.uprof.LoginName, ci.src.IP())
+		ss.logf("access denied for %v (%v)", ci.uprof.LoginName, ci.src.IP())
 		s.Exit(1)
 		return
 	}
-
+	ss.logf("access granted for %v (%v) to ssh-user %q", ci.uprof.LoginName, ci.src.IP(), sshUser)
 	ss.action = action
 	ss.run()
 }
@@ -320,6 +415,12 @@ type sshSession struct {
 	exitOnce sync.Once
 }
 
+func (ss *sshSession) vlogf(format string, args ...interface{}) {
+	if sshVerboseLogging {
+		ss.logf(format, args...)
+	}
+}
+
 func (srv *server) newSSHSession(s ssh.Session, ci *sshConnInfo, lu *user.User) *sshSession {
 	sharedID := fmt.Sprintf("%s-%02x", ci.now.UTC().Format("20060102T150405"), randBytes(5))
 	return &sshSession{
@@ -510,7 +611,7 @@ func (ss *sshSession) run() {
 
 	if euid := os.Geteuid(); euid != 0 {
 		if lu.Uid != fmt.Sprint(euid) {
-			logf("ssh: can't switch to user %q from process euid %v", localUser, euid)
+			ss.logf("can't switch to user %q from process euid %v", localUser, euid)
 			fmt.Fprintf(ss, "can't switch user\n")
 			ss.Exit(1)
 			return
@@ -522,7 +623,7 @@ func (ss *sshSession) run() {
 	ss.DisablePTYEmulation()
 
 	if err := ss.handleSSHAgentForwarding(ss, lu); err != nil {
-		logf("ssh: agent forwarding failed: %v", err)
+		ss.logf("agent forwarding failed: %v", err)
 	} else if ss.agentListener != nil {
 		// TODO(maisem/bradfitz): add a way to close all session resources
 		defer ss.agentListener.Close()
@@ -534,7 +635,7 @@ func (ss *sshSession) run() {
 		rec, err = ss.startNewRecording()
 		if err != nil {
 			fmt.Fprintf(ss, "can't start new recording\n")
-			logf("startNewRecording: %v", err)
+			ss.logf("startNewRecording: %v", err)
 			ss.Exit(1)
 			return
 		}
@@ -581,18 +682,18 @@ func (ss *sshSession) run() {
 	ss.exitOnce.Do(func() {})
 
 	if err == nil {
-		logf("ssh: Wait: ok")
+		ss.logf("Wait: ok")
 		ss.Exit(0)
 		return
 	}
 	if ee, ok := err.(*exec.ExitError); ok {
 		code := ee.ProcessState.ExitCode()
-		logf("ssh: Wait: code=%v", code)
+		ss.logf("Wait: code=%v", code)
 		ss.Exit(code)
 		return
 	}
 
-	logf("ssh: Wait: %v", err)
+	ss.logf("Wait: %v", err)
 	ss.Exit(1)
 	return
 }
@@ -609,6 +710,10 @@ type sshConnInfo struct {
 	// now is the time to consider the present moment for the
 	// purposes of rule evaluation.
 	now time.Time
+	// fetchPublicKeysURL, if non-nil, is a func to fetch the public
+	// keys of a URL. The strings are in the the typical public
+	// key "type base64-string [comment]" format seen at e.g. https://github.com/USER.keys
+	fetchPublicKeysURL func(url string) ([]string, error)
 
 	// sshUser is the requested local SSH username ("root", "alice", etc).
 	sshUser string
@@ -624,6 +729,18 @@ type sshConnInfo struct {
 
 	// uprof is node's UserProfile.
 	uprof *tailcfg.UserProfile
+
+	// pubKey is the public key presented by the client, or nil
+	// if they haven't yet sent one (as in the early "none" phase
+	// of authentication negotiation).
+	pubKey ssh.PublicKey
+}
+
+func (ci *sshConnInfo) ruleExpired(r *tailcfg.SSHRule) bool {
+	if r.RuleExpires == nil {
+		return false
+	}
+	return r.RuleExpires.Before(ci.now)
 }
 
 func evalSSHPolicy(pol *tailcfg.SSHPolicy, ci *sshConnInfo) (a *tailcfg.SSHAction, localUser string, ok bool) {
@@ -651,18 +768,18 @@ func matchRule(r *tailcfg.SSHRule, ci *sshConnInfo) (a *tailcfg.SSHAction, local
 	if r.Action == nil {
 		return nil, "", errNilAction
 	}
-	if r.RuleExpires != nil && ci.now.After(*r.RuleExpires) {
+	if ci.ruleExpired(r) {
 		return nil, "", errRuleExpired
 	}
-	if !matchesPrincipal(r.Principals, ci) {
-		return nil, "", errPrincipalMatch
-	}
 	if !r.Action.Reject || r.SSHUsers != nil {
 		localUser = mapLocalUser(r.SSHUsers, ci.sshUser)
 		if localUser == "" {
 			return nil, "", errUserMatch
 		}
 	}
+	if !anyPrincipalMatches(r.Principals, ci) {
+		return nil, "", errPrincipalMatch
+	}
 	return r.Action, localUser, nil
 }
 
@@ -677,29 +794,85 @@ func mapLocalUser(ruleSSHUsers map[string]string, reqSSHUser string) (localUser
 	return v
 }
 
-func matchesPrincipal(ps []*tailcfg.SSHPrincipal, ci *sshConnInfo) bool {
+func anyPrincipalMatches(ps []*tailcfg.SSHPrincipal, ci *sshConnInfo) bool {
 	for _, p := range ps {
 		if p == nil {
 			continue
 		}
-		if p.Any {
+		if principalMatches(p, ci) {
 			return true
 		}
-		if !p.Node.IsZero() && ci.node != nil && p.Node == ci.node.StableID {
+	}
+	return false
+}
+
+func principalMatches(p *tailcfg.SSHPrincipal, ci *sshConnInfo) bool {
+	return principalMatchesTailscaleIdentity(p, ci) &&
+		principalMatchesPubKey(p, ci)
+}
+
+// principalMatchesTailscaleIdentity reports whether one of p's four fields
+// that match the Tailscale identity match (Node, NodeIP, UserLogin, Any).
+// This function does not consider PubKeys.
+func principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal, ci *sshConnInfo) bool {
+	if p.Any {
+		return true
+	}
+	if !p.Node.IsZero() && ci.node != nil && p.Node == ci.node.StableID {
+		return true
+	}
+	if p.NodeIP != "" {
+		if ip, _ := netaddr.ParseIP(p.NodeIP); ip == ci.src.IP() {
 			return true
 		}
-		if p.NodeIP != "" {
-			if ip, _ := netaddr.ParseIP(p.NodeIP); ip == ci.src.IP() {
-				return true
-			}
+	}
+	if p.UserLogin != "" && ci.uprof != nil && ci.uprof.LoginName == p.UserLogin {
+		return true
+	}
+	return false
+}
+
+func principalMatchesPubKey(p *tailcfg.SSHPrincipal, ci *sshConnInfo) bool {
+	if len(p.PubKeys) == 0 {
+		return true
+	}
+	if ci.pubKey == nil {
+		return false
+	}
+	pubKeys := p.PubKeys
+	if len(pubKeys) == 1 && strings.HasPrefix(pubKeys[0], "https://") {
+		if ci.fetchPublicKeysURL == nil {
+			// TODO: log?
+			return false
+		}
+		var err error
+		pubKeys, err = ci.fetchPublicKeysURL(pubKeys[0])
+		if err != nil {
+			// TODO: log?
+			return false
 		}
-		if p.UserLogin != "" && ci.uprof != nil && ci.uprof.LoginName == p.UserLogin {
+	}
+	for _, pubKey := range pubKeys {
+		if pubKeyMatchesAuthorizedKey(ci.pubKey, pubKey) {
 			return true
 		}
 	}
 	return false
 }
 
+func pubKeyMatchesAuthorizedKey(pubKey ssh.PublicKey, wantKey string) bool {
+	wantKeyType, rest, ok := strings.Cut(wantKey, " ")
+	if !ok {
+		return false
+	}
+	if pubKey.Type() != wantKeyType {
+		return false
+	}
+	wantKeyB64, _, _ := strings.Cut(rest, " ")
+	wantKeyData, _ := base64.StdEncoding.DecodeString(wantKeyB64)
+	return len(wantKeyData) > 0 && bytes.Equal(pubKey.Marshal(), wantKeyData)
+}
+
 func randBytes(n int) []byte {
 	b := make([]byte, n)
 	if _, err := rand.Read(b); err != nil {

+ 4 - 1
ssh/tailssh/tailssh_test.go

@@ -63,7 +63,10 @@ func TestMatchRule(t *testing.T) {
 			name: "no-principal",
 			rule: &tailcfg.SSHRule{
 				Action: someAction,
-			},
+				SSHUsers: map[string]string{
+					"*": "ubuntu",
+				}},
+			ci:      &sshConnInfo{},
 			wantErr: errPrincipalMatch,
 		},
 		{

+ 11 - 5
tailcfg/tailcfg.go

@@ -1593,16 +1593,22 @@ type SSHRule struct {
 }
 
 // SSHPrincipal is either a particular node or a user on any node.
-// Any matching field causes a match.
 type SSHPrincipal struct {
+	// Matching any one of the following four field causes a match.
+	// It must also match Certs, if non-empty.
+
 	Node      StableNodeID `json:"node,omitempty"`
 	NodeIP    string       `json:"nodeIP,omitempty"`
 	UserLogin string       `json:"userLogin,omitempty"` // email-ish: [email protected], bar@github
-
-	// Any, if true, matches any user.
-	Any bool `json:"any,omitempty"`
-
+	Any       bool         `json:"any,omitempty"`       // if true, match any connection
 	// TODO(bradfitz): add StableUserID, once that exists
+
+	// PubKeys, if non-empty, means that this SSHPrincipal only
+	// matches if one of these public keys is presented by the user.
+	//
+	// As a special case, if len(PubKeys) == 1 and PubKeys[0] starts
+	// with "https://", then it's fetched (like https://github.com/username.keys).
+	PubKeys []string `json:"pubKeys,omitempty"`
 }
 
 // SSHAction is how to handle an incoming connection.

+ 0 - 2
wgengine/netstack/netstack.go

@@ -666,8 +666,6 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
 			ns.logf("handling SSH connection....")
 			if err := handleSSH(ns.logf, ns.lb, c); err != nil {
 				ns.logf("ssh error: %v", err)
-			} else {
-				ns.logf("ssh: ok")
 			}
 			return
 		}