Ver código fonte

ssh/tailssh: add session recording test for non-pty sessions

Updates tailscale/corp#9967

Signed-off-by: Maisem Ali <[email protected]>
Maisem Ali 3 anos atrás
pai
commit
09d0b632d4
1 arquivos alterados com 96 adições e 2 exclusões
  1. 96 2
      ssh/tailssh/tailssh_test.go

+ 96 - 2
ssh/tailssh/tailssh_test.go

@@ -7,6 +7,7 @@ package tailssh
 
 import (
 	"bytes"
+	"context"
 	"crypto/ed25519"
 	"crypto/rand"
 	"crypto/sha256"
@@ -14,6 +15,7 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"io/ioutil"
 	"net"
 	"net/http"
 	"net/http/httptest"
@@ -324,9 +326,101 @@ func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule {
 	}
 }
 
+// TestSSHRecordingNonInteractive tests that the SSH server records the SSH session
+// when the client is not interactive (i.e. no PTY).
+// It starts a local SSH server and a recording server. The recording server
+// records the SSH session and returns it to the test.
+// The test then verifies that the recording has a valid CastHeader, it does not
+// validate the contents of the recording.
+func TestSSHRecordingNonInteractive(t *testing.T) {
+	if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
+		t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
+	}
+	var recording []byte
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+	recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		defer cancel()
+		var err error
+		recording, err = ioutil.ReadAll(r.Body)
+		if err != nil {
+			t.Error(err)
+			return
+		}
+		w.WriteHeader(http.StatusOK)
+	}))
+	defer recordingServer.Close()
+
+	state := &localState{
+		sshEnabled: true,
+		matchingRule: newSSHRule(
+			&tailcfg.SSHAction{
+				Accept: true,
+				Recorders: []netip.AddrPort{
+					must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())),
+				},
+			},
+		),
+	}
+	s := &server{
+		logf:  t.Logf,
+		httpc: recordingServer.Client(),
+	}
+	defer s.Shutdown()
+
+	src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
+	sc, dc := memnet.NewTCPConn(src, dst, 1024)
+	s.lb = state
+
+	const sshUser = "alice"
+	cfg := &gossh.ClientConfig{
+		User:            sshUser,
+		HostKeyCallback: gossh.InsecureIgnoreHostKey(),
+	}
+
+	var wg sync.WaitGroup
+	wg.Add(1)
+	go func() {
+		defer wg.Done()
+		c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
+		if err != nil {
+			t.Errorf("client: %v", err)
+			return
+		}
+		client := gossh.NewClient(c, chans, reqs)
+		defer client.Close()
+		session, err := client.NewSession()
+		if err != nil {
+			t.Errorf("client: %v", err)
+			return
+		}
+		defer session.Close()
+		t.Logf("client established session")
+		_, err = session.CombinedOutput("echo Ran echo!")
+		if err != nil {
+			t.Errorf("client: %v", err)
+		}
+	}()
+	if err := s.HandleSSHConn(dc); err != nil {
+		t.Errorf("unexpected error: %v", err)
+	}
+	wg.Wait()
+
+	<-ctx.Done() // wait for recording to finish
+	var ch CastHeader
+	if err := json.NewDecoder(bytes.NewReader(recording)).Decode(&ch); err != nil {
+		t.Fatal(err)
+	}
+	if ch.SSHUser != sshUser {
+		t.Errorf("SSHUser = %q; want %q", ch.SSHUser, sshUser)
+	}
+	if ch.Command != "echo Ran echo!" {
+		t.Errorf("Command = %q; want %q", ch.Command, "echo Ran echo!")
+	}
+}
+
 func TestSSHAuthFlow(t *testing.T) {
-	if runtime.GOOS != "linux" {
-		t.Skip("Not running on Linux, skipping")
+	if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
+		t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
 	}
 	acceptRule := newSSHRule(&tailcfg.SSHAction{
 		Accept:  true,