|
|
@@ -8,12 +8,15 @@ package tailssh
|
|
|
import (
|
|
|
"bytes"
|
|
|
"context"
|
|
|
+ "crypto/ecdsa"
|
|
|
"crypto/ed25519"
|
|
|
+ "crypto/elliptic"
|
|
|
"crypto/rand"
|
|
|
"encoding/json"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
+ "log"
|
|
|
"net"
|
|
|
"net/http"
|
|
|
"net/http/httptest"
|
|
|
@@ -41,7 +44,7 @@ import (
|
|
|
"tailscale.com/sessionrecording"
|
|
|
"tailscale.com/tailcfg"
|
|
|
"tailscale.com/tempfork/gliderlabs/ssh"
|
|
|
- sshtest "tailscale.com/tempfork/sshtest/ssh"
|
|
|
+ testssh "tailscale.com/tempfork/sshtest/ssh"
|
|
|
"tailscale.com/tsd"
|
|
|
"tailscale.com/tstest"
|
|
|
"tailscale.com/types/key"
|
|
|
@@ -56,8 +59,6 @@ import (
|
|
|
"tailscale.com/wgengine"
|
|
|
)
|
|
|
|
|
|
-type _ = sshtest.Client // TODO(bradfitz,percy): sshtest; delete this line
|
|
|
-
|
|
|
func TestMatchRule(t *testing.T) {
|
|
|
someAction := new(tailcfg.SSHAction)
|
|
|
tests := []struct {
|
|
|
@@ -510,9 +511,9 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
|
|
|
defer s.Shutdown()
|
|
|
|
|
|
const sshUser = "alice"
|
|
|
- cfg := &gossh.ClientConfig{
|
|
|
+ cfg := &testssh.ClientConfig{
|
|
|
User: sshUser,
|
|
|
- HostKeyCallback: gossh.InsecureIgnoreHostKey(),
|
|
|
+ HostKeyCallback: testssh.InsecureIgnoreHostKey(),
|
|
|
}
|
|
|
|
|
|
tests := []struct {
|
|
|
@@ -559,12 +560,12 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
|
|
|
wg.Add(1)
|
|
|
go func() {
|
|
|
defer wg.Done()
|
|
|
- c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
|
|
|
+ c, chans, reqs, err := testssh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
|
|
|
if err != nil {
|
|
|
t.Errorf("client: %v", err)
|
|
|
return
|
|
|
}
|
|
|
- client := gossh.NewClient(c, chans, reqs)
|
|
|
+ client := testssh.NewClient(c, chans, reqs)
|
|
|
defer client.Close()
|
|
|
session, err := client.NewSession()
|
|
|
if err != nil {
|
|
|
@@ -645,21 +646,21 @@ func TestMultipleRecorders(t *testing.T) {
|
|
|
sc, dc := memnet.NewTCPConn(src, dst, 1024)
|
|
|
|
|
|
const sshUser = "alice"
|
|
|
- cfg := &gossh.ClientConfig{
|
|
|
+ cfg := &testssh.ClientConfig{
|
|
|
User: sshUser,
|
|
|
- HostKeyCallback: gossh.InsecureIgnoreHostKey(),
|
|
|
+ HostKeyCallback: testssh.InsecureIgnoreHostKey(),
|
|
|
}
|
|
|
|
|
|
var wg sync.WaitGroup
|
|
|
wg.Add(1)
|
|
|
go func() {
|
|
|
defer wg.Done()
|
|
|
- c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
|
|
|
+ c, chans, reqs, err := testssh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
|
|
|
if err != nil {
|
|
|
t.Errorf("client: %v", err)
|
|
|
return
|
|
|
}
|
|
|
- client := gossh.NewClient(c, chans, reqs)
|
|
|
+ client := testssh.NewClient(c, chans, reqs)
|
|
|
defer client.Close()
|
|
|
session, err := client.NewSession()
|
|
|
if err != nil {
|
|
|
@@ -736,21 +737,21 @@ func TestSSHRecordingNonInteractive(t *testing.T) {
|
|
|
sc, dc := memnet.NewTCPConn(src, dst, 1024)
|
|
|
|
|
|
const sshUser = "alice"
|
|
|
- cfg := &gossh.ClientConfig{
|
|
|
+ cfg := &testssh.ClientConfig{
|
|
|
User: sshUser,
|
|
|
- HostKeyCallback: gossh.InsecureIgnoreHostKey(),
|
|
|
+ HostKeyCallback: testssh.InsecureIgnoreHostKey(),
|
|
|
}
|
|
|
|
|
|
var wg sync.WaitGroup
|
|
|
wg.Add(1)
|
|
|
go func() {
|
|
|
defer wg.Done()
|
|
|
- c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
|
|
|
+ c, chans, reqs, err := testssh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
|
|
|
if err != nil {
|
|
|
t.Errorf("client: %v", err)
|
|
|
return
|
|
|
}
|
|
|
- client := gossh.NewClient(c, chans, reqs)
|
|
|
+ client := testssh.NewClient(c, chans, reqs)
|
|
|
defer client.Close()
|
|
|
session, err := client.NewSession()
|
|
|
if err != nil {
|
|
|
@@ -886,80 +887,151 @@ func TestSSHAuthFlow(t *testing.T) {
|
|
|
},
|
|
|
}
|
|
|
s := &server{
|
|
|
- logf: logger.Discard,
|
|
|
+ logf: log.Printf,
|
|
|
}
|
|
|
defer s.Shutdown()
|
|
|
src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
|
|
|
for _, tc := range tests {
|
|
|
- t.Run(tc.name, func(t *testing.T) {
|
|
|
- sc, dc := memnet.NewTCPConn(src, dst, 1024)
|
|
|
- s.lb = tc.state
|
|
|
- sshUser := "alice"
|
|
|
- if tc.sshUser != "" {
|
|
|
- sshUser = tc.sshUser
|
|
|
- }
|
|
|
- var passwordUsed atomic.Bool
|
|
|
- cfg := &gossh.ClientConfig{
|
|
|
- User: sshUser,
|
|
|
- HostKeyCallback: gossh.InsecureIgnoreHostKey(),
|
|
|
- Auth: []gossh.AuthMethod{
|
|
|
- gossh.PasswordCallback(func() (secret string, err error) {
|
|
|
- if !tc.usesPassword {
|
|
|
- t.Error("unexpected use of PasswordCallback")
|
|
|
- return "", errors.New("unexpected use of PasswordCallback")
|
|
|
- }
|
|
|
+ for _, authMethods := range [][]string{nil, {"publickey", "password"}, {"password", "publickey"}} {
|
|
|
+ t.Run(fmt.Sprintf("%s-skip-none-auth-%v", tc.name, strings.Join(authMethods, "-then-")), func(t *testing.T) {
|
|
|
+ sc, dc := memnet.NewTCPConn(src, dst, 1024)
|
|
|
+ s.lb = tc.state
|
|
|
+ sshUser := "alice"
|
|
|
+ if tc.sshUser != "" {
|
|
|
+ sshUser = tc.sshUser
|
|
|
+ }
|
|
|
+
|
|
|
+ wantBanners := slices.Clone(tc.wantBanners)
|
|
|
+ noneAuthEnabled := len(authMethods) == 0
|
|
|
+
|
|
|
+ var publicKeyUsed atomic.Bool
|
|
|
+ var passwordUsed atomic.Bool
|
|
|
+ var methods []testssh.AuthMethod
|
|
|
+
|
|
|
+ for _, authMethod := range authMethods {
|
|
|
+ switch authMethod {
|
|
|
+ case "publickey":
|
|
|
+ methods = append(methods,
|
|
|
+ testssh.PublicKeysCallback(func() (signers []testssh.Signer, err error) {
|
|
|
+ publicKeyUsed.Store(true)
|
|
|
+ key, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ sig, err := testssh.NewSignerFromKey(key)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ return []testssh.Signer{sig}, nil
|
|
|
+ }))
|
|
|
+ case "password":
|
|
|
+ methods = append(methods, testssh.PasswordCallback(func() (secret string, err error) {
|
|
|
+ passwordUsed.Store(true)
|
|
|
+ return "any-pass", nil
|
|
|
+ }))
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if noneAuthEnabled && tc.usesPassword {
|
|
|
+ methods = append(methods, testssh.PasswordCallback(func() (secret string, err error) {
|
|
|
passwordUsed.Store(true)
|
|
|
return "any-pass", nil
|
|
|
- }),
|
|
|
- },
|
|
|
- BannerCallback: func(message string) error {
|
|
|
- if len(tc.wantBanners) == 0 {
|
|
|
- t.Errorf("unexpected banner: %q", message)
|
|
|
- } else if message != tc.wantBanners[0] {
|
|
|
- t.Errorf("banner = %q; want %q", message, tc.wantBanners[0])
|
|
|
- } else {
|
|
|
- t.Logf("banner = %q", message)
|
|
|
- tc.wantBanners = tc.wantBanners[1:]
|
|
|
+ }))
|
|
|
+ }
|
|
|
+
|
|
|
+ cfg := &testssh.ClientConfig{
|
|
|
+ User: sshUser,
|
|
|
+ HostKeyCallback: testssh.InsecureIgnoreHostKey(),
|
|
|
+ SkipNoneAuth: !noneAuthEnabled,
|
|
|
+ Auth: methods,
|
|
|
+ BannerCallback: func(message string) error {
|
|
|
+ if len(wantBanners) == 0 {
|
|
|
+ t.Errorf("unexpected banner: %q", message)
|
|
|
+ } else if message != wantBanners[0] {
|
|
|
+ t.Errorf("banner = %q; want %q", message, wantBanners[0])
|
|
|
+ } else {
|
|
|
+ t.Logf("banner = %q", message)
|
|
|
+ wantBanners = wantBanners[1:]
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ var wg sync.WaitGroup
|
|
|
+ wg.Add(1)
|
|
|
+ go func() {
|
|
|
+ defer wg.Done()
|
|
|
+ c, chans, reqs, err := testssh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
|
|
|
+ if err != nil {
|
|
|
+ if !tc.authErr {
|
|
|
+ t.Errorf("client: %v", err)
|
|
|
+ }
|
|
|
+ return
|
|
|
+ } else if tc.authErr {
|
|
|
+ c.Close()
|
|
|
+ t.Errorf("client: expected error, got nil")
|
|
|
+ return
|
|
|
}
|
|
|
- return nil
|
|
|
- },
|
|
|
- }
|
|
|
- var wg sync.WaitGroup
|
|
|
- wg.Add(1)
|
|
|
- go func() {
|
|
|
- defer wg.Done()
|
|
|
- c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
|
|
|
- if err != nil {
|
|
|
- if !tc.authErr {
|
|
|
+ client := testssh.NewClient(c, chans, reqs)
|
|
|
+ defer client.Close()
|
|
|
+ session, err := client.NewSession()
|
|
|
+ if err != nil {
|
|
|
t.Errorf("client: %v", err)
|
|
|
+ return
|
|
|
}
|
|
|
- return
|
|
|
- } else if tc.authErr {
|
|
|
- c.Close()
|
|
|
- t.Errorf("client: expected error, got nil")
|
|
|
- return
|
|
|
+ defer session.Close()
|
|
|
+ _, err = session.CombinedOutput("echo Ran echo!")
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("client: %v", err)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+ if err := s.HandleSSHConn(dc); err != nil {
|
|
|
+ t.Errorf("unexpected error: %v", err)
|
|
|
}
|
|
|
- client := gossh.NewClient(c, chans, reqs)
|
|
|
- defer client.Close()
|
|
|
- session, err := client.NewSession()
|
|
|
- if err != nil {
|
|
|
- t.Errorf("client: %v", err)
|
|
|
- return
|
|
|
+ wg.Wait()
|
|
|
+ if len(wantBanners) > 0 {
|
|
|
+ t.Errorf("missing banners: %v", wantBanners)
|
|
|
}
|
|
|
- defer session.Close()
|
|
|
- _, err = session.CombinedOutput("echo Ran echo!")
|
|
|
- if err != nil {
|
|
|
- t.Errorf("client: %v", err)
|
|
|
+
|
|
|
+ // Check to see which callbacks were invoked.
|
|
|
+ //
|
|
|
+ // When `none` auth is enabled, the public key callback should
|
|
|
+ // never fire, and the password callback should only fire if
|
|
|
+ // authentication succeeded and the client was trying to force
|
|
|
+ // password authentication by connecting with the '-password'
|
|
|
+ // username suffix.
|
|
|
+ //
|
|
|
+ // When skipping `none` auth, the first callback should always
|
|
|
+ // fire, and the 2nd callback should fire only if
|
|
|
+ // authentication failed.
|
|
|
+ wantPublicKey := false
|
|
|
+ wantPassword := false
|
|
|
+ if noneAuthEnabled {
|
|
|
+ wantPassword = !tc.authErr && tc.usesPassword
|
|
|
+ } else {
|
|
|
+ for i, authMethod := range authMethods {
|
|
|
+ switch authMethod {
|
|
|
+ case "publickey":
|
|
|
+ wantPublicKey = i == 0 || tc.authErr
|
|
|
+ case "password":
|
|
|
+ wantPassword = i == 0 || tc.authErr
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
- }()
|
|
|
- if err := s.HandleSSHConn(dc); err != nil {
|
|
|
- t.Errorf("unexpected error: %v", err)
|
|
|
- }
|
|
|
- wg.Wait()
|
|
|
- if len(tc.wantBanners) > 0 {
|
|
|
- t.Errorf("missing banners: %v", tc.wantBanners)
|
|
|
- }
|
|
|
- })
|
|
|
+
|
|
|
+ if wantPublicKey && !publicKeyUsed.Load() {
|
|
|
+ t.Error("public key should have been attempted")
|
|
|
+ } else if !wantPublicKey && publicKeyUsed.Load() {
|
|
|
+ t.Errorf("public key should not have been attempted")
|
|
|
+ }
|
|
|
+
|
|
|
+ if wantPassword && !passwordUsed.Load() {
|
|
|
+ t.Error("password should have been attempted")
|
|
|
+ } else if !wantPassword && passwordUsed.Load() {
|
|
|
+ t.Error("password should not have been attempted")
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|