tailssh_test.go 30 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. //go:build linux || darwin
  4. package tailssh
  5. import (
  6. "bytes"
  7. "context"
  8. "crypto/ed25519"
  9. "crypto/rand"
  10. "crypto/sha256"
  11. "encoding/json"
  12. "errors"
  13. "fmt"
  14. "io"
  15. "net"
  16. "net/http"
  17. "net/http/httptest"
  18. "net/netip"
  19. "os"
  20. "os/exec"
  21. "os/user"
  22. "reflect"
  23. "runtime"
  24. "strconv"
  25. "strings"
  26. "sync"
  27. "sync/atomic"
  28. "testing"
  29. "time"
  30. gossh "github.com/tailscale/golang-x-crypto/ssh"
  31. "tailscale.com/ipn/ipnlocal"
  32. "tailscale.com/ipn/store/mem"
  33. "tailscale.com/net/memnet"
  34. "tailscale.com/net/tsdial"
  35. "tailscale.com/sessionrecording"
  36. "tailscale.com/tailcfg"
  37. "tailscale.com/tempfork/gliderlabs/ssh"
  38. "tailscale.com/tsd"
  39. "tailscale.com/tstest"
  40. "tailscale.com/types/key"
  41. "tailscale.com/types/logger"
  42. "tailscale.com/types/logid"
  43. "tailscale.com/types/netmap"
  44. "tailscale.com/types/ptr"
  45. "tailscale.com/util/cibuild"
  46. "tailscale.com/util/lineread"
  47. "tailscale.com/util/must"
  48. "tailscale.com/version/distro"
  49. "tailscale.com/wgengine"
  50. )
  51. func TestMatchRule(t *testing.T) {
  52. someAction := new(tailcfg.SSHAction)
  53. tests := []struct {
  54. name string
  55. rule *tailcfg.SSHRule
  56. ci *sshConnInfo
  57. wantErr error
  58. wantUser string
  59. }{
  60. {
  61. name: "invalid-conn",
  62. rule: &tailcfg.SSHRule{
  63. Action: someAction,
  64. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  65. SSHUsers: map[string]string{
  66. "*": "ubuntu",
  67. },
  68. },
  69. wantErr: errInvalidConn,
  70. },
  71. {
  72. name: "nil-rule",
  73. ci: &sshConnInfo{},
  74. rule: nil,
  75. wantErr: errNilRule,
  76. },
  77. {
  78. name: "nil-action",
  79. ci: &sshConnInfo{},
  80. rule: &tailcfg.SSHRule{},
  81. wantErr: errNilAction,
  82. },
  83. {
  84. name: "expired",
  85. rule: &tailcfg.SSHRule{
  86. Action: someAction,
  87. RuleExpires: ptr.To(time.Unix(100, 0)),
  88. },
  89. ci: &sshConnInfo{},
  90. wantErr: errRuleExpired,
  91. },
  92. {
  93. name: "no-principal",
  94. rule: &tailcfg.SSHRule{
  95. Action: someAction,
  96. SSHUsers: map[string]string{
  97. "*": "ubuntu",
  98. }},
  99. ci: &sshConnInfo{},
  100. wantErr: errPrincipalMatch,
  101. },
  102. {
  103. name: "no-user-match",
  104. rule: &tailcfg.SSHRule{
  105. Action: someAction,
  106. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  107. },
  108. ci: &sshConnInfo{sshUser: "alice"},
  109. wantErr: errUserMatch,
  110. },
  111. {
  112. name: "ok-wildcard",
  113. rule: &tailcfg.SSHRule{
  114. Action: someAction,
  115. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  116. SSHUsers: map[string]string{
  117. "*": "ubuntu",
  118. },
  119. },
  120. ci: &sshConnInfo{sshUser: "alice"},
  121. wantUser: "ubuntu",
  122. },
  123. {
  124. name: "ok-wildcard-and-nil-principal",
  125. rule: &tailcfg.SSHRule{
  126. Action: someAction,
  127. Principals: []*tailcfg.SSHPrincipal{
  128. nil, // don't crash on this
  129. {Any: true},
  130. },
  131. SSHUsers: map[string]string{
  132. "*": "ubuntu",
  133. },
  134. },
  135. ci: &sshConnInfo{sshUser: "alice"},
  136. wantUser: "ubuntu",
  137. },
  138. {
  139. name: "ok-exact",
  140. rule: &tailcfg.SSHRule{
  141. Action: someAction,
  142. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  143. SSHUsers: map[string]string{
  144. "*": "ubuntu",
  145. "alice": "thealice",
  146. },
  147. },
  148. ci: &sshConnInfo{sshUser: "alice"},
  149. wantUser: "thealice",
  150. },
  151. {
  152. name: "no-users-for-reject",
  153. rule: &tailcfg.SSHRule{
  154. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  155. Action: &tailcfg.SSHAction{Reject: true},
  156. },
  157. ci: &sshConnInfo{sshUser: "alice"},
  158. },
  159. {
  160. name: "match-principal-node-ip",
  161. rule: &tailcfg.SSHRule{
  162. Action: someAction,
  163. Principals: []*tailcfg.SSHPrincipal{{NodeIP: "1.2.3.4"}},
  164. SSHUsers: map[string]string{"*": "ubuntu"},
  165. },
  166. ci: &sshConnInfo{src: netip.MustParseAddrPort("1.2.3.4:30343")},
  167. wantUser: "ubuntu",
  168. },
  169. {
  170. name: "match-principal-node-id",
  171. rule: &tailcfg.SSHRule{
  172. Action: someAction,
  173. Principals: []*tailcfg.SSHPrincipal{{Node: "some-node-ID"}},
  174. SSHUsers: map[string]string{"*": "ubuntu"},
  175. },
  176. ci: &sshConnInfo{node: (&tailcfg.Node{StableID: "some-node-ID"}).View()},
  177. wantUser: "ubuntu",
  178. },
  179. {
  180. name: "match-principal-userlogin",
  181. rule: &tailcfg.SSHRule{
  182. Action: someAction,
  183. Principals: []*tailcfg.SSHPrincipal{{UserLogin: "[email protected]"}},
  184. SSHUsers: map[string]string{"*": "ubuntu"},
  185. },
  186. ci: &sshConnInfo{uprof: tailcfg.UserProfile{LoginName: "[email protected]"}},
  187. wantUser: "ubuntu",
  188. },
  189. {
  190. name: "ssh-user-equal",
  191. rule: &tailcfg.SSHRule{
  192. Action: someAction,
  193. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  194. SSHUsers: map[string]string{
  195. "*": "=",
  196. },
  197. },
  198. ci: &sshConnInfo{sshUser: "alice"},
  199. wantUser: "alice",
  200. },
  201. }
  202. for _, tt := range tests {
  203. t.Run(tt.name, func(t *testing.T) {
  204. c := &conn{
  205. info: tt.ci,
  206. srv: &server{logf: t.Logf},
  207. }
  208. got, gotUser, err := c.matchRule(tt.rule, nil)
  209. if err != tt.wantErr {
  210. t.Errorf("err = %v; want %v", err, tt.wantErr)
  211. }
  212. if gotUser != tt.wantUser {
  213. t.Errorf("user = %q; want %q", gotUser, tt.wantUser)
  214. }
  215. if err == nil && got == nil {
  216. t.Errorf("expected non-nil action on success")
  217. }
  218. })
  219. }
  220. }
  221. // localState implements ipnLocalBackend for testing.
  222. type localState struct {
  223. sshEnabled bool
  224. matchingRule *tailcfg.SSHRule
  225. // serverActions is a map of the action name to the action.
  226. // It is served for paths like https://unused/ssh-action/<action-name>.
  227. // The action name is the last part of the action URL.
  228. serverActions map[string]*tailcfg.SSHAction
  229. }
  230. var (
  231. currentUser = os.Getenv("USER") // Use the current user for the test.
  232. testSigner gossh.Signer
  233. testSignerOnce sync.Once
  234. )
  235. func (ts *localState) Dialer() *tsdial.Dialer {
  236. return &tsdial.Dialer{}
  237. }
  238. func (ts *localState) GetSSH_HostKeys() ([]gossh.Signer, error) {
  239. testSignerOnce.Do(func() {
  240. _, priv, err := ed25519.GenerateKey(rand.Reader)
  241. if err != nil {
  242. panic(err)
  243. }
  244. s, err := gossh.NewSignerFromSigner(priv)
  245. if err != nil {
  246. panic(err)
  247. }
  248. testSigner = s
  249. })
  250. return []gossh.Signer{testSigner}, nil
  251. }
  252. func (ts *localState) ShouldRunSSH() bool {
  253. return ts.sshEnabled
  254. }
  255. func (ts *localState) NetMap() *netmap.NetworkMap {
  256. var policy *tailcfg.SSHPolicy
  257. if ts.matchingRule != nil {
  258. policy = &tailcfg.SSHPolicy{
  259. Rules: []*tailcfg.SSHRule{
  260. ts.matchingRule,
  261. },
  262. }
  263. }
  264. return &netmap.NetworkMap{
  265. SelfNode: (&tailcfg.Node{
  266. ID: 1,
  267. }).View(),
  268. SSHPolicy: policy,
  269. }
  270. }
  271. func (ts *localState) WhoIs(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) {
  272. if proto != "tcp" {
  273. return tailcfg.NodeView{}, tailcfg.UserProfile{}, false
  274. }
  275. return (&tailcfg.Node{
  276. ID: 2,
  277. StableID: "peer-id",
  278. }).View(), tailcfg.UserProfile{
  279. LoginName: "peer",
  280. }, true
  281. }
  282. func (ts *localState) DoNoiseRequest(req *http.Request) (*http.Response, error) {
  283. rec := httptest.NewRecorder()
  284. k, ok := strings.CutPrefix(req.URL.Path, "/ssh-action/")
  285. if !ok {
  286. rec.WriteHeader(http.StatusNotFound)
  287. }
  288. a, ok := ts.serverActions[k]
  289. if !ok {
  290. rec.WriteHeader(http.StatusNotFound)
  291. return rec.Result(), nil
  292. }
  293. rec.WriteHeader(http.StatusOK)
  294. if err := json.NewEncoder(rec).Encode(a); err != nil {
  295. return nil, err
  296. }
  297. return rec.Result(), nil
  298. }
  299. func (ts *localState) TailscaleVarRoot() string {
  300. return ""
  301. }
  302. func (ts *localState) NodeKey() key.NodePublic {
  303. return key.NewNode().Public()
  304. }
  305. func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule {
  306. return &tailcfg.SSHRule{
  307. SSHUsers: map[string]string{
  308. "*": currentUser,
  309. },
  310. Action: action,
  311. Principals: []*tailcfg.SSHPrincipal{
  312. {
  313. Any: true,
  314. },
  315. },
  316. }
  317. }
  318. func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
  319. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  320. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  321. }
  322. var handler http.HandlerFunc
  323. recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  324. handler(w, r)
  325. }))
  326. defer recordingServer.Close()
  327. s := &server{
  328. logf: t.Logf,
  329. lb: &localState{
  330. sshEnabled: true,
  331. matchingRule: newSSHRule(
  332. &tailcfg.SSHAction{
  333. Accept: true,
  334. Recorders: []netip.AddrPort{
  335. netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
  336. },
  337. OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
  338. RejectSessionWithMessage: "session rejected",
  339. TerminateSessionWithMessage: "session terminated",
  340. },
  341. },
  342. ),
  343. },
  344. }
  345. defer s.Shutdown()
  346. const sshUser = "alice"
  347. cfg := &gossh.ClientConfig{
  348. User: sshUser,
  349. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  350. }
  351. tests := []struct {
  352. name string
  353. handler func(w http.ResponseWriter, r *http.Request)
  354. sshCommand string
  355. wantClientOutput string
  356. clientOutputMustNotContain []string
  357. }{
  358. {
  359. name: "upload-denied",
  360. handler: func(w http.ResponseWriter, r *http.Request) {
  361. w.WriteHeader(http.StatusForbidden)
  362. },
  363. sshCommand: "echo hello",
  364. wantClientOutput: "session rejected\r\n",
  365. clientOutputMustNotContain: []string{"hello"},
  366. },
  367. {
  368. name: "upload-fails-after-starting",
  369. handler: func(w http.ResponseWriter, r *http.Request) {
  370. r.Body.Read(make([]byte, 1))
  371. time.Sleep(100 * time.Millisecond)
  372. w.WriteHeader(http.StatusInternalServerError)
  373. },
  374. sshCommand: "echo hello && sleep 1 && echo world",
  375. wantClientOutput: "\r\n\r\nsession terminated\r\n\r\n",
  376. clientOutputMustNotContain: []string{"world"},
  377. },
  378. }
  379. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  380. for _, tt := range tests {
  381. t.Run(tt.name, func(t *testing.T) {
  382. tstest.Replace(t, &handler, tt.handler)
  383. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  384. var wg sync.WaitGroup
  385. wg.Add(1)
  386. go func() {
  387. defer wg.Done()
  388. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  389. if err != nil {
  390. t.Errorf("client: %v", err)
  391. return
  392. }
  393. client := gossh.NewClient(c, chans, reqs)
  394. defer client.Close()
  395. session, err := client.NewSession()
  396. if err != nil {
  397. t.Errorf("client: %v", err)
  398. return
  399. }
  400. defer session.Close()
  401. t.Logf("client established session")
  402. got, err := session.CombinedOutput(tt.sshCommand)
  403. if err != nil {
  404. t.Logf("client got: %q: %v", got, err)
  405. } else {
  406. t.Errorf("client did not get kicked out: %q", got)
  407. }
  408. gotStr := string(got)
  409. if !strings.HasSuffix(gotStr, tt.wantClientOutput) {
  410. t.Errorf("client got %q, want %q", got, tt.wantClientOutput)
  411. }
  412. for _, x := range tt.clientOutputMustNotContain {
  413. if strings.Contains(gotStr, x) {
  414. t.Errorf("client output must not contain %q", x)
  415. }
  416. }
  417. }()
  418. if err := s.HandleSSHConn(dc); err != nil {
  419. t.Errorf("unexpected error: %v", err)
  420. }
  421. wg.Wait()
  422. })
  423. }
  424. }
  425. func TestMultipleRecorders(t *testing.T) {
  426. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  427. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  428. }
  429. done := make(chan struct{})
  430. recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  431. defer close(done)
  432. io.ReadAll(r.Body)
  433. w.WriteHeader(http.StatusOK)
  434. }))
  435. defer recordingServer.Close()
  436. badRecorder, err := net.Listen("tcp", ":0")
  437. if err != nil {
  438. t.Fatal(err)
  439. }
  440. badRecorderAddr := badRecorder.Addr().String()
  441. badRecorder.Close()
  442. badRecordingServer500 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  443. w.WriteHeader(500)
  444. }))
  445. defer badRecordingServer500.Close()
  446. badRecordingServer200 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  447. w.WriteHeader(200)
  448. }))
  449. defer badRecordingServer200.Close()
  450. s := &server{
  451. logf: t.Logf,
  452. lb: &localState{
  453. sshEnabled: true,
  454. matchingRule: newSSHRule(
  455. &tailcfg.SSHAction{
  456. Accept: true,
  457. Recorders: []netip.AddrPort{
  458. netip.MustParseAddrPort(badRecorderAddr),
  459. netip.MustParseAddrPort(badRecordingServer500.Listener.Addr().String()),
  460. netip.MustParseAddrPort(badRecordingServer200.Listener.Addr().String()),
  461. netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
  462. },
  463. OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
  464. RejectSessionWithMessage: "session rejected",
  465. TerminateSessionWithMessage: "session terminated",
  466. },
  467. },
  468. ),
  469. },
  470. }
  471. defer s.Shutdown()
  472. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  473. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  474. const sshUser = "alice"
  475. cfg := &gossh.ClientConfig{
  476. User: sshUser,
  477. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  478. }
  479. var wg sync.WaitGroup
  480. wg.Add(1)
  481. go func() {
  482. defer wg.Done()
  483. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  484. if err != nil {
  485. t.Errorf("client: %v", err)
  486. return
  487. }
  488. client := gossh.NewClient(c, chans, reqs)
  489. defer client.Close()
  490. session, err := client.NewSession()
  491. if err != nil {
  492. t.Errorf("client: %v", err)
  493. return
  494. }
  495. defer session.Close()
  496. t.Logf("client established session")
  497. out, err := session.CombinedOutput("echo Ran echo!")
  498. if err != nil {
  499. t.Errorf("client: %v", err)
  500. }
  501. if string(out) != "Ran echo!\n" {
  502. t.Errorf("client: unexpected output: %q", out)
  503. }
  504. }()
  505. if err := s.HandleSSHConn(dc); err != nil {
  506. t.Errorf("unexpected error: %v", err)
  507. }
  508. wg.Wait()
  509. select {
  510. case <-done:
  511. case <-time.After(1 * time.Second):
  512. t.Fatal("timed out waiting for recording")
  513. }
  514. }
  515. // TestSSHRecordingNonInteractive tests that the SSH server records the SSH session
  516. // when the client is not interactive (i.e. no PTY).
  517. // It starts a local SSH server and a recording server. The recording server
  518. // records the SSH session and returns it to the test.
  519. // The test then verifies that the recording has a valid CastHeader, it does not
  520. // validate the contents of the recording.
  521. func TestSSHRecordingNonInteractive(t *testing.T) {
  522. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  523. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  524. }
  525. var recording []byte
  526. ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  527. recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  528. defer cancel()
  529. var err error
  530. recording, err = io.ReadAll(r.Body)
  531. if err != nil {
  532. t.Error(err)
  533. return
  534. }
  535. }))
  536. defer recordingServer.Close()
  537. s := &server{
  538. logf: logger.Discard,
  539. lb: &localState{
  540. sshEnabled: true,
  541. matchingRule: newSSHRule(
  542. &tailcfg.SSHAction{
  543. Accept: true,
  544. Recorders: []netip.AddrPort{
  545. must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())),
  546. },
  547. OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
  548. RejectSessionWithMessage: "session rejected",
  549. TerminateSessionWithMessage: "session terminated",
  550. },
  551. },
  552. ),
  553. },
  554. }
  555. defer s.Shutdown()
  556. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  557. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  558. const sshUser = "alice"
  559. cfg := &gossh.ClientConfig{
  560. User: sshUser,
  561. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  562. }
  563. var wg sync.WaitGroup
  564. wg.Add(1)
  565. go func() {
  566. defer wg.Done()
  567. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  568. if err != nil {
  569. t.Errorf("client: %v", err)
  570. return
  571. }
  572. client := gossh.NewClient(c, chans, reqs)
  573. defer client.Close()
  574. session, err := client.NewSession()
  575. if err != nil {
  576. t.Errorf("client: %v", err)
  577. return
  578. }
  579. defer session.Close()
  580. t.Logf("client established session")
  581. _, err = session.CombinedOutput("echo Ran echo!")
  582. if err != nil {
  583. t.Errorf("client: %v", err)
  584. }
  585. }()
  586. if err := s.HandleSSHConn(dc); err != nil {
  587. t.Errorf("unexpected error: %v", err)
  588. }
  589. wg.Wait()
  590. <-ctx.Done() // wait for recording to finish
  591. var ch sessionrecording.CastHeader
  592. if err := json.NewDecoder(bytes.NewReader(recording)).Decode(&ch); err != nil {
  593. t.Fatal(err)
  594. }
  595. if ch.SSHUser != sshUser {
  596. t.Errorf("SSHUser = %q; want %q", ch.SSHUser, sshUser)
  597. }
  598. if ch.Command != "echo Ran echo!" {
  599. t.Errorf("Command = %q; want %q", ch.Command, "echo Ran echo!")
  600. }
  601. }
  602. func TestSSHAuthFlow(t *testing.T) {
  603. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  604. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  605. }
  606. acceptRule := newSSHRule(&tailcfg.SSHAction{
  607. Accept: true,
  608. Message: "Welcome to Tailscale SSH!",
  609. })
  610. rejectRule := newSSHRule(&tailcfg.SSHAction{
  611. Reject: true,
  612. Message: "Go Away!",
  613. })
  614. tests := []struct {
  615. name string
  616. sshUser string // defaults to alice
  617. state *localState
  618. wantBanners []string
  619. usesPassword bool
  620. authErr bool
  621. }{
  622. {
  623. name: "no-policy",
  624. state: &localState{
  625. sshEnabled: true,
  626. },
  627. authErr: true,
  628. },
  629. {
  630. name: "accept",
  631. state: &localState{
  632. sshEnabled: true,
  633. matchingRule: acceptRule,
  634. },
  635. wantBanners: []string{"Welcome to Tailscale SSH!"},
  636. },
  637. {
  638. name: "reject",
  639. state: &localState{
  640. sshEnabled: true,
  641. matchingRule: rejectRule,
  642. },
  643. wantBanners: []string{"Go Away!"},
  644. authErr: true,
  645. },
  646. {
  647. name: "simple-check",
  648. state: &localState{
  649. sshEnabled: true,
  650. matchingRule: newSSHRule(&tailcfg.SSHAction{
  651. HoldAndDelegate: "https://unused/ssh-action/accept",
  652. }),
  653. serverActions: map[string]*tailcfg.SSHAction{
  654. "accept": acceptRule.Action,
  655. },
  656. },
  657. wantBanners: []string{"Welcome to Tailscale SSH!"},
  658. },
  659. {
  660. name: "multi-check",
  661. state: &localState{
  662. sshEnabled: true,
  663. matchingRule: newSSHRule(&tailcfg.SSHAction{
  664. Message: "First",
  665. HoldAndDelegate: "https://unused/ssh-action/check1",
  666. }),
  667. serverActions: map[string]*tailcfg.SSHAction{
  668. "check1": {
  669. Message: "url-here",
  670. HoldAndDelegate: "https://unused/ssh-action/check2",
  671. },
  672. "check2": acceptRule.Action,
  673. },
  674. },
  675. wantBanners: []string{"First", "url-here", "Welcome to Tailscale SSH!"},
  676. },
  677. {
  678. name: "check-reject",
  679. state: &localState{
  680. sshEnabled: true,
  681. matchingRule: newSSHRule(&tailcfg.SSHAction{
  682. Message: "First",
  683. HoldAndDelegate: "https://unused/ssh-action/reject",
  684. }),
  685. serverActions: map[string]*tailcfg.SSHAction{
  686. "reject": rejectRule.Action,
  687. },
  688. },
  689. wantBanners: []string{"First", "Go Away!"},
  690. authErr: true,
  691. },
  692. {
  693. name: "force-password-auth",
  694. sshUser: "alice+password",
  695. state: &localState{
  696. sshEnabled: true,
  697. matchingRule: acceptRule,
  698. },
  699. usesPassword: true,
  700. wantBanners: []string{"Welcome to Tailscale SSH!"},
  701. },
  702. }
  703. s := &server{
  704. logf: logger.Discard,
  705. }
  706. defer s.Shutdown()
  707. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  708. for _, tc := range tests {
  709. t.Run(tc.name, func(t *testing.T) {
  710. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  711. s.lb = tc.state
  712. sshUser := "alice"
  713. if tc.sshUser != "" {
  714. sshUser = tc.sshUser
  715. }
  716. var passwordUsed atomic.Bool
  717. cfg := &gossh.ClientConfig{
  718. User: sshUser,
  719. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  720. Auth: []gossh.AuthMethod{
  721. gossh.PasswordCallback(func() (secret string, err error) {
  722. if !tc.usesPassword {
  723. t.Error("unexpected use of PasswordCallback")
  724. return "", errors.New("unexpected use of PasswordCallback")
  725. }
  726. passwordUsed.Store(true)
  727. return "any-pass", nil
  728. }),
  729. },
  730. BannerCallback: func(message string) error {
  731. if len(tc.wantBanners) == 0 {
  732. t.Errorf("unexpected banner: %q", message)
  733. } else if message != tc.wantBanners[0] {
  734. t.Errorf("banner = %q; want %q", message, tc.wantBanners[0])
  735. } else {
  736. t.Logf("banner = %q", message)
  737. tc.wantBanners = tc.wantBanners[1:]
  738. }
  739. return nil
  740. },
  741. }
  742. var wg sync.WaitGroup
  743. wg.Add(1)
  744. go func() {
  745. defer wg.Done()
  746. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  747. if err != nil {
  748. if !tc.authErr {
  749. t.Errorf("client: %v", err)
  750. }
  751. return
  752. } else if tc.authErr {
  753. c.Close()
  754. t.Errorf("client: expected error, got nil")
  755. return
  756. }
  757. client := gossh.NewClient(c, chans, reqs)
  758. defer client.Close()
  759. session, err := client.NewSession()
  760. if err != nil {
  761. t.Errorf("client: %v", err)
  762. return
  763. }
  764. defer session.Close()
  765. _, err = session.CombinedOutput("echo Ran echo!")
  766. if err != nil {
  767. t.Errorf("client: %v", err)
  768. }
  769. }()
  770. if err := s.HandleSSHConn(dc); err != nil {
  771. t.Errorf("unexpected error: %v", err)
  772. }
  773. wg.Wait()
  774. if len(tc.wantBanners) > 0 {
  775. t.Errorf("missing banners: %v", tc.wantBanners)
  776. }
  777. })
  778. }
  779. }
  780. func TestSSH(t *testing.T) {
  781. var logf logger.Logf = t.Logf
  782. sys := &tsd.System{}
  783. eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker())
  784. if err != nil {
  785. t.Fatal(err)
  786. }
  787. sys.Set(eng)
  788. sys.Set(new(mem.Store))
  789. lb, err := ipnlocal.NewLocalBackend(logf, logid.PublicID{}, sys, 0)
  790. if err != nil {
  791. t.Fatal(err)
  792. }
  793. defer lb.Shutdown()
  794. dir := t.TempDir()
  795. lb.SetVarRoot(dir)
  796. srv := &server{
  797. lb: lb,
  798. logf: logf,
  799. }
  800. sc, err := srv.newConn()
  801. if err != nil {
  802. t.Fatal(err)
  803. }
  804. // Remove the auth checks for the test
  805. sc.insecureSkipTailscaleAuth = true
  806. u, err := user.Current()
  807. if err != nil {
  808. t.Fatal(err)
  809. }
  810. um, err := userLookup(u.Username)
  811. if err != nil {
  812. t.Fatal(err)
  813. }
  814. sc.localUser = um
  815. sc.info = &sshConnInfo{
  816. sshUser: "test",
  817. src: netip.MustParseAddrPort("1.2.3.4:32342"),
  818. dst: netip.MustParseAddrPort("1.2.3.5:22"),
  819. node: (&tailcfg.Node{}).View(),
  820. uprof: tailcfg.UserProfile{},
  821. }
  822. sc.action0 = &tailcfg.SSHAction{Accept: true}
  823. sc.finalAction = sc.action0
  824. sc.Handler = func(s ssh.Session) {
  825. sc.newSSHSession(s).run()
  826. }
  827. ln, err := net.Listen("tcp4", "127.0.0.1:0")
  828. if err != nil {
  829. t.Fatal(err)
  830. }
  831. defer ln.Close()
  832. port := ln.Addr().(*net.TCPAddr).Port
  833. go func() {
  834. for {
  835. c, err := ln.Accept()
  836. if err != nil {
  837. if !errors.Is(err, net.ErrClosed) {
  838. t.Errorf("Accept: %v", err)
  839. }
  840. return
  841. }
  842. go sc.HandleConn(c)
  843. }
  844. }()
  845. execSSH := func(args ...string) *exec.Cmd {
  846. cmd := exec.Command("ssh",
  847. "-F",
  848. "none",
  849. "-v",
  850. "-p", fmt.Sprint(port),
  851. "-o", "StrictHostKeyChecking=no",
  852. "[email protected]")
  853. cmd.Args = append(cmd.Args, args...)
  854. return cmd
  855. }
  856. t.Run("env", func(t *testing.T) {
  857. if cibuild.On() {
  858. t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")
  859. }
  860. cmd := execSSH("LANG=foo env")
  861. cmd.Env = append(os.Environ(), "LOCAL_ENV=bar")
  862. got, err := cmd.CombinedOutput()
  863. if err != nil {
  864. t.Fatal(err, string(got))
  865. }
  866. m := parseEnv(got)
  867. if got := m["USER"]; got == "" || got != u.Username {
  868. t.Errorf("USER = %q; want %q", got, u.Username)
  869. }
  870. if got := m["HOME"]; got == "" || got != u.HomeDir {
  871. t.Errorf("HOME = %q; want %q", got, u.HomeDir)
  872. }
  873. if got := m["PWD"]; got == "" || got != u.HomeDir {
  874. t.Errorf("PWD = %q; want %q", got, u.HomeDir)
  875. }
  876. if got := m["SHELL"]; got == "" {
  877. t.Errorf("no SHELL")
  878. }
  879. if got, want := m["LANG"], "foo"; got != want {
  880. t.Errorf("LANG = %q; want %q", got, want)
  881. }
  882. if got := m["LOCAL_ENV"]; got != "" {
  883. t.Errorf("LOCAL_ENV leaked over ssh: %v", got)
  884. }
  885. t.Logf("got: %+v", m)
  886. })
  887. t.Run("stdout_stderr", func(t *testing.T) {
  888. cmd := execSSH("sh", "-c", "echo foo; echo bar >&2")
  889. var outBuf, errBuf bytes.Buffer
  890. cmd.Stdout = &outBuf
  891. cmd.Stderr = &errBuf
  892. if err := cmd.Run(); err != nil {
  893. t.Fatal(err)
  894. }
  895. t.Logf("Got: %q and %q", outBuf.Bytes(), errBuf.Bytes())
  896. // TODO: figure out why these aren't right. should be
  897. // "foo\n" and "bar\n", not "\n" and "bar\n".
  898. })
  899. t.Run("large_file", func(t *testing.T) {
  900. const wantSize = 1e6
  901. var outBuf bytes.Buffer
  902. cmd := execSSH("head", "-c", strconv.Itoa(wantSize), "/dev/zero")
  903. cmd.Stdout = &outBuf
  904. if err := cmd.Run(); err != nil {
  905. t.Fatal(err)
  906. }
  907. if gotSize := outBuf.Len(); gotSize != wantSize {
  908. t.Fatalf("got %d, want %d", gotSize, int(wantSize))
  909. }
  910. })
  911. t.Run("stdin", func(t *testing.T) {
  912. if cibuild.On() {
  913. t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")
  914. }
  915. cmd := execSSH("cat")
  916. var outBuf bytes.Buffer
  917. cmd.Stdout = &outBuf
  918. const str = "foo\nbar\n"
  919. cmd.Stdin = strings.NewReader(str)
  920. if err := cmd.Run(); err != nil {
  921. t.Fatal(err)
  922. }
  923. if got := outBuf.String(); got != str {
  924. t.Errorf("got %q; want %q", got, str)
  925. }
  926. })
  927. }
  928. func parseEnv(out []byte) map[string]string {
  929. e := map[string]string{}
  930. lineread.Reader(bytes.NewReader(out), func(line []byte) error {
  931. i := bytes.IndexByte(line, '=')
  932. if i == -1 {
  933. return nil
  934. }
  935. e[string(line[:i])] = string(line[i+1:])
  936. return nil
  937. })
  938. return e
  939. }
  940. func TestPublicKeyFetching(t *testing.T) {
  941. var reqsTotal, reqsIfNoneMatchHit, reqsIfNoneMatchMiss int32
  942. ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  943. atomic.AddInt32((&reqsTotal), 1)
  944. etag := fmt.Sprintf("W/%q", sha256.Sum256([]byte(r.URL.Path)))
  945. w.Header().Set("Etag", etag)
  946. if v := r.Header.Get("If-None-Match"); v != "" {
  947. if v == etag {
  948. atomic.AddInt32(&reqsIfNoneMatchHit, 1)
  949. w.WriteHeader(304)
  950. return
  951. }
  952. atomic.AddInt32(&reqsIfNoneMatchMiss, 1)
  953. }
  954. io.WriteString(w, "foo\nbar\n"+string(r.URL.Path)+"\n")
  955. }))
  956. ts.StartTLS()
  957. defer ts.Close()
  958. keys := ts.URL
  959. clock := &tstest.Clock{}
  960. srv := &server{
  961. pubKeyHTTPClient: ts.Client(),
  962. timeNow: clock.Now,
  963. }
  964. for range 2 {
  965. got, err := srv.fetchPublicKeysURL(keys + "/alice.keys")
  966. if err != nil {
  967. t.Fatal(err)
  968. }
  969. if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) {
  970. t.Errorf("got %q; want %q", got, want)
  971. }
  972. }
  973. if got, want := atomic.LoadInt32(&reqsTotal), int32(1); got != want {
  974. t.Errorf("got %d requests; want %d", got, want)
  975. }
  976. if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(0); got != want {
  977. t.Errorf("got %d etag hits; want %d", got, want)
  978. }
  979. clock.Advance(5 * time.Minute)
  980. got, err := srv.fetchPublicKeysURL(keys + "/alice.keys")
  981. if err != nil {
  982. t.Fatal(err)
  983. }
  984. if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) {
  985. t.Errorf("got %q; want %q", got, want)
  986. }
  987. if got, want := atomic.LoadInt32(&reqsTotal), int32(2); got != want {
  988. t.Errorf("got %d requests; want %d", got, want)
  989. }
  990. if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(1); got != want {
  991. t.Errorf("got %d etag hits; want %d", got, want)
  992. }
  993. if got, want := atomic.LoadInt32(&reqsIfNoneMatchMiss), int32(0); got != want {
  994. t.Errorf("got %d etag misses; want %d", got, want)
  995. }
  996. }
  997. func TestExpandPublicKeyURL(t *testing.T) {
  998. c := &conn{
  999. info: &sshConnInfo{
  1000. uprof: tailcfg.UserProfile{
  1001. LoginName: "[email protected]",
  1002. },
  1003. },
  1004. }
  1005. if got, want := c.expandPublicKeyURL("foo"), "foo"; got != want {
  1006. t.Errorf("basic: got %q; want %q", got, want)
  1007. }
  1008. if got, want := c.expandPublicKeyURL("https://example.com/$LOGINNAME_LOCALPART.keys"), "https://example.com/bar.keys"; got != want {
  1009. t.Errorf("localpart: got %q; want %q", got, want)
  1010. }
  1011. if got, want := c.expandPublicKeyURL("https://example.com/keys?email=$LOGINNAME_EMAIL"), "https://example.com/[email protected]"; got != want {
  1012. t.Errorf("email: got %q; want %q", got, want)
  1013. }
  1014. c.info = new(sshConnInfo)
  1015. if got, want := c.expandPublicKeyURL("https://example.com/keys?email=$LOGINNAME_EMAIL"), "https://example.com/keys?email="; got != want {
  1016. t.Errorf("on empty: got %q; want %q", got, want)
  1017. }
  1018. }
  1019. func TestAcceptEnvPair(t *testing.T) {
  1020. tests := []struct {
  1021. in string
  1022. want bool
  1023. }{
  1024. {"TERM=x", true},
  1025. {"term=x", false},
  1026. {"TERM", false},
  1027. {"LC_FOO=x", true},
  1028. {"LD_PRELOAD=naah", false},
  1029. {"TERM=screen-256color", true},
  1030. }
  1031. for _, tt := range tests {
  1032. if got := acceptEnvPair(tt.in); got != tt.want {
  1033. t.Errorf("for %q, got %v; want %v", tt.in, got, tt.want)
  1034. }
  1035. }
  1036. }
  1037. func TestPathFromPAMEnvLine(t *testing.T) {
  1038. u := &user.User{Username: "foo", HomeDir: "/Homes/Foo"}
  1039. tests := []struct {
  1040. line string
  1041. u *user.User
  1042. want string
  1043. }{
  1044. {"", u, ""},
  1045. {`PATH DEFAULT="/run/wrappers/bin:@{HOME}/.nix-profile/bin:/etc/profiles/per-user/@{PAM_USER}/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin"`,
  1046. u, "/run/wrappers/bin:/Homes/Foo/.nix-profile/bin:/etc/profiles/per-user/foo/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin"},
  1047. {`PATH DEFAULT="@{SOMETHING_ELSE}:nope:@{HOME}"`,
  1048. u, ""},
  1049. }
  1050. for i, tt := range tests {
  1051. got := pathFromPAMEnvLine([]byte(tt.line), tt.u)
  1052. if got != tt.want {
  1053. t.Errorf("%d. got %q; want %q", i, got, tt.want)
  1054. }
  1055. }
  1056. }
  1057. func TestExpandDefaultPathTmpl(t *testing.T) {
  1058. u := &user.User{Username: "foo", HomeDir: "/Homes/Foo"}
  1059. tests := []struct {
  1060. t string
  1061. u *user.User
  1062. want string
  1063. }{
  1064. {"", u, ""},
  1065. {`/run/wrappers/bin:@{HOME}/.nix-profile/bin:/etc/profiles/per-user/@{PAM_USER}/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin`,
  1066. u, "/run/wrappers/bin:/Homes/Foo/.nix-profile/bin:/etc/profiles/per-user/foo/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin"},
  1067. {`@{SOMETHING_ELSE}:nope:@{HOME}`, u, ""},
  1068. }
  1069. for i, tt := range tests {
  1070. got := expandDefaultPathTmpl(tt.t, tt.u)
  1071. if got != tt.want {
  1072. t.Errorf("%d. got %q; want %q", i, got, tt.want)
  1073. }
  1074. }
  1075. }
  1076. func TestPathFromPAMEnvLineOnNixOS(t *testing.T) {
  1077. if runtime.GOOS != "linux" {
  1078. t.Skip("skipping on non-linux")
  1079. }
  1080. if distro.Get() != distro.NixOS {
  1081. t.Skip("skipping on non-NixOS")
  1082. }
  1083. u, err := user.Current()
  1084. if err != nil {
  1085. t.Fatal(err)
  1086. }
  1087. got := defaultPathForUserOnNixOS(u)
  1088. if got == "" {
  1089. x, err := os.ReadFile("/etc/pam/environment")
  1090. t.Fatalf("no result. file was: err=%v, contents=%s", err, x)
  1091. }
  1092. t.Logf("success; got=%q", got)
  1093. }
  1094. func TestStdOsUserUserAssumptions(t *testing.T) {
  1095. v := reflect.TypeFor[user.User]()
  1096. if got, want := v.NumField(), 5; got != want {
  1097. t.Errorf("os/user.User has %v fields; this package assumes %v", got, want)
  1098. }
  1099. }