|
|
@@ -31,6 +31,7 @@ import (
|
|
|
"sync"
|
|
|
"sync/atomic"
|
|
|
"testing"
|
|
|
+ "testing/synctest"
|
|
|
"time"
|
|
|
|
|
|
gossh "golang.org/x/crypto/ssh"
|
|
|
@@ -1111,6 +1112,7 @@ func TestSSH(t *testing.T) {
|
|
|
}
|
|
|
sc.action0 = &tailcfg.SSHAction{Accept: true}
|
|
|
sc.finalAction = sc.action0
|
|
|
+ sc.authCompleted.Store(true)
|
|
|
|
|
|
sc.Handler = func(s ssh.Session) {
|
|
|
sc.newSSHSession(s).run()
|
|
|
@@ -1320,6 +1322,79 @@ func TestStdOsUserUserAssumptions(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func TestOnPolicyChangeSkipsPreAuthConns(t *testing.T) {
|
|
|
+ tests := []struct {
|
|
|
+ name string
|
|
|
+ sshRule *tailcfg.SSHRule
|
|
|
+ wantCancel bool
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ name: "accept-after-auth",
|
|
|
+ sshRule: newSSHRule(&tailcfg.SSHAction{Accept: true}),
|
|
|
+ wantCancel: false,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "reject-after-auth",
|
|
|
+ sshRule: newSSHRule(&tailcfg.SSHAction{Reject: true}),
|
|
|
+ wantCancel: true,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, tt := range tests {
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
+ synctest.Test(t, func(t *testing.T) {
|
|
|
+ srv := &server{
|
|
|
+ logf: tstest.WhileTestRunningLogger(t),
|
|
|
+ lb: &localState{
|
|
|
+ sshEnabled: true,
|
|
|
+ matchingRule: tt.sshRule,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ c := &conn{
|
|
|
+ srv: srv,
|
|
|
+ info: &sshConnInfo{
|
|
|
+ sshUser: "alice",
|
|
|
+ src: netip.MustParseAddrPort("1.2.3.4:30343"),
|
|
|
+ dst: netip.MustParseAddrPort("100.100.100.102:22"),
|
|
|
+ },
|
|
|
+ localUser: &userMeta{User: user.User{Username: currentUser}},
|
|
|
+ }
|
|
|
+ srv.activeConns = map[*conn]bool{c: true}
|
|
|
+ ctx, cancel := context.WithCancelCause(context.Background())
|
|
|
+ ss := &sshSession{ctx: ctx, cancelCtx: cancel}
|
|
|
+ c.sessions = []*sshSession{ss}
|
|
|
+
|
|
|
+ // Before authCompleted is set, OnPolicyChange should skip
|
|
|
+ // the conn entirely — no goroutine spawned.
|
|
|
+ srv.OnPolicyChange()
|
|
|
+ synctest.Wait()
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ t.Fatal("session canceled before auth completed")
|
|
|
+ default:
|
|
|
+ }
|
|
|
+
|
|
|
+ // Mark auth as completed. Now OnPolicyChange should
|
|
|
+ // evaluate the policy and act accordingly.
|
|
|
+ c.authCompleted.Store(true)
|
|
|
+
|
|
|
+ srv.OnPolicyChange()
|
|
|
+ synctest.Wait()
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ if !tt.wantCancel {
|
|
|
+ t.Fatal("valid session should not have been canceled")
|
|
|
+ }
|
|
|
+ default:
|
|
|
+ if tt.wantCancel {
|
|
|
+ t.Fatal("invalid session should have been canceled")
|
|
|
+ }
|
|
|
+ }
|
|
|
+ })
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func mockRecordingServer(t *testing.T, handleRecord http.HandlerFunc) *httptest.Server {
|
|
|
t.Helper()
|
|
|
mux := http.NewServeMux()
|