فهرست منبع

ssh/tailssh: add TestSSHAuthFlow

Signed-off-by: Maisem Ali <[email protected]>
Maisem Ali 3 سال پیش
والد
کامیت
ecf6cdd830
3فایلهای تغییر یافته به همراه299 افزوده شده و 15 حذف شده
  1. 20 1
      net/nettest/conn.go
  2. 22 11
      ssh/tailssh/tailssh.go
  3. 257 3
      ssh/tailssh/tailssh_test.go

+ 20 - 1
net/nettest/conn.go

@@ -6,6 +6,7 @@ package nettest
 
 import (
 	"net"
+	"net/netip"
 	"time"
 )
 
@@ -32,20 +33,38 @@ func NewConn(name string, maxBuf int) (Conn, Conn) {
 	return &connHalf{r: r, w: w}, &connHalf{r: w, w: r}
 }
 
+// NewTCPConn creates a pair of Conns that are wired together by pipes.
+func NewTCPConn(src, dst netip.AddrPort, maxBuf int) (local Conn, remote Conn) {
+	r := NewPipe(src.String(), maxBuf)
+	w := NewPipe(dst.String(), maxBuf)
+
+	lAddr := net.TCPAddrFromAddrPort(src)
+	rAddr := net.TCPAddrFromAddrPort(dst)
+
+	return &connHalf{r: r, w: w, remote: rAddr, local: lAddr}, &connHalf{r: w, w: r, remote: lAddr, local: rAddr}
+}
+
 type connAddr string
 
 func (a connAddr) Network() string { return "mem" }
 func (a connAddr) String() string  { return string(a) }
 
 type connHalf struct {
-	r, w *Pipe
+	local, remote net.Addr
+	r, w          *Pipe
 }
 
 func (c *connHalf) LocalAddr() net.Addr {
+	if c.local != nil {
+		return c.local
+	}
 	return connAddr(c.r.name)
 }
 
 func (c *connHalf) RemoteAddr() net.Addr {
+	if c.remote != nil {
+		return c.remote
+	}
 	return connAddr(c.w.name)
 }
 

+ 22 - 11
ssh/tailssh/tailssh.go

@@ -39,6 +39,7 @@ import (
 	"tailscale.com/tailcfg"
 	"tailscale.com/tempfork/gliderlabs/ssh"
 	"tailscale.com/types/logger"
+	"tailscale.com/types/netmap"
 	"tailscale.com/util/clientmetric"
 	"tailscale.com/util/mak"
 )
@@ -47,8 +48,19 @@ var (
 	sshVerboseLogging = envknob.RegisterBool("TS_DEBUG_SSH_VLOG")
 )
 
+// ipnLocalBackend is the subset of ipnlocal.LocalBackend that we use.
+// It is used for testing.
+type ipnLocalBackend interface {
+	GetSSH_HostKeys() ([]gossh.Signer, error)
+	ShouldRunSSH() bool
+	NetMap() *netmap.NetworkMap
+	WhoIs(ipp netip.AddrPort) (n *tailcfg.Node, u tailcfg.UserProfile, ok bool)
+	DoNoiseRequest(req *http.Request) (*http.Response, error)
+	TailscaleVarRoot() string
+}
+
 type server struct {
-	lb             *ipnlocal.LocalBackend
+	lb             ipnLocalBackend
 	logf           logger.Logf
 	tailscaledPath string
 
@@ -212,7 +224,10 @@ func (c *conn) logf(format string, args ...any) {
 	c.srv.logf(format, args...)
 }
 
-// isAuthorized returns nil if the connection is authorized to proceed.
+// isAuthorized walks through the action chain and returns nil if the connection
+// is authorized. If the connection is not authorized, it returns
+// gossh.ErrDenied. If the action chain resolution fails, it returns the
+// resolution error.
 func (c *conn) isAuthorized(ctx ssh.Context) error {
 	action := c.currentAction
 	for {
@@ -525,7 +540,7 @@ func (c *conn) setInfo(ctx ssh.Context) error {
 		return fmt.Errorf("unknown Tailscale identity from src %v", ci.src)
 	}
 	ci.node = node
-	ci.uprof = &uprof
+	ci.uprof = uprof
 
 	c.idH = ctx.SessionID()
 	c.info = ci
@@ -743,12 +758,8 @@ func (c *conn) expandPublicKeyURL(pubKeyURL string) string {
 	if !strings.Contains(pubKeyURL, "$") {
 		return pubKeyURL
 	}
-	var localPart string
-	var loginName string
-	if c.info.uprof != nil {
-		loginName = c.info.uprof.LoginName
-		localPart, _, _ = strings.Cut(loginName, "@")
-	}
+	loginName := c.info.uprof.LoginName
+	localPart, _, _ := strings.Cut(loginName, "@")
 	return strings.NewReplacer(
 		"$LOGINNAME_EMAIL", loginName,
 		"$LOGINNAME_LOCALPART", localPart,
@@ -1108,7 +1119,7 @@ type sshConnInfo struct {
 	node *tailcfg.Node
 
 	// uprof is node's UserProfile.
-	uprof *tailcfg.UserProfile
+	uprof tailcfg.UserProfile
 }
 
 func (ci *sshConnInfo) String() string {
@@ -1223,7 +1234,7 @@ func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool {
 			return true
 		}
 	}
-	if p.UserLogin != "" && ci.uprof != nil && ci.uprof.LoginName == p.UserLogin {
+	if p.UserLogin != "" && ci.uprof.LoginName == p.UserLogin {
 		return true
 	}
 	return false

+ 257 - 3
ssh/tailssh/tailssh_test.go

@@ -9,7 +9,10 @@ package tailssh
 
 import (
 	"bytes"
+	"crypto/ed25519"
+	"crypto/rand"
 	"crypto/sha256"
+	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
@@ -21,20 +24,27 @@ import (
 	"os/exec"
 	"os/user"
 	"reflect"
+	"runtime"
 	"strings"
+	"sync"
 	"sync/atomic"
 	"testing"
 	"time"
 
+	gossh "github.com/tailscale/golang-x-crypto/ssh"
 	"tailscale.com/ipn/ipnlocal"
 	"tailscale.com/ipn/store/mem"
+	"tailscale.com/net/nettest"
 	"tailscale.com/net/tsdial"
 	"tailscale.com/tailcfg"
 	"tailscale.com/tempfork/gliderlabs/ssh"
 	"tailscale.com/tstest"
 	"tailscale.com/types/logger"
+	"tailscale.com/types/netmap"
 	"tailscale.com/util/cibuild"
 	"tailscale.com/util/lineread"
+	"tailscale.com/util/must"
+	"tailscale.com/util/strs"
 	"tailscale.com/wgengine"
 )
 
@@ -173,7 +183,7 @@ func TestMatchRule(t *testing.T) {
 				Principals: []*tailcfg.SSHPrincipal{{UserLogin: "[email protected]"}},
 				SSHUsers:   map[string]string{"*": "ubuntu"},
 			},
-			ci:       &sshConnInfo{uprof: &tailcfg.UserProfile{LoginName: "[email protected]"}},
+			ci:       &sshConnInfo{uprof: tailcfg.UserProfile{LoginName: "[email protected]"}},
 			wantUser: "ubuntu",
 		},
 		{
@@ -211,6 +221,250 @@ func TestMatchRule(t *testing.T) {
 
 func timePtr(t time.Time) *time.Time { return &t }
 
+// localState implements ipnLocalBackend for testing.
+type localState struct {
+	sshEnabled   bool
+	matchingRule *tailcfg.SSHRule
+
+	// serverActions is a map of the action name to the action.
+	// It is served for paths like https://unused/ssh-action/<action-name>.
+	// The action name is the last part of the action URL.
+	serverActions map[string]*tailcfg.SSHAction
+}
+
+var (
+	currentUser    = os.Getenv("USER") // Use the current user for the test.
+	testSigner     gossh.Signer
+	testSignerOnce sync.Once
+)
+
+func (ts *localState) GetSSH_HostKeys() ([]gossh.Signer, error) {
+	testSignerOnce.Do(func() {
+		_, priv, err := ed25519.GenerateKey(rand.Reader)
+		if err != nil {
+			panic(err)
+		}
+		s, err := gossh.NewSignerFromSigner(priv)
+		if err != nil {
+			panic(err)
+		}
+		testSigner = s
+	})
+	return []gossh.Signer{testSigner}, nil
+}
+
+func (ts *localState) ShouldRunSSH() bool {
+	return ts.sshEnabled
+}
+
+func (ts *localState) NetMap() *netmap.NetworkMap {
+	var policy *tailcfg.SSHPolicy
+	if ts.matchingRule != nil {
+		policy = &tailcfg.SSHPolicy{
+			Rules: []*tailcfg.SSHRule{
+				ts.matchingRule,
+			},
+		}
+	}
+
+	return &netmap.NetworkMap{
+		SelfNode: &tailcfg.Node{
+			ID: 1,
+		},
+		SSHPolicy: policy,
+	}
+}
+
+func (ts *localState) WhoIs(ipp netip.AddrPort) (n *tailcfg.Node, u tailcfg.UserProfile, ok bool) {
+	return &tailcfg.Node{
+			ID:       2,
+			StableID: "peer-id",
+		}, tailcfg.UserProfile{
+			LoginName: "peer",
+		}, true
+
+}
+
+func (ts *localState) DoNoiseRequest(req *http.Request) (*http.Response, error) {
+	rec := httptest.NewRecorder()
+	k, ok := strs.CutPrefix(req.URL.Path, "/ssh-action/")
+	if !ok {
+		rec.WriteHeader(http.StatusNotFound)
+	}
+	a, ok := ts.serverActions[k]
+	if !ok {
+		rec.WriteHeader(http.StatusNotFound)
+		return rec.Result(), nil
+	}
+	rec.WriteHeader(http.StatusOK)
+	if err := json.NewEncoder(rec).Encode(a); err != nil {
+		return nil, err
+	}
+	return rec.Result(), nil
+}
+
+func (ts *localState) TailscaleVarRoot() string {
+	return ""
+}
+
+func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule {
+	return &tailcfg.SSHRule{
+		SSHUsers: map[string]string{
+			"*": currentUser,
+		},
+		Action: action,
+		Principals: []*tailcfg.SSHPrincipal{
+			{
+				Any: true,
+			},
+		},
+	}
+}
+
+func TestSSHAuthFlow(t *testing.T) {
+	if runtime.GOOS != "linux" {
+		t.Skip("Not running on Linux, skipping")
+	}
+	acceptRule := newSSHRule(&tailcfg.SSHAction{
+		Accept:  true,
+		Message: "Welcome to Tailscale SSH!",
+	})
+	rejectRule := newSSHRule(&tailcfg.SSHAction{
+		Reject:  true,
+		Message: "Go Away!",
+	})
+
+	tests := []struct {
+		name       string
+		state      *localState
+		wantBanner string
+		authErr    bool
+	}{
+		{
+			name: "no-policy",
+			state: &localState{
+				sshEnabled: true,
+			},
+			authErr: true,
+		},
+		{
+			name: "accept",
+			state: &localState{
+				sshEnabled:   true,
+				matchingRule: acceptRule,
+			},
+			wantBanner: "Welcome to Tailscale SSH!",
+		},
+		{
+			name: "reject",
+			state: &localState{
+				sshEnabled:   true,
+				matchingRule: rejectRule,
+			},
+			wantBanner: "Go Away!",
+			authErr:    true,
+		},
+		{
+			name: "simple-check",
+			state: &localState{
+				sshEnabled: true,
+				matchingRule: newSSHRule(&tailcfg.SSHAction{
+					HoldAndDelegate: "https://unused/ssh-action/accept",
+				}),
+				serverActions: map[string]*tailcfg.SSHAction{
+					"accept": acceptRule.Action,
+				},
+			},
+			wantBanner: "Welcome to Tailscale SSH!",
+		},
+		{
+			name: "multi-check",
+			state: &localState{
+				sshEnabled: true,
+				matchingRule: newSSHRule(&tailcfg.SSHAction{
+					HoldAndDelegate: "https://unused/ssh-action/check1",
+				}),
+				serverActions: map[string]*tailcfg.SSHAction{
+					"check1": {
+						Message:         "url-here",
+						HoldAndDelegate: "https://unused/ssh-action/check2",
+					},
+					"check2": acceptRule.Action,
+				},
+			},
+			wantBanner: "url-here",
+		},
+		{
+			name: "check-reject",
+			state: &localState{
+				sshEnabled: true,
+				matchingRule: newSSHRule(&tailcfg.SSHAction{
+					HoldAndDelegate: "https://unused/ssh-action/reject",
+				}),
+				serverActions: map[string]*tailcfg.SSHAction{
+					"reject": rejectRule.Action,
+				},
+			},
+			wantBanner: "Go Away!",
+			authErr:    true,
+		},
+	}
+	s := &server{
+		logf: logger.Discard,
+	}
+	defer s.Shutdown()
+	src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			sc, dc := nettest.NewTCPConn(src, dst, 1024)
+			s.lb = tc.state
+			cfg := &gossh.ClientConfig{
+				User:            "alice",
+				HostKeyCallback: gossh.InsecureIgnoreHostKey(),
+				BannerCallback: func(message string) error {
+					if message != tc.wantBanner {
+						t.Errorf("BannerCallback = %q; want %q", message, tc.wantBanner)
+					}
+					return nil
+				},
+			}
+			var wg sync.WaitGroup
+			wg.Add(1)
+			go func() {
+				defer wg.Done()
+				c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
+				if err != nil {
+					if !tc.authErr {
+						t.Errorf("client: %v", err)
+					}
+					return
+				} else if tc.authErr {
+					c.Close()
+					t.Errorf("client: expected error, got nil")
+					return
+				}
+				client := gossh.NewClient(c, chans, reqs)
+				defer client.Close()
+				session, err := client.NewSession()
+				if err != nil {
+					t.Errorf("client: %v", err)
+					return
+				}
+				defer session.Close()
+				o, err := session.CombinedOutput("echo Ran echo!")
+				if err != nil {
+					t.Errorf("client: %v", err)
+				}
+				t.Logf("output: %s", o)
+			}()
+			if err := s.HandleSSHConn(dc); err != nil {
+				t.Errorf("unexpected error: %v", err)
+			}
+			wg.Wait()
+		})
+	}
+}
+
 func TestSSH(t *testing.T) {
 	var logf logger.Logf = t.Logf
 	eng, err := wgengine.NewFakeUserspaceEngine(logf, 0)
@@ -249,7 +503,7 @@ func TestSSH(t *testing.T) {
 		src:     netip.MustParseAddrPort("1.2.3.4:32342"),
 		dst:     netip.MustParseAddrPort("1.2.3.5:22"),
 		node:    &tailcfg.Node{},
-		uprof:   &tailcfg.UserProfile{},
+		uprof:   tailcfg.UserProfile{},
 	}
 	sc.finalAction = &tailcfg.SSHAction{Accept: true}
 
@@ -428,7 +682,7 @@ func TestPublicKeyFetching(t *testing.T) {
 func TestExpandPublicKeyURL(t *testing.T) {
 	c := &conn{
 		info: &sshConnInfo{
-			uprof: &tailcfg.UserProfile{
+			uprof: tailcfg.UserProfile{
 				LoginName: "[email protected]",
 			},
 		},