tailssh_test.go 30 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. //go:build linux || darwin
  4. package tailssh
  5. import (
  6. "bytes"
  7. "context"
  8. "crypto/ed25519"
  9. "crypto/rand"
  10. "crypto/sha256"
  11. "encoding/json"
  12. "errors"
  13. "fmt"
  14. "io"
  15. "io/ioutil"
  16. "net"
  17. "net/http"
  18. "net/http/httptest"
  19. "net/netip"
  20. "os"
  21. "os/exec"
  22. "os/user"
  23. "reflect"
  24. "runtime"
  25. "strconv"
  26. "strings"
  27. "sync"
  28. "sync/atomic"
  29. "testing"
  30. "time"
  31. gossh "github.com/tailscale/golang-x-crypto/ssh"
  32. "tailscale.com/ipn/ipnlocal"
  33. "tailscale.com/ipn/store/mem"
  34. "tailscale.com/net/memnet"
  35. "tailscale.com/net/tsdial"
  36. "tailscale.com/tailcfg"
  37. "tailscale.com/tempfork/gliderlabs/ssh"
  38. "tailscale.com/tsd"
  39. "tailscale.com/tstest"
  40. "tailscale.com/types/key"
  41. "tailscale.com/types/logger"
  42. "tailscale.com/types/logid"
  43. "tailscale.com/types/netmap"
  44. "tailscale.com/util/cibuild"
  45. "tailscale.com/util/lineread"
  46. "tailscale.com/util/must"
  47. "tailscale.com/version/distro"
  48. "tailscale.com/wgengine"
  49. )
  50. func TestMatchRule(t *testing.T) {
  51. someAction := new(tailcfg.SSHAction)
  52. tests := []struct {
  53. name string
  54. rule *tailcfg.SSHRule
  55. ci *sshConnInfo
  56. wantErr error
  57. wantUser string
  58. }{
  59. {
  60. name: "invalid-conn",
  61. rule: &tailcfg.SSHRule{
  62. Action: someAction,
  63. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  64. SSHUsers: map[string]string{
  65. "*": "ubuntu",
  66. },
  67. },
  68. wantErr: errInvalidConn,
  69. },
  70. {
  71. name: "nil-rule",
  72. ci: &sshConnInfo{},
  73. rule: nil,
  74. wantErr: errNilRule,
  75. },
  76. {
  77. name: "nil-action",
  78. ci: &sshConnInfo{},
  79. rule: &tailcfg.SSHRule{},
  80. wantErr: errNilAction,
  81. },
  82. {
  83. name: "expired",
  84. rule: &tailcfg.SSHRule{
  85. Action: someAction,
  86. RuleExpires: timePtr(time.Unix(100, 0)),
  87. },
  88. ci: &sshConnInfo{},
  89. wantErr: errRuleExpired,
  90. },
  91. {
  92. name: "no-principal",
  93. rule: &tailcfg.SSHRule{
  94. Action: someAction,
  95. SSHUsers: map[string]string{
  96. "*": "ubuntu",
  97. }},
  98. ci: &sshConnInfo{},
  99. wantErr: errPrincipalMatch,
  100. },
  101. {
  102. name: "no-user-match",
  103. rule: &tailcfg.SSHRule{
  104. Action: someAction,
  105. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  106. },
  107. ci: &sshConnInfo{sshUser: "alice"},
  108. wantErr: errUserMatch,
  109. },
  110. {
  111. name: "ok-wildcard",
  112. rule: &tailcfg.SSHRule{
  113. Action: someAction,
  114. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  115. SSHUsers: map[string]string{
  116. "*": "ubuntu",
  117. },
  118. },
  119. ci: &sshConnInfo{sshUser: "alice"},
  120. wantUser: "ubuntu",
  121. },
  122. {
  123. name: "ok-wildcard-and-nil-principal",
  124. rule: &tailcfg.SSHRule{
  125. Action: someAction,
  126. Principals: []*tailcfg.SSHPrincipal{
  127. nil, // don't crash on this
  128. {Any: true},
  129. },
  130. SSHUsers: map[string]string{
  131. "*": "ubuntu",
  132. },
  133. },
  134. ci: &sshConnInfo{sshUser: "alice"},
  135. wantUser: "ubuntu",
  136. },
  137. {
  138. name: "ok-exact",
  139. rule: &tailcfg.SSHRule{
  140. Action: someAction,
  141. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  142. SSHUsers: map[string]string{
  143. "*": "ubuntu",
  144. "alice": "thealice",
  145. },
  146. },
  147. ci: &sshConnInfo{sshUser: "alice"},
  148. wantUser: "thealice",
  149. },
  150. {
  151. name: "no-users-for-reject",
  152. rule: &tailcfg.SSHRule{
  153. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  154. Action: &tailcfg.SSHAction{Reject: true},
  155. },
  156. ci: &sshConnInfo{sshUser: "alice"},
  157. },
  158. {
  159. name: "match-principal-node-ip",
  160. rule: &tailcfg.SSHRule{
  161. Action: someAction,
  162. Principals: []*tailcfg.SSHPrincipal{{NodeIP: "1.2.3.4"}},
  163. SSHUsers: map[string]string{"*": "ubuntu"},
  164. },
  165. ci: &sshConnInfo{src: netip.MustParseAddrPort("1.2.3.4:30343")},
  166. wantUser: "ubuntu",
  167. },
  168. {
  169. name: "match-principal-node-id",
  170. rule: &tailcfg.SSHRule{
  171. Action: someAction,
  172. Principals: []*tailcfg.SSHPrincipal{{Node: "some-node-ID"}},
  173. SSHUsers: map[string]string{"*": "ubuntu"},
  174. },
  175. ci: &sshConnInfo{node: &tailcfg.Node{StableID: "some-node-ID"}},
  176. wantUser: "ubuntu",
  177. },
  178. {
  179. name: "match-principal-userlogin",
  180. rule: &tailcfg.SSHRule{
  181. Action: someAction,
  182. Principals: []*tailcfg.SSHPrincipal{{UserLogin: "[email protected]"}},
  183. SSHUsers: map[string]string{"*": "ubuntu"},
  184. },
  185. ci: &sshConnInfo{uprof: tailcfg.UserProfile{LoginName: "[email protected]"}},
  186. wantUser: "ubuntu",
  187. },
  188. {
  189. name: "ssh-user-equal",
  190. rule: &tailcfg.SSHRule{
  191. Action: someAction,
  192. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  193. SSHUsers: map[string]string{
  194. "*": "=",
  195. },
  196. },
  197. ci: &sshConnInfo{sshUser: "alice"},
  198. wantUser: "alice",
  199. },
  200. }
  201. for _, tt := range tests {
  202. t.Run(tt.name, func(t *testing.T) {
  203. c := &conn{
  204. info: tt.ci,
  205. srv: &server{logf: t.Logf},
  206. }
  207. got, gotUser, err := c.matchRule(tt.rule, nil)
  208. if err != tt.wantErr {
  209. t.Errorf("err = %v; want %v", err, tt.wantErr)
  210. }
  211. if gotUser != tt.wantUser {
  212. t.Errorf("user = %q; want %q", gotUser, tt.wantUser)
  213. }
  214. if err == nil && got == nil {
  215. t.Errorf("expected non-nil action on success")
  216. }
  217. })
  218. }
  219. }
  220. func timePtr(t time.Time) *time.Time { return &t }
  221. // localState implements ipnLocalBackend for testing.
  222. type localState struct {
  223. sshEnabled bool
  224. matchingRule *tailcfg.SSHRule
  225. // serverActions is a map of the action name to the action.
  226. // It is served for paths like https://unused/ssh-action/<action-name>.
  227. // The action name is the last part of the action URL.
  228. serverActions map[string]*tailcfg.SSHAction
  229. }
  230. var (
  231. currentUser = os.Getenv("USER") // Use the current user for the test.
  232. testSigner gossh.Signer
  233. testSignerOnce sync.Once
  234. )
  235. func (ts *localState) Dialer() *tsdial.Dialer {
  236. return &tsdial.Dialer{}
  237. }
  238. func (ts *localState) GetSSH_HostKeys() ([]gossh.Signer, error) {
  239. testSignerOnce.Do(func() {
  240. _, priv, err := ed25519.GenerateKey(rand.Reader)
  241. if err != nil {
  242. panic(err)
  243. }
  244. s, err := gossh.NewSignerFromSigner(priv)
  245. if err != nil {
  246. panic(err)
  247. }
  248. testSigner = s
  249. })
  250. return []gossh.Signer{testSigner}, nil
  251. }
  252. func (ts *localState) ShouldRunSSH() bool {
  253. return ts.sshEnabled
  254. }
  255. func (ts *localState) NetMap() *netmap.NetworkMap {
  256. var policy *tailcfg.SSHPolicy
  257. if ts.matchingRule != nil {
  258. policy = &tailcfg.SSHPolicy{
  259. Rules: []*tailcfg.SSHRule{
  260. ts.matchingRule,
  261. },
  262. }
  263. }
  264. return &netmap.NetworkMap{
  265. SelfNode: &tailcfg.Node{
  266. ID: 1,
  267. },
  268. SSHPolicy: policy,
  269. }
  270. }
  271. func (ts *localState) WhoIs(ipp netip.AddrPort) (n *tailcfg.Node, u tailcfg.UserProfile, ok bool) {
  272. return &tailcfg.Node{
  273. ID: 2,
  274. StableID: "peer-id",
  275. }, tailcfg.UserProfile{
  276. LoginName: "peer",
  277. }, true
  278. }
  279. func (ts *localState) DoNoiseRequest(req *http.Request) (*http.Response, error) {
  280. rec := httptest.NewRecorder()
  281. k, ok := strings.CutPrefix(req.URL.Path, "/ssh-action/")
  282. if !ok {
  283. rec.WriteHeader(http.StatusNotFound)
  284. }
  285. a, ok := ts.serverActions[k]
  286. if !ok {
  287. rec.WriteHeader(http.StatusNotFound)
  288. return rec.Result(), nil
  289. }
  290. rec.WriteHeader(http.StatusOK)
  291. if err := json.NewEncoder(rec).Encode(a); err != nil {
  292. return nil, err
  293. }
  294. return rec.Result(), nil
  295. }
  296. func (ts *localState) TailscaleVarRoot() string {
  297. return ""
  298. }
  299. func (ts *localState) NodeKey() key.NodePublic {
  300. return key.NewNode().Public()
  301. }
  302. func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule {
  303. return &tailcfg.SSHRule{
  304. SSHUsers: map[string]string{
  305. "*": currentUser,
  306. },
  307. Action: action,
  308. Principals: []*tailcfg.SSHPrincipal{
  309. {
  310. Any: true,
  311. },
  312. },
  313. }
  314. }
  315. func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
  316. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  317. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  318. }
  319. var handler http.HandlerFunc
  320. recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  321. handler(w, r)
  322. }))
  323. defer recordingServer.Close()
  324. s := &server{
  325. logf: t.Logf,
  326. lb: &localState{
  327. sshEnabled: true,
  328. matchingRule: newSSHRule(
  329. &tailcfg.SSHAction{
  330. Accept: true,
  331. Recorders: []netip.AddrPort{
  332. netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
  333. },
  334. OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
  335. RejectSessionWithMessage: "session rejected",
  336. TerminateSessionWithMessage: "session terminated",
  337. },
  338. },
  339. ),
  340. },
  341. }
  342. defer s.Shutdown()
  343. const sshUser = "alice"
  344. cfg := &gossh.ClientConfig{
  345. User: sshUser,
  346. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  347. }
  348. tests := []struct {
  349. name string
  350. handler func(w http.ResponseWriter, r *http.Request)
  351. sshCommand string
  352. wantClientOutput string
  353. clientOutputMustNotContain []string
  354. }{
  355. {
  356. name: "upload-denied",
  357. handler: func(w http.ResponseWriter, r *http.Request) {
  358. w.WriteHeader(http.StatusForbidden)
  359. },
  360. sshCommand: "echo hello",
  361. wantClientOutput: "session rejected\r\n",
  362. clientOutputMustNotContain: []string{"hello"},
  363. },
  364. {
  365. name: "upload-fails-after-starting",
  366. handler: func(w http.ResponseWriter, r *http.Request) {
  367. r.Body.Read(make([]byte, 1))
  368. time.Sleep(100 * time.Millisecond)
  369. w.WriteHeader(http.StatusInternalServerError)
  370. },
  371. sshCommand: "echo hello && sleep 1 && echo world",
  372. wantClientOutput: "\r\n\r\nsession terminated\r\n\r\n",
  373. clientOutputMustNotContain: []string{"world"},
  374. },
  375. }
  376. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  377. for _, tt := range tests {
  378. t.Run(tt.name, func(t *testing.T) {
  379. tstest.Replace(t, &handler, tt.handler)
  380. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  381. var wg sync.WaitGroup
  382. wg.Add(1)
  383. go func() {
  384. defer wg.Done()
  385. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  386. if err != nil {
  387. t.Errorf("client: %v", err)
  388. return
  389. }
  390. client := gossh.NewClient(c, chans, reqs)
  391. defer client.Close()
  392. session, err := client.NewSession()
  393. if err != nil {
  394. t.Errorf("client: %v", err)
  395. return
  396. }
  397. defer session.Close()
  398. t.Logf("client established session")
  399. got, err := session.CombinedOutput(tt.sshCommand)
  400. if err != nil {
  401. t.Logf("client got: %q: %v", got, err)
  402. } else {
  403. t.Errorf("client did not get kicked out: %q", got)
  404. }
  405. gotStr := string(got)
  406. if !strings.HasSuffix(gotStr, tt.wantClientOutput) {
  407. t.Errorf("client got %q, want %q", got, tt.wantClientOutput)
  408. }
  409. for _, x := range tt.clientOutputMustNotContain {
  410. if strings.Contains(gotStr, x) {
  411. t.Errorf("client output must not contain %q", x)
  412. }
  413. }
  414. }()
  415. if err := s.HandleSSHConn(dc); err != nil {
  416. t.Errorf("unexpected error: %v", err)
  417. }
  418. wg.Wait()
  419. })
  420. }
  421. }
  422. func TestMultipleRecorders(t *testing.T) {
  423. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  424. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  425. }
  426. done := make(chan struct{})
  427. recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  428. defer close(done)
  429. io.ReadAll(r.Body)
  430. w.WriteHeader(http.StatusOK)
  431. }))
  432. defer recordingServer.Close()
  433. badRecorder, err := net.Listen("tcp", ":0")
  434. if err != nil {
  435. t.Fatal(err)
  436. }
  437. badRecorderAddr := badRecorder.Addr().String()
  438. badRecorder.Close()
  439. badRecordingServer500 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  440. w.WriteHeader(500)
  441. }))
  442. defer badRecordingServer500.Close()
  443. badRecordingServer200 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  444. w.WriteHeader(200)
  445. }))
  446. defer badRecordingServer200.Close()
  447. s := &server{
  448. logf: t.Logf,
  449. lb: &localState{
  450. sshEnabled: true,
  451. matchingRule: newSSHRule(
  452. &tailcfg.SSHAction{
  453. Accept: true,
  454. Recorders: []netip.AddrPort{
  455. netip.MustParseAddrPort(badRecorderAddr),
  456. netip.MustParseAddrPort(badRecordingServer500.Listener.Addr().String()),
  457. netip.MustParseAddrPort(badRecordingServer200.Listener.Addr().String()),
  458. netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
  459. },
  460. OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
  461. RejectSessionWithMessage: "session rejected",
  462. TerminateSessionWithMessage: "session terminated",
  463. },
  464. },
  465. ),
  466. },
  467. }
  468. defer s.Shutdown()
  469. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  470. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  471. const sshUser = "alice"
  472. cfg := &gossh.ClientConfig{
  473. User: sshUser,
  474. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  475. }
  476. var wg sync.WaitGroup
  477. wg.Add(1)
  478. go func() {
  479. defer wg.Done()
  480. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  481. if err != nil {
  482. t.Errorf("client: %v", err)
  483. return
  484. }
  485. client := gossh.NewClient(c, chans, reqs)
  486. defer client.Close()
  487. session, err := client.NewSession()
  488. if err != nil {
  489. t.Errorf("client: %v", err)
  490. return
  491. }
  492. defer session.Close()
  493. t.Logf("client established session")
  494. out, err := session.CombinedOutput("echo Ran echo!")
  495. if err != nil {
  496. t.Errorf("client: %v", err)
  497. }
  498. if string(out) != "Ran echo!\n" {
  499. t.Errorf("client: unexpected output: %q", out)
  500. }
  501. }()
  502. if err := s.HandleSSHConn(dc); err != nil {
  503. t.Errorf("unexpected error: %v", err)
  504. }
  505. wg.Wait()
  506. select {
  507. case <-done:
  508. case <-time.After(1 * time.Second):
  509. t.Fatal("timed out waiting for recording")
  510. }
  511. }
  512. // TestSSHRecordingNonInteractive tests that the SSH server records the SSH session
  513. // when the client is not interactive (i.e. no PTY).
  514. // It starts a local SSH server and a recording server. The recording server
  515. // records the SSH session and returns it to the test.
  516. // The test then verifies that the recording has a valid CastHeader, it does not
  517. // validate the contents of the recording.
  518. func TestSSHRecordingNonInteractive(t *testing.T) {
  519. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  520. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  521. }
  522. var recording []byte
  523. ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  524. recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  525. defer cancel()
  526. var err error
  527. recording, err = ioutil.ReadAll(r.Body)
  528. if err != nil {
  529. t.Error(err)
  530. return
  531. }
  532. }))
  533. defer recordingServer.Close()
  534. s := &server{
  535. logf: logger.Discard,
  536. lb: &localState{
  537. sshEnabled: true,
  538. matchingRule: newSSHRule(
  539. &tailcfg.SSHAction{
  540. Accept: true,
  541. Recorders: []netip.AddrPort{
  542. must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())),
  543. },
  544. OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
  545. RejectSessionWithMessage: "session rejected",
  546. TerminateSessionWithMessage: "session terminated",
  547. },
  548. },
  549. ),
  550. },
  551. }
  552. defer s.Shutdown()
  553. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  554. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  555. const sshUser = "alice"
  556. cfg := &gossh.ClientConfig{
  557. User: sshUser,
  558. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  559. }
  560. var wg sync.WaitGroup
  561. wg.Add(1)
  562. go func() {
  563. defer wg.Done()
  564. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  565. if err != nil {
  566. t.Errorf("client: %v", err)
  567. return
  568. }
  569. client := gossh.NewClient(c, chans, reqs)
  570. defer client.Close()
  571. session, err := client.NewSession()
  572. if err != nil {
  573. t.Errorf("client: %v", err)
  574. return
  575. }
  576. defer session.Close()
  577. t.Logf("client established session")
  578. _, err = session.CombinedOutput("echo Ran echo!")
  579. if err != nil {
  580. t.Errorf("client: %v", err)
  581. }
  582. }()
  583. if err := s.HandleSSHConn(dc); err != nil {
  584. t.Errorf("unexpected error: %v", err)
  585. }
  586. wg.Wait()
  587. <-ctx.Done() // wait for recording to finish
  588. var ch CastHeader
  589. if err := json.NewDecoder(bytes.NewReader(recording)).Decode(&ch); err != nil {
  590. t.Fatal(err)
  591. }
  592. if ch.SSHUser != sshUser {
  593. t.Errorf("SSHUser = %q; want %q", ch.SSHUser, sshUser)
  594. }
  595. if ch.Command != "echo Ran echo!" {
  596. t.Errorf("Command = %q; want %q", ch.Command, "echo Ran echo!")
  597. }
  598. }
  599. func TestSSHAuthFlow(t *testing.T) {
  600. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  601. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  602. }
  603. acceptRule := newSSHRule(&tailcfg.SSHAction{
  604. Accept: true,
  605. Message: "Welcome to Tailscale SSH!",
  606. })
  607. rejectRule := newSSHRule(&tailcfg.SSHAction{
  608. Reject: true,
  609. Message: "Go Away!",
  610. })
  611. tests := []struct {
  612. name string
  613. sshUser string // defaults to alice
  614. state *localState
  615. wantBanners []string
  616. usesPassword bool
  617. authErr bool
  618. }{
  619. {
  620. name: "no-policy",
  621. state: &localState{
  622. sshEnabled: true,
  623. },
  624. authErr: true,
  625. },
  626. {
  627. name: "accept",
  628. state: &localState{
  629. sshEnabled: true,
  630. matchingRule: acceptRule,
  631. },
  632. wantBanners: []string{"Welcome to Tailscale SSH!"},
  633. },
  634. {
  635. name: "reject",
  636. state: &localState{
  637. sshEnabled: true,
  638. matchingRule: rejectRule,
  639. },
  640. wantBanners: []string{"Go Away!"},
  641. authErr: true,
  642. },
  643. {
  644. name: "simple-check",
  645. state: &localState{
  646. sshEnabled: true,
  647. matchingRule: newSSHRule(&tailcfg.SSHAction{
  648. HoldAndDelegate: "https://unused/ssh-action/accept",
  649. }),
  650. serverActions: map[string]*tailcfg.SSHAction{
  651. "accept": acceptRule.Action,
  652. },
  653. },
  654. wantBanners: []string{"Welcome to Tailscale SSH!"},
  655. },
  656. {
  657. name: "multi-check",
  658. state: &localState{
  659. sshEnabled: true,
  660. matchingRule: newSSHRule(&tailcfg.SSHAction{
  661. Message: "First",
  662. HoldAndDelegate: "https://unused/ssh-action/check1",
  663. }),
  664. serverActions: map[string]*tailcfg.SSHAction{
  665. "check1": {
  666. Message: "url-here",
  667. HoldAndDelegate: "https://unused/ssh-action/check2",
  668. },
  669. "check2": acceptRule.Action,
  670. },
  671. },
  672. wantBanners: []string{"First", "url-here", "Welcome to Tailscale SSH!"},
  673. },
  674. {
  675. name: "check-reject",
  676. state: &localState{
  677. sshEnabled: true,
  678. matchingRule: newSSHRule(&tailcfg.SSHAction{
  679. Message: "First",
  680. HoldAndDelegate: "https://unused/ssh-action/reject",
  681. }),
  682. serverActions: map[string]*tailcfg.SSHAction{
  683. "reject": rejectRule.Action,
  684. },
  685. },
  686. wantBanners: []string{"First", "Go Away!"},
  687. authErr: true,
  688. },
  689. {
  690. name: "force-password-auth",
  691. sshUser: "alice+password",
  692. state: &localState{
  693. sshEnabled: true,
  694. matchingRule: acceptRule,
  695. },
  696. usesPassword: true,
  697. wantBanners: []string{"Welcome to Tailscale SSH!"},
  698. },
  699. }
  700. s := &server{
  701. logf: logger.Discard,
  702. }
  703. defer s.Shutdown()
  704. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  705. for _, tc := range tests {
  706. t.Run(tc.name, func(t *testing.T) {
  707. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  708. s.lb = tc.state
  709. sshUser := "alice"
  710. if tc.sshUser != "" {
  711. sshUser = tc.sshUser
  712. }
  713. var passwordUsed atomic.Bool
  714. cfg := &gossh.ClientConfig{
  715. User: sshUser,
  716. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  717. Auth: []gossh.AuthMethod{
  718. gossh.PasswordCallback(func() (secret string, err error) {
  719. if !tc.usesPassword {
  720. t.Error("unexpected use of PasswordCallback")
  721. return "", errors.New("unexpected use of PasswordCallback")
  722. }
  723. passwordUsed.Store(true)
  724. return "any-pass", nil
  725. }),
  726. },
  727. BannerCallback: func(message string) error {
  728. if len(tc.wantBanners) == 0 {
  729. t.Errorf("unexpected banner: %q", message)
  730. } else if message != tc.wantBanners[0] {
  731. t.Errorf("banner = %q; want %q", message, tc.wantBanners[0])
  732. } else {
  733. t.Logf("banner = %q", message)
  734. tc.wantBanners = tc.wantBanners[1:]
  735. }
  736. return nil
  737. },
  738. }
  739. var wg sync.WaitGroup
  740. wg.Add(1)
  741. go func() {
  742. defer wg.Done()
  743. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  744. if err != nil {
  745. if !tc.authErr {
  746. t.Errorf("client: %v", err)
  747. }
  748. return
  749. } else if tc.authErr {
  750. c.Close()
  751. t.Errorf("client: expected error, got nil")
  752. return
  753. }
  754. client := gossh.NewClient(c, chans, reqs)
  755. defer client.Close()
  756. session, err := client.NewSession()
  757. if err != nil {
  758. t.Errorf("client: %v", err)
  759. return
  760. }
  761. defer session.Close()
  762. _, err = session.CombinedOutput("echo Ran echo!")
  763. if err != nil {
  764. t.Errorf("client: %v", err)
  765. }
  766. }()
  767. if err := s.HandleSSHConn(dc); err != nil {
  768. t.Errorf("unexpected error: %v", err)
  769. }
  770. wg.Wait()
  771. if len(tc.wantBanners) > 0 {
  772. t.Errorf("missing banners: %v", tc.wantBanners)
  773. }
  774. })
  775. }
  776. }
  777. func TestSSH(t *testing.T) {
  778. var logf logger.Logf = t.Logf
  779. sys := &tsd.System{}
  780. eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set)
  781. if err != nil {
  782. t.Fatal(err)
  783. }
  784. sys.Set(eng)
  785. sys.Set(new(mem.Store))
  786. lb, err := ipnlocal.NewLocalBackend(logf, logid.PublicID{}, sys, 0)
  787. if err != nil {
  788. t.Fatal(err)
  789. }
  790. defer lb.Shutdown()
  791. dir := t.TempDir()
  792. lb.SetVarRoot(dir)
  793. srv := &server{
  794. lb: lb,
  795. logf: logf,
  796. }
  797. sc, err := srv.newConn()
  798. if err != nil {
  799. t.Fatal(err)
  800. }
  801. // Remove the auth checks for the test
  802. sc.insecureSkipTailscaleAuth = true
  803. u, err := user.Current()
  804. if err != nil {
  805. t.Fatal(err)
  806. }
  807. um, err := userLookup(u.Username)
  808. if err != nil {
  809. t.Fatal(err)
  810. }
  811. sc.localUser = um
  812. sc.info = &sshConnInfo{
  813. sshUser: "test",
  814. src: netip.MustParseAddrPort("1.2.3.4:32342"),
  815. dst: netip.MustParseAddrPort("1.2.3.5:22"),
  816. node: &tailcfg.Node{},
  817. uprof: tailcfg.UserProfile{},
  818. }
  819. sc.action0 = &tailcfg.SSHAction{Accept: true}
  820. sc.finalAction = sc.action0
  821. sc.Handler = func(s ssh.Session) {
  822. sc.newSSHSession(s).run()
  823. }
  824. ln, err := net.Listen("tcp4", "127.0.0.1:0")
  825. if err != nil {
  826. t.Fatal(err)
  827. }
  828. defer ln.Close()
  829. port := ln.Addr().(*net.TCPAddr).Port
  830. go func() {
  831. for {
  832. c, err := ln.Accept()
  833. if err != nil {
  834. if !errors.Is(err, net.ErrClosed) {
  835. t.Errorf("Accept: %v", err)
  836. }
  837. return
  838. }
  839. go sc.HandleConn(c)
  840. }
  841. }()
  842. execSSH := func(args ...string) *exec.Cmd {
  843. cmd := exec.Command("ssh",
  844. "-F",
  845. "none",
  846. "-v",
  847. "-p", fmt.Sprint(port),
  848. "-o", "StrictHostKeyChecking=no",
  849. "[email protected]")
  850. cmd.Args = append(cmd.Args, args...)
  851. return cmd
  852. }
  853. t.Run("env", func(t *testing.T) {
  854. if cibuild.On() {
  855. t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")
  856. }
  857. cmd := execSSH("LANG=foo env")
  858. cmd.Env = append(os.Environ(), "LOCAL_ENV=bar")
  859. got, err := cmd.CombinedOutput()
  860. if err != nil {
  861. t.Fatal(err, string(got))
  862. }
  863. m := parseEnv(got)
  864. if got := m["USER"]; got == "" || got != u.Username {
  865. t.Errorf("USER = %q; want %q", got, u.Username)
  866. }
  867. if got := m["HOME"]; got == "" || got != u.HomeDir {
  868. t.Errorf("HOME = %q; want %q", got, u.HomeDir)
  869. }
  870. if got := m["PWD"]; got == "" || got != u.HomeDir {
  871. t.Errorf("PWD = %q; want %q", got, u.HomeDir)
  872. }
  873. if got := m["SHELL"]; got == "" {
  874. t.Errorf("no SHELL")
  875. }
  876. if got, want := m["LANG"], "foo"; got != want {
  877. t.Errorf("LANG = %q; want %q", got, want)
  878. }
  879. if got := m["LOCAL_ENV"]; got != "" {
  880. t.Errorf("LOCAL_ENV leaked over ssh: %v", got)
  881. }
  882. t.Logf("got: %+v", m)
  883. })
  884. t.Run("stdout_stderr", func(t *testing.T) {
  885. cmd := execSSH("sh", "-c", "echo foo; echo bar >&2")
  886. var outBuf, errBuf bytes.Buffer
  887. cmd.Stdout = &outBuf
  888. cmd.Stderr = &errBuf
  889. if err := cmd.Run(); err != nil {
  890. t.Fatal(err)
  891. }
  892. t.Logf("Got: %q and %q", outBuf.Bytes(), errBuf.Bytes())
  893. // TODO: figure out why these aren't right. should be
  894. // "foo\n" and "bar\n", not "\n" and "bar\n".
  895. })
  896. t.Run("large_file", func(t *testing.T) {
  897. const wantSize = 1e6
  898. var outBuf bytes.Buffer
  899. cmd := execSSH("head", "-c", strconv.Itoa(wantSize), "/dev/zero")
  900. cmd.Stdout = &outBuf
  901. if err := cmd.Run(); err != nil {
  902. t.Fatal(err)
  903. }
  904. if gotSize := outBuf.Len(); gotSize != wantSize {
  905. t.Fatalf("got %d, want %d", gotSize, int(wantSize))
  906. }
  907. })
  908. t.Run("stdin", func(t *testing.T) {
  909. if cibuild.On() {
  910. t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")
  911. }
  912. cmd := execSSH("cat")
  913. var outBuf bytes.Buffer
  914. cmd.Stdout = &outBuf
  915. const str = "foo\nbar\n"
  916. cmd.Stdin = strings.NewReader(str)
  917. if err := cmd.Run(); err != nil {
  918. t.Fatal(err)
  919. }
  920. if got := outBuf.String(); got != str {
  921. t.Errorf("got %q; want %q", got, str)
  922. }
  923. })
  924. }
  925. func parseEnv(out []byte) map[string]string {
  926. e := map[string]string{}
  927. lineread.Reader(bytes.NewReader(out), func(line []byte) error {
  928. i := bytes.IndexByte(line, '=')
  929. if i == -1 {
  930. return nil
  931. }
  932. e[string(line[:i])] = string(line[i+1:])
  933. return nil
  934. })
  935. return e
  936. }
  937. func TestPublicKeyFetching(t *testing.T) {
  938. var reqsTotal, reqsIfNoneMatchHit, reqsIfNoneMatchMiss int32
  939. ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  940. atomic.AddInt32((&reqsTotal), 1)
  941. etag := fmt.Sprintf("W/%q", sha256.Sum256([]byte(r.URL.Path)))
  942. w.Header().Set("Etag", etag)
  943. if v := r.Header.Get("If-None-Match"); v != "" {
  944. if v == etag {
  945. atomic.AddInt32(&reqsIfNoneMatchHit, 1)
  946. w.WriteHeader(304)
  947. return
  948. }
  949. atomic.AddInt32(&reqsIfNoneMatchMiss, 1)
  950. }
  951. io.WriteString(w, "foo\nbar\n"+string(r.URL.Path)+"\n")
  952. }))
  953. ts.StartTLS()
  954. defer ts.Close()
  955. keys := ts.URL
  956. clock := &tstest.Clock{}
  957. srv := &server{
  958. pubKeyHTTPClient: ts.Client(),
  959. timeNow: clock.Now,
  960. }
  961. for i := 0; i < 2; i++ {
  962. got, err := srv.fetchPublicKeysURL(keys + "/alice.keys")
  963. if err != nil {
  964. t.Fatal(err)
  965. }
  966. if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) {
  967. t.Errorf("got %q; want %q", got, want)
  968. }
  969. }
  970. if got, want := atomic.LoadInt32(&reqsTotal), int32(1); got != want {
  971. t.Errorf("got %d requests; want %d", got, want)
  972. }
  973. if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(0); got != want {
  974. t.Errorf("got %d etag hits; want %d", got, want)
  975. }
  976. clock.Advance(5 * time.Minute)
  977. got, err := srv.fetchPublicKeysURL(keys + "/alice.keys")
  978. if err != nil {
  979. t.Fatal(err)
  980. }
  981. if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) {
  982. t.Errorf("got %q; want %q", got, want)
  983. }
  984. if got, want := atomic.LoadInt32(&reqsTotal), int32(2); got != want {
  985. t.Errorf("got %d requests; want %d", got, want)
  986. }
  987. if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(1); got != want {
  988. t.Errorf("got %d etag hits; want %d", got, want)
  989. }
  990. if got, want := atomic.LoadInt32(&reqsIfNoneMatchMiss), int32(0); got != want {
  991. t.Errorf("got %d etag misses; want %d", got, want)
  992. }
  993. }
  994. func TestExpandPublicKeyURL(t *testing.T) {
  995. c := &conn{
  996. info: &sshConnInfo{
  997. uprof: tailcfg.UserProfile{
  998. LoginName: "[email protected]",
  999. },
  1000. },
  1001. }
  1002. if got, want := c.expandPublicKeyURL("foo"), "foo"; got != want {
  1003. t.Errorf("basic: got %q; want %q", got, want)
  1004. }
  1005. if got, want := c.expandPublicKeyURL("https://example.com/$LOGINNAME_LOCALPART.keys"), "https://example.com/bar.keys"; got != want {
  1006. t.Errorf("localpart: got %q; want %q", got, want)
  1007. }
  1008. if got, want := c.expandPublicKeyURL("https://example.com/keys?email=$LOGINNAME_EMAIL"), "https://example.com/[email protected]"; got != want {
  1009. t.Errorf("email: got %q; want %q", got, want)
  1010. }
  1011. c.info = new(sshConnInfo)
  1012. if got, want := c.expandPublicKeyURL("https://example.com/keys?email=$LOGINNAME_EMAIL"), "https://example.com/keys?email="; got != want {
  1013. t.Errorf("on empty: got %q; want %q", got, want)
  1014. }
  1015. }
  1016. func TestAcceptEnvPair(t *testing.T) {
  1017. tests := []struct {
  1018. in string
  1019. want bool
  1020. }{
  1021. {"TERM=x", true},
  1022. {"term=x", false},
  1023. {"TERM", false},
  1024. {"LC_FOO=x", true},
  1025. {"LD_PRELOAD=naah", false},
  1026. {"TERM=screen-256color", true},
  1027. }
  1028. for _, tt := range tests {
  1029. if got := acceptEnvPair(tt.in); got != tt.want {
  1030. t.Errorf("for %q, got %v; want %v", tt.in, got, tt.want)
  1031. }
  1032. }
  1033. }
  1034. func TestPathFromPAMEnvLine(t *testing.T) {
  1035. u := &user.User{Username: "foo", HomeDir: "/Homes/Foo"}
  1036. tests := []struct {
  1037. line string
  1038. u *user.User
  1039. want string
  1040. }{
  1041. {"", u, ""},
  1042. {`PATH DEFAULT="/run/wrappers/bin:@{HOME}/.nix-profile/bin:/etc/profiles/per-user/@{PAM_USER}/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin"`,
  1043. u, "/run/wrappers/bin:/Homes/Foo/.nix-profile/bin:/etc/profiles/per-user/foo/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin"},
  1044. {`PATH DEFAULT="@{SOMETHING_ELSE}:nope:@{HOME}"`,
  1045. u, ""},
  1046. }
  1047. for i, tt := range tests {
  1048. got := pathFromPAMEnvLine([]byte(tt.line), tt.u)
  1049. if got != tt.want {
  1050. t.Errorf("%d. got %q; want %q", i, got, tt.want)
  1051. }
  1052. }
  1053. }
  1054. func TestExpandDefaultPathTmpl(t *testing.T) {
  1055. u := &user.User{Username: "foo", HomeDir: "/Homes/Foo"}
  1056. tests := []struct {
  1057. t string
  1058. u *user.User
  1059. want string
  1060. }{
  1061. {"", u, ""},
  1062. {`/run/wrappers/bin:@{HOME}/.nix-profile/bin:/etc/profiles/per-user/@{PAM_USER}/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin`,
  1063. u, "/run/wrappers/bin:/Homes/Foo/.nix-profile/bin:/etc/profiles/per-user/foo/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin"},
  1064. {`@{SOMETHING_ELSE}:nope:@{HOME}`, u, ""},
  1065. }
  1066. for i, tt := range tests {
  1067. got := expandDefaultPathTmpl(tt.t, tt.u)
  1068. if got != tt.want {
  1069. t.Errorf("%d. got %q; want %q", i, got, tt.want)
  1070. }
  1071. }
  1072. }
  1073. func TestPathFromPAMEnvLineOnNixOS(t *testing.T) {
  1074. if runtime.GOOS != "linux" {
  1075. t.Skip("skipping on non-linux")
  1076. }
  1077. if distro.Get() != distro.NixOS {
  1078. t.Skip("skipping on non-NixOS")
  1079. }
  1080. u, err := user.Current()
  1081. if err != nil {
  1082. t.Fatal(err)
  1083. }
  1084. got := defaultPathForUserOnNixOS(u)
  1085. if got == "" {
  1086. x, err := os.ReadFile("/etc/pam/environment")
  1087. t.Fatalf("no result. file was: err=%v, contents=%s", err, x)
  1088. }
  1089. t.Logf("success; got=%q", got)
  1090. }
  1091. func TestStdOsUserUserAssumptions(t *testing.T) {
  1092. v := reflect.TypeOf(user.User{})
  1093. if got, want := v.NumField(), 5; got != want {
  1094. t.Errorf("os/user.User has %v fields; this package assumes %v", got, want)
  1095. }
  1096. }