|
|
@@ -7,6 +7,7 @@ package tailssh
|
|
|
|
|
|
import (
|
|
|
"bytes"
|
|
|
+ "context"
|
|
|
"crypto/ed25519"
|
|
|
"crypto/rand"
|
|
|
"crypto/sha256"
|
|
|
@@ -14,6 +15,7 @@ import (
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
+ "io/ioutil"
|
|
|
"net"
|
|
|
"net/http"
|
|
|
"net/http/httptest"
|
|
|
@@ -324,9 +326,101 @@ func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// TestSSHRecordingNonInteractive tests that the SSH server records the SSH session
|
|
|
+// when the client is not interactive (i.e. no PTY).
|
|
|
+// It starts a local SSH server and a recording server. The recording server
|
|
|
+// records the SSH session and returns it to the test.
|
|
|
+// The test then verifies that the recording has a valid CastHeader, it does not
|
|
|
+// validate the contents of the recording.
|
|
|
+func TestSSHRecordingNonInteractive(t *testing.T) {
|
|
|
+ if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
|
|
|
+ t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
|
|
|
+ }
|
|
|
+ var recording []byte
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
|
+ recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
+ defer cancel()
|
|
|
+ var err error
|
|
|
+ recording, err = ioutil.ReadAll(r.Body)
|
|
|
+ if err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ w.WriteHeader(http.StatusOK)
|
|
|
+ }))
|
|
|
+ defer recordingServer.Close()
|
|
|
+
|
|
|
+ state := &localState{
|
|
|
+ sshEnabled: true,
|
|
|
+ matchingRule: newSSHRule(
|
|
|
+ &tailcfg.SSHAction{
|
|
|
+ Accept: true,
|
|
|
+ Recorders: []netip.AddrPort{
|
|
|
+ must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())),
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ),
|
|
|
+ }
|
|
|
+ s := &server{
|
|
|
+ logf: t.Logf,
|
|
|
+ httpc: recordingServer.Client(),
|
|
|
+ }
|
|
|
+ defer s.Shutdown()
|
|
|
+
|
|
|
+ src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
|
|
|
+ sc, dc := memnet.NewTCPConn(src, dst, 1024)
|
|
|
+ s.lb = state
|
|
|
+
|
|
|
+ const sshUser = "alice"
|
|
|
+ cfg := &gossh.ClientConfig{
|
|
|
+ User: sshUser,
|
|
|
+ HostKeyCallback: gossh.InsecureIgnoreHostKey(),
|
|
|
+ }
|
|
|
+
|
|
|
+ var wg sync.WaitGroup
|
|
|
+ wg.Add(1)
|
|
|
+ go func() {
|
|
|
+ defer wg.Done()
|
|
|
+ c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("client: %v", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ client := gossh.NewClient(c, chans, reqs)
|
|
|
+ defer client.Close()
|
|
|
+ session, err := client.NewSession()
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("client: %v", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ defer session.Close()
|
|
|
+ t.Logf("client established session")
|
|
|
+ _, err = session.CombinedOutput("echo Ran echo!")
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("client: %v", err)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+ if err := s.HandleSSHConn(dc); err != nil {
|
|
|
+ t.Errorf("unexpected error: %v", err)
|
|
|
+ }
|
|
|
+ wg.Wait()
|
|
|
+
|
|
|
+ <-ctx.Done() // wait for recording to finish
|
|
|
+ var ch CastHeader
|
|
|
+ if err := json.NewDecoder(bytes.NewReader(recording)).Decode(&ch); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ if ch.SSHUser != sshUser {
|
|
|
+ t.Errorf("SSHUser = %q; want %q", ch.SSHUser, sshUser)
|
|
|
+ }
|
|
|
+ if ch.Command != "echo Ran echo!" {
|
|
|
+ t.Errorf("Command = %q; want %q", ch.Command, "echo Ran echo!")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func TestSSHAuthFlow(t *testing.T) {
|
|
|
- if runtime.GOOS != "linux" {
|
|
|
- t.Skip("Not running on Linux, skipping")
|
|
|
+ if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
|
|
|
+ t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
|
|
|
}
|
|
|
acceptRule := newSSHRule(&tailcfg.SSHAction{
|
|
|
Accept: true,
|