tailssh_test.go 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. //go:build linux || darwin
  4. package tailssh
  5. import (
  6. "bytes"
  7. "context"
  8. "crypto/ed25519"
  9. "crypto/rand"
  10. "crypto/sha256"
  11. "encoding/json"
  12. "errors"
  13. "fmt"
  14. "io"
  15. "net"
  16. "net/http"
  17. "net/http/httptest"
  18. "net/netip"
  19. "os"
  20. "os/exec"
  21. "os/user"
  22. "reflect"
  23. "runtime"
  24. "strconv"
  25. "strings"
  26. "sync"
  27. "sync/atomic"
  28. "testing"
  29. "time"
  30. gossh "github.com/tailscale/golang-x-crypto/ssh"
  31. "tailscale.com/ipn/ipnlocal"
  32. "tailscale.com/ipn/store/mem"
  33. "tailscale.com/net/memnet"
  34. "tailscale.com/net/tsdial"
  35. "tailscale.com/tailcfg"
  36. "tailscale.com/tempfork/gliderlabs/ssh"
  37. "tailscale.com/tsd"
  38. "tailscale.com/tstest"
  39. "tailscale.com/types/key"
  40. "tailscale.com/types/logger"
  41. "tailscale.com/types/logid"
  42. "tailscale.com/types/netmap"
  43. "tailscale.com/types/ptr"
  44. "tailscale.com/util/cibuild"
  45. "tailscale.com/util/lineread"
  46. "tailscale.com/util/must"
  47. "tailscale.com/version/distro"
  48. "tailscale.com/wgengine"
  49. )
  50. func TestMatchRule(t *testing.T) {
  51. someAction := new(tailcfg.SSHAction)
  52. tests := []struct {
  53. name string
  54. rule *tailcfg.SSHRule
  55. ci *sshConnInfo
  56. wantErr error
  57. wantUser string
  58. }{
  59. {
  60. name: "invalid-conn",
  61. rule: &tailcfg.SSHRule{
  62. Action: someAction,
  63. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  64. SSHUsers: map[string]string{
  65. "*": "ubuntu",
  66. },
  67. },
  68. wantErr: errInvalidConn,
  69. },
  70. {
  71. name: "nil-rule",
  72. ci: &sshConnInfo{},
  73. rule: nil,
  74. wantErr: errNilRule,
  75. },
  76. {
  77. name: "nil-action",
  78. ci: &sshConnInfo{},
  79. rule: &tailcfg.SSHRule{},
  80. wantErr: errNilAction,
  81. },
  82. {
  83. name: "expired",
  84. rule: &tailcfg.SSHRule{
  85. Action: someAction,
  86. RuleExpires: ptr.To(time.Unix(100, 0)),
  87. },
  88. ci: &sshConnInfo{},
  89. wantErr: errRuleExpired,
  90. },
  91. {
  92. name: "no-principal",
  93. rule: &tailcfg.SSHRule{
  94. Action: someAction,
  95. SSHUsers: map[string]string{
  96. "*": "ubuntu",
  97. }},
  98. ci: &sshConnInfo{},
  99. wantErr: errPrincipalMatch,
  100. },
  101. {
  102. name: "no-user-match",
  103. rule: &tailcfg.SSHRule{
  104. Action: someAction,
  105. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  106. },
  107. ci: &sshConnInfo{sshUser: "alice"},
  108. wantErr: errUserMatch,
  109. },
  110. {
  111. name: "ok-wildcard",
  112. rule: &tailcfg.SSHRule{
  113. Action: someAction,
  114. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  115. SSHUsers: map[string]string{
  116. "*": "ubuntu",
  117. },
  118. },
  119. ci: &sshConnInfo{sshUser: "alice"},
  120. wantUser: "ubuntu",
  121. },
  122. {
  123. name: "ok-wildcard-and-nil-principal",
  124. rule: &tailcfg.SSHRule{
  125. Action: someAction,
  126. Principals: []*tailcfg.SSHPrincipal{
  127. nil, // don't crash on this
  128. {Any: true},
  129. },
  130. SSHUsers: map[string]string{
  131. "*": "ubuntu",
  132. },
  133. },
  134. ci: &sshConnInfo{sshUser: "alice"},
  135. wantUser: "ubuntu",
  136. },
  137. {
  138. name: "ok-exact",
  139. rule: &tailcfg.SSHRule{
  140. Action: someAction,
  141. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  142. SSHUsers: map[string]string{
  143. "*": "ubuntu",
  144. "alice": "thealice",
  145. },
  146. },
  147. ci: &sshConnInfo{sshUser: "alice"},
  148. wantUser: "thealice",
  149. },
  150. {
  151. name: "no-users-for-reject",
  152. rule: &tailcfg.SSHRule{
  153. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  154. Action: &tailcfg.SSHAction{Reject: true},
  155. },
  156. ci: &sshConnInfo{sshUser: "alice"},
  157. },
  158. {
  159. name: "match-principal-node-ip",
  160. rule: &tailcfg.SSHRule{
  161. Action: someAction,
  162. Principals: []*tailcfg.SSHPrincipal{{NodeIP: "1.2.3.4"}},
  163. SSHUsers: map[string]string{"*": "ubuntu"},
  164. },
  165. ci: &sshConnInfo{src: netip.MustParseAddrPort("1.2.3.4:30343")},
  166. wantUser: "ubuntu",
  167. },
  168. {
  169. name: "match-principal-node-id",
  170. rule: &tailcfg.SSHRule{
  171. Action: someAction,
  172. Principals: []*tailcfg.SSHPrincipal{{Node: "some-node-ID"}},
  173. SSHUsers: map[string]string{"*": "ubuntu"},
  174. },
  175. ci: &sshConnInfo{node: (&tailcfg.Node{StableID: "some-node-ID"}).View()},
  176. wantUser: "ubuntu",
  177. },
  178. {
  179. name: "match-principal-userlogin",
  180. rule: &tailcfg.SSHRule{
  181. Action: someAction,
  182. Principals: []*tailcfg.SSHPrincipal{{UserLogin: "[email protected]"}},
  183. SSHUsers: map[string]string{"*": "ubuntu"},
  184. },
  185. ci: &sshConnInfo{uprof: tailcfg.UserProfile{LoginName: "[email protected]"}},
  186. wantUser: "ubuntu",
  187. },
  188. {
  189. name: "ssh-user-equal",
  190. rule: &tailcfg.SSHRule{
  191. Action: someAction,
  192. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  193. SSHUsers: map[string]string{
  194. "*": "=",
  195. },
  196. },
  197. ci: &sshConnInfo{sshUser: "alice"},
  198. wantUser: "alice",
  199. },
  200. }
  201. for _, tt := range tests {
  202. t.Run(tt.name, func(t *testing.T) {
  203. c := &conn{
  204. info: tt.ci,
  205. srv: &server{logf: t.Logf},
  206. }
  207. got, gotUser, err := c.matchRule(tt.rule, nil)
  208. if err != tt.wantErr {
  209. t.Errorf("err = %v; want %v", err, tt.wantErr)
  210. }
  211. if gotUser != tt.wantUser {
  212. t.Errorf("user = %q; want %q", gotUser, tt.wantUser)
  213. }
  214. if err == nil && got == nil {
  215. t.Errorf("expected non-nil action on success")
  216. }
  217. })
  218. }
  219. }
  220. // localState implements ipnLocalBackend for testing.
  221. type localState struct {
  222. sshEnabled bool
  223. matchingRule *tailcfg.SSHRule
  224. // serverActions is a map of the action name to the action.
  225. // It is served for paths like https://unused/ssh-action/<action-name>.
  226. // The action name is the last part of the action URL.
  227. serverActions map[string]*tailcfg.SSHAction
  228. }
  229. var (
  230. currentUser = os.Getenv("USER") // Use the current user for the test.
  231. testSigner gossh.Signer
  232. testSignerOnce sync.Once
  233. )
  234. func (ts *localState) Dialer() *tsdial.Dialer {
  235. return &tsdial.Dialer{}
  236. }
  237. func (ts *localState) GetSSH_HostKeys() ([]gossh.Signer, error) {
  238. testSignerOnce.Do(func() {
  239. _, priv, err := ed25519.GenerateKey(rand.Reader)
  240. if err != nil {
  241. panic(err)
  242. }
  243. s, err := gossh.NewSignerFromSigner(priv)
  244. if err != nil {
  245. panic(err)
  246. }
  247. testSigner = s
  248. })
  249. return []gossh.Signer{testSigner}, nil
  250. }
  251. func (ts *localState) ShouldRunSSH() bool {
  252. return ts.sshEnabled
  253. }
  254. func (ts *localState) NetMap() *netmap.NetworkMap {
  255. var policy *tailcfg.SSHPolicy
  256. if ts.matchingRule != nil {
  257. policy = &tailcfg.SSHPolicy{
  258. Rules: []*tailcfg.SSHRule{
  259. ts.matchingRule,
  260. },
  261. }
  262. }
  263. return &netmap.NetworkMap{
  264. SelfNode: (&tailcfg.Node{
  265. ID: 1,
  266. }).View(),
  267. SSHPolicy: policy,
  268. }
  269. }
  270. func (ts *localState) WhoIs(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) {
  271. if proto != "tcp" {
  272. return tailcfg.NodeView{}, tailcfg.UserProfile{}, false
  273. }
  274. return (&tailcfg.Node{
  275. ID: 2,
  276. StableID: "peer-id",
  277. }).View(), tailcfg.UserProfile{
  278. LoginName: "peer",
  279. }, true
  280. }
  281. func (ts *localState) DoNoiseRequest(req *http.Request) (*http.Response, error) {
  282. rec := httptest.NewRecorder()
  283. k, ok := strings.CutPrefix(req.URL.Path, "/ssh-action/")
  284. if !ok {
  285. rec.WriteHeader(http.StatusNotFound)
  286. }
  287. a, ok := ts.serverActions[k]
  288. if !ok {
  289. rec.WriteHeader(http.StatusNotFound)
  290. return rec.Result(), nil
  291. }
  292. rec.WriteHeader(http.StatusOK)
  293. if err := json.NewEncoder(rec).Encode(a); err != nil {
  294. return nil, err
  295. }
  296. return rec.Result(), nil
  297. }
  298. func (ts *localState) TailscaleVarRoot() string {
  299. return ""
  300. }
  301. func (ts *localState) NodeKey() key.NodePublic {
  302. return key.NewNode().Public()
  303. }
  304. func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule {
  305. return &tailcfg.SSHRule{
  306. SSHUsers: map[string]string{
  307. "*": currentUser,
  308. },
  309. Action: action,
  310. Principals: []*tailcfg.SSHPrincipal{
  311. {
  312. Any: true,
  313. },
  314. },
  315. }
  316. }
  317. func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
  318. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  319. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  320. }
  321. var handler http.HandlerFunc
  322. recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  323. handler(w, r)
  324. }))
  325. defer recordingServer.Close()
  326. s := &server{
  327. logf: t.Logf,
  328. lb: &localState{
  329. sshEnabled: true,
  330. matchingRule: newSSHRule(
  331. &tailcfg.SSHAction{
  332. Accept: true,
  333. Recorders: []netip.AddrPort{
  334. netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
  335. },
  336. OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
  337. RejectSessionWithMessage: "session rejected",
  338. TerminateSessionWithMessage: "session terminated",
  339. },
  340. },
  341. ),
  342. },
  343. }
  344. defer s.Shutdown()
  345. const sshUser = "alice"
  346. cfg := &gossh.ClientConfig{
  347. User: sshUser,
  348. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  349. }
  350. tests := []struct {
  351. name string
  352. handler func(w http.ResponseWriter, r *http.Request)
  353. sshCommand string
  354. wantClientOutput string
  355. clientOutputMustNotContain []string
  356. }{
  357. {
  358. name: "upload-denied",
  359. handler: func(w http.ResponseWriter, r *http.Request) {
  360. w.WriteHeader(http.StatusForbidden)
  361. },
  362. sshCommand: "echo hello",
  363. wantClientOutput: "session rejected\r\n",
  364. clientOutputMustNotContain: []string{"hello"},
  365. },
  366. {
  367. name: "upload-fails-after-starting",
  368. handler: func(w http.ResponseWriter, r *http.Request) {
  369. r.Body.Read(make([]byte, 1))
  370. time.Sleep(100 * time.Millisecond)
  371. w.WriteHeader(http.StatusInternalServerError)
  372. },
  373. sshCommand: "echo hello && sleep 1 && echo world",
  374. wantClientOutput: "\r\n\r\nsession terminated\r\n\r\n",
  375. clientOutputMustNotContain: []string{"world"},
  376. },
  377. }
  378. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  379. for _, tt := range tests {
  380. t.Run(tt.name, func(t *testing.T) {
  381. tstest.Replace(t, &handler, tt.handler)
  382. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  383. var wg sync.WaitGroup
  384. wg.Add(1)
  385. go func() {
  386. defer wg.Done()
  387. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  388. if err != nil {
  389. t.Errorf("client: %v", err)
  390. return
  391. }
  392. client := gossh.NewClient(c, chans, reqs)
  393. defer client.Close()
  394. session, err := client.NewSession()
  395. if err != nil {
  396. t.Errorf("client: %v", err)
  397. return
  398. }
  399. defer session.Close()
  400. t.Logf("client established session")
  401. got, err := session.CombinedOutput(tt.sshCommand)
  402. if err != nil {
  403. t.Logf("client got: %q: %v", got, err)
  404. } else {
  405. t.Errorf("client did not get kicked out: %q", got)
  406. }
  407. gotStr := string(got)
  408. if !strings.HasSuffix(gotStr, tt.wantClientOutput) {
  409. t.Errorf("client got %q, want %q", got, tt.wantClientOutput)
  410. }
  411. for _, x := range tt.clientOutputMustNotContain {
  412. if strings.Contains(gotStr, x) {
  413. t.Errorf("client output must not contain %q", x)
  414. }
  415. }
  416. }()
  417. if err := s.HandleSSHConn(dc); err != nil {
  418. t.Errorf("unexpected error: %v", err)
  419. }
  420. wg.Wait()
  421. })
  422. }
  423. }
  424. func TestMultipleRecorders(t *testing.T) {
  425. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  426. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  427. }
  428. done := make(chan struct{})
  429. recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  430. defer close(done)
  431. io.ReadAll(r.Body)
  432. w.WriteHeader(http.StatusOK)
  433. }))
  434. defer recordingServer.Close()
  435. badRecorder, err := net.Listen("tcp", ":0")
  436. if err != nil {
  437. t.Fatal(err)
  438. }
  439. badRecorderAddr := badRecorder.Addr().String()
  440. badRecorder.Close()
  441. badRecordingServer500 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  442. w.WriteHeader(500)
  443. }))
  444. defer badRecordingServer500.Close()
  445. badRecordingServer200 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  446. w.WriteHeader(200)
  447. }))
  448. defer badRecordingServer200.Close()
  449. s := &server{
  450. logf: t.Logf,
  451. lb: &localState{
  452. sshEnabled: true,
  453. matchingRule: newSSHRule(
  454. &tailcfg.SSHAction{
  455. Accept: true,
  456. Recorders: []netip.AddrPort{
  457. netip.MustParseAddrPort(badRecorderAddr),
  458. netip.MustParseAddrPort(badRecordingServer500.Listener.Addr().String()),
  459. netip.MustParseAddrPort(badRecordingServer200.Listener.Addr().String()),
  460. netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
  461. },
  462. OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
  463. RejectSessionWithMessage: "session rejected",
  464. TerminateSessionWithMessage: "session terminated",
  465. },
  466. },
  467. ),
  468. },
  469. }
  470. defer s.Shutdown()
  471. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  472. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  473. const sshUser = "alice"
  474. cfg := &gossh.ClientConfig{
  475. User: sshUser,
  476. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  477. }
  478. var wg sync.WaitGroup
  479. wg.Add(1)
  480. go func() {
  481. defer wg.Done()
  482. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  483. if err != nil {
  484. t.Errorf("client: %v", err)
  485. return
  486. }
  487. client := gossh.NewClient(c, chans, reqs)
  488. defer client.Close()
  489. session, err := client.NewSession()
  490. if err != nil {
  491. t.Errorf("client: %v", err)
  492. return
  493. }
  494. defer session.Close()
  495. t.Logf("client established session")
  496. out, err := session.CombinedOutput("echo Ran echo!")
  497. if err != nil {
  498. t.Errorf("client: %v", err)
  499. }
  500. if string(out) != "Ran echo!\n" {
  501. t.Errorf("client: unexpected output: %q", out)
  502. }
  503. }()
  504. if err := s.HandleSSHConn(dc); err != nil {
  505. t.Errorf("unexpected error: %v", err)
  506. }
  507. wg.Wait()
  508. select {
  509. case <-done:
  510. case <-time.After(1 * time.Second):
  511. t.Fatal("timed out waiting for recording")
  512. }
  513. }
  514. // TestSSHRecordingNonInteractive tests that the SSH server records the SSH session
  515. // when the client is not interactive (i.e. no PTY).
  516. // It starts a local SSH server and a recording server. The recording server
  517. // records the SSH session and returns it to the test.
  518. // The test then verifies that the recording has a valid CastHeader, it does not
  519. // validate the contents of the recording.
  520. func TestSSHRecordingNonInteractive(t *testing.T) {
  521. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  522. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  523. }
  524. var recording []byte
  525. ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  526. recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  527. defer cancel()
  528. var err error
  529. recording, err = io.ReadAll(r.Body)
  530. if err != nil {
  531. t.Error(err)
  532. return
  533. }
  534. }))
  535. defer recordingServer.Close()
  536. s := &server{
  537. logf: logger.Discard,
  538. lb: &localState{
  539. sshEnabled: true,
  540. matchingRule: newSSHRule(
  541. &tailcfg.SSHAction{
  542. Accept: true,
  543. Recorders: []netip.AddrPort{
  544. must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())),
  545. },
  546. OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
  547. RejectSessionWithMessage: "session rejected",
  548. TerminateSessionWithMessage: "session terminated",
  549. },
  550. },
  551. ),
  552. },
  553. }
  554. defer s.Shutdown()
  555. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  556. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  557. const sshUser = "alice"
  558. cfg := &gossh.ClientConfig{
  559. User: sshUser,
  560. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  561. }
  562. var wg sync.WaitGroup
  563. wg.Add(1)
  564. go func() {
  565. defer wg.Done()
  566. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  567. if err != nil {
  568. t.Errorf("client: %v", err)
  569. return
  570. }
  571. client := gossh.NewClient(c, chans, reqs)
  572. defer client.Close()
  573. session, err := client.NewSession()
  574. if err != nil {
  575. t.Errorf("client: %v", err)
  576. return
  577. }
  578. defer session.Close()
  579. t.Logf("client established session")
  580. _, err = session.CombinedOutput("echo Ran echo!")
  581. if err != nil {
  582. t.Errorf("client: %v", err)
  583. }
  584. }()
  585. if err := s.HandleSSHConn(dc); err != nil {
  586. t.Errorf("unexpected error: %v", err)
  587. }
  588. wg.Wait()
  589. <-ctx.Done() // wait for recording to finish
  590. var ch CastHeader
  591. if err := json.NewDecoder(bytes.NewReader(recording)).Decode(&ch); err != nil {
  592. t.Fatal(err)
  593. }
  594. if ch.SSHUser != sshUser {
  595. t.Errorf("SSHUser = %q; want %q", ch.SSHUser, sshUser)
  596. }
  597. if ch.Command != "echo Ran echo!" {
  598. t.Errorf("Command = %q; want %q", ch.Command, "echo Ran echo!")
  599. }
  600. }
  601. func TestSSHAuthFlow(t *testing.T) {
  602. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  603. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  604. }
  605. acceptRule := newSSHRule(&tailcfg.SSHAction{
  606. Accept: true,
  607. Message: "Welcome to Tailscale SSH!",
  608. })
  609. rejectRule := newSSHRule(&tailcfg.SSHAction{
  610. Reject: true,
  611. Message: "Go Away!",
  612. })
  613. tests := []struct {
  614. name string
  615. sshUser string // defaults to alice
  616. state *localState
  617. wantBanners []string
  618. usesPassword bool
  619. authErr bool
  620. }{
  621. {
  622. name: "no-policy",
  623. state: &localState{
  624. sshEnabled: true,
  625. },
  626. authErr: true,
  627. },
  628. {
  629. name: "accept",
  630. state: &localState{
  631. sshEnabled: true,
  632. matchingRule: acceptRule,
  633. },
  634. wantBanners: []string{"Welcome to Tailscale SSH!"},
  635. },
  636. {
  637. name: "reject",
  638. state: &localState{
  639. sshEnabled: true,
  640. matchingRule: rejectRule,
  641. },
  642. wantBanners: []string{"Go Away!"},
  643. authErr: true,
  644. },
  645. {
  646. name: "simple-check",
  647. state: &localState{
  648. sshEnabled: true,
  649. matchingRule: newSSHRule(&tailcfg.SSHAction{
  650. HoldAndDelegate: "https://unused/ssh-action/accept",
  651. }),
  652. serverActions: map[string]*tailcfg.SSHAction{
  653. "accept": acceptRule.Action,
  654. },
  655. },
  656. wantBanners: []string{"Welcome to Tailscale SSH!"},
  657. },
  658. {
  659. name: "multi-check",
  660. state: &localState{
  661. sshEnabled: true,
  662. matchingRule: newSSHRule(&tailcfg.SSHAction{
  663. Message: "First",
  664. HoldAndDelegate: "https://unused/ssh-action/check1",
  665. }),
  666. serverActions: map[string]*tailcfg.SSHAction{
  667. "check1": {
  668. Message: "url-here",
  669. HoldAndDelegate: "https://unused/ssh-action/check2",
  670. },
  671. "check2": acceptRule.Action,
  672. },
  673. },
  674. wantBanners: []string{"First", "url-here", "Welcome to Tailscale SSH!"},
  675. },
  676. {
  677. name: "check-reject",
  678. state: &localState{
  679. sshEnabled: true,
  680. matchingRule: newSSHRule(&tailcfg.SSHAction{
  681. Message: "First",
  682. HoldAndDelegate: "https://unused/ssh-action/reject",
  683. }),
  684. serverActions: map[string]*tailcfg.SSHAction{
  685. "reject": rejectRule.Action,
  686. },
  687. },
  688. wantBanners: []string{"First", "Go Away!"},
  689. authErr: true,
  690. },
  691. {
  692. name: "force-password-auth",
  693. sshUser: "alice+password",
  694. state: &localState{
  695. sshEnabled: true,
  696. matchingRule: acceptRule,
  697. },
  698. usesPassword: true,
  699. wantBanners: []string{"Welcome to Tailscale SSH!"},
  700. },
  701. }
  702. s := &server{
  703. logf: logger.Discard,
  704. }
  705. defer s.Shutdown()
  706. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  707. for _, tc := range tests {
  708. t.Run(tc.name, func(t *testing.T) {
  709. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  710. s.lb = tc.state
  711. sshUser := "alice"
  712. if tc.sshUser != "" {
  713. sshUser = tc.sshUser
  714. }
  715. var passwordUsed atomic.Bool
  716. cfg := &gossh.ClientConfig{
  717. User: sshUser,
  718. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  719. Auth: []gossh.AuthMethod{
  720. gossh.PasswordCallback(func() (secret string, err error) {
  721. if !tc.usesPassword {
  722. t.Error("unexpected use of PasswordCallback")
  723. return "", errors.New("unexpected use of PasswordCallback")
  724. }
  725. passwordUsed.Store(true)
  726. return "any-pass", nil
  727. }),
  728. },
  729. BannerCallback: func(message string) error {
  730. if len(tc.wantBanners) == 0 {
  731. t.Errorf("unexpected banner: %q", message)
  732. } else if message != tc.wantBanners[0] {
  733. t.Errorf("banner = %q; want %q", message, tc.wantBanners[0])
  734. } else {
  735. t.Logf("banner = %q", message)
  736. tc.wantBanners = tc.wantBanners[1:]
  737. }
  738. return nil
  739. },
  740. }
  741. var wg sync.WaitGroup
  742. wg.Add(1)
  743. go func() {
  744. defer wg.Done()
  745. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  746. if err != nil {
  747. if !tc.authErr {
  748. t.Errorf("client: %v", err)
  749. }
  750. return
  751. } else if tc.authErr {
  752. c.Close()
  753. t.Errorf("client: expected error, got nil")
  754. return
  755. }
  756. client := gossh.NewClient(c, chans, reqs)
  757. defer client.Close()
  758. session, err := client.NewSession()
  759. if err != nil {
  760. t.Errorf("client: %v", err)
  761. return
  762. }
  763. defer session.Close()
  764. _, err = session.CombinedOutput("echo Ran echo!")
  765. if err != nil {
  766. t.Errorf("client: %v", err)
  767. }
  768. }()
  769. if err := s.HandleSSHConn(dc); err != nil {
  770. t.Errorf("unexpected error: %v", err)
  771. }
  772. wg.Wait()
  773. if len(tc.wantBanners) > 0 {
  774. t.Errorf("missing banners: %v", tc.wantBanners)
  775. }
  776. })
  777. }
  778. }
  779. func TestSSH(t *testing.T) {
  780. var logf logger.Logf = t.Logf
  781. sys := &tsd.System{}
  782. eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker())
  783. if err != nil {
  784. t.Fatal(err)
  785. }
  786. sys.Set(eng)
  787. sys.Set(new(mem.Store))
  788. lb, err := ipnlocal.NewLocalBackend(logf, logid.PublicID{}, sys, 0)
  789. if err != nil {
  790. t.Fatal(err)
  791. }
  792. defer lb.Shutdown()
  793. dir := t.TempDir()
  794. lb.SetVarRoot(dir)
  795. srv := &server{
  796. lb: lb,
  797. logf: logf,
  798. }
  799. sc, err := srv.newConn()
  800. if err != nil {
  801. t.Fatal(err)
  802. }
  803. // Remove the auth checks for the test
  804. sc.insecureSkipTailscaleAuth = true
  805. u, err := user.Current()
  806. if err != nil {
  807. t.Fatal(err)
  808. }
  809. um, err := userLookup(u.Username)
  810. if err != nil {
  811. t.Fatal(err)
  812. }
  813. sc.localUser = um
  814. sc.info = &sshConnInfo{
  815. sshUser: "test",
  816. src: netip.MustParseAddrPort("1.2.3.4:32342"),
  817. dst: netip.MustParseAddrPort("1.2.3.5:22"),
  818. node: (&tailcfg.Node{}).View(),
  819. uprof: tailcfg.UserProfile{},
  820. }
  821. sc.action0 = &tailcfg.SSHAction{Accept: true}
  822. sc.finalAction = sc.action0
  823. sc.Handler = func(s ssh.Session) {
  824. sc.newSSHSession(s).run()
  825. }
  826. ln, err := net.Listen("tcp4", "127.0.0.1:0")
  827. if err != nil {
  828. t.Fatal(err)
  829. }
  830. defer ln.Close()
  831. port := ln.Addr().(*net.TCPAddr).Port
  832. go func() {
  833. for {
  834. c, err := ln.Accept()
  835. if err != nil {
  836. if !errors.Is(err, net.ErrClosed) {
  837. t.Errorf("Accept: %v", err)
  838. }
  839. return
  840. }
  841. go sc.HandleConn(c)
  842. }
  843. }()
  844. execSSH := func(args ...string) *exec.Cmd {
  845. cmd := exec.Command("ssh",
  846. "-F",
  847. "none",
  848. "-v",
  849. "-p", fmt.Sprint(port),
  850. "-o", "StrictHostKeyChecking=no",
  851. "[email protected]")
  852. cmd.Args = append(cmd.Args, args...)
  853. return cmd
  854. }
  855. t.Run("env", func(t *testing.T) {
  856. if cibuild.On() {
  857. t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")
  858. }
  859. cmd := execSSH("LANG=foo env")
  860. cmd.Env = append(os.Environ(), "LOCAL_ENV=bar")
  861. got, err := cmd.CombinedOutput()
  862. if err != nil {
  863. t.Fatal(err, string(got))
  864. }
  865. m := parseEnv(got)
  866. if got := m["USER"]; got == "" || got != u.Username {
  867. t.Errorf("USER = %q; want %q", got, u.Username)
  868. }
  869. if got := m["HOME"]; got == "" || got != u.HomeDir {
  870. t.Errorf("HOME = %q; want %q", got, u.HomeDir)
  871. }
  872. if got := m["PWD"]; got == "" || got != u.HomeDir {
  873. t.Errorf("PWD = %q; want %q", got, u.HomeDir)
  874. }
  875. if got := m["SHELL"]; got == "" {
  876. t.Errorf("no SHELL")
  877. }
  878. if got, want := m["LANG"], "foo"; got != want {
  879. t.Errorf("LANG = %q; want %q", got, want)
  880. }
  881. if got := m["LOCAL_ENV"]; got != "" {
  882. t.Errorf("LOCAL_ENV leaked over ssh: %v", got)
  883. }
  884. t.Logf("got: %+v", m)
  885. })
  886. t.Run("stdout_stderr", func(t *testing.T) {
  887. cmd := execSSH("sh", "-c", "echo foo; echo bar >&2")
  888. var outBuf, errBuf bytes.Buffer
  889. cmd.Stdout = &outBuf
  890. cmd.Stderr = &errBuf
  891. if err := cmd.Run(); err != nil {
  892. t.Fatal(err)
  893. }
  894. t.Logf("Got: %q and %q", outBuf.Bytes(), errBuf.Bytes())
  895. // TODO: figure out why these aren't right. should be
  896. // "foo\n" and "bar\n", not "\n" and "bar\n".
  897. })
  898. t.Run("large_file", func(t *testing.T) {
  899. const wantSize = 1e6
  900. var outBuf bytes.Buffer
  901. cmd := execSSH("head", "-c", strconv.Itoa(wantSize), "/dev/zero")
  902. cmd.Stdout = &outBuf
  903. if err := cmd.Run(); err != nil {
  904. t.Fatal(err)
  905. }
  906. if gotSize := outBuf.Len(); gotSize != wantSize {
  907. t.Fatalf("got %d, want %d", gotSize, int(wantSize))
  908. }
  909. })
  910. t.Run("stdin", func(t *testing.T) {
  911. if cibuild.On() {
  912. t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")
  913. }
  914. cmd := execSSH("cat")
  915. var outBuf bytes.Buffer
  916. cmd.Stdout = &outBuf
  917. const str = "foo\nbar\n"
  918. cmd.Stdin = strings.NewReader(str)
  919. if err := cmd.Run(); err != nil {
  920. t.Fatal(err)
  921. }
  922. if got := outBuf.String(); got != str {
  923. t.Errorf("got %q; want %q", got, str)
  924. }
  925. })
  926. }
  927. func parseEnv(out []byte) map[string]string {
  928. e := map[string]string{}
  929. lineread.Reader(bytes.NewReader(out), func(line []byte) error {
  930. i := bytes.IndexByte(line, '=')
  931. if i == -1 {
  932. return nil
  933. }
  934. e[string(line[:i])] = string(line[i+1:])
  935. return nil
  936. })
  937. return e
  938. }
  939. func TestPublicKeyFetching(t *testing.T) {
  940. var reqsTotal, reqsIfNoneMatchHit, reqsIfNoneMatchMiss int32
  941. ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  942. atomic.AddInt32((&reqsTotal), 1)
  943. etag := fmt.Sprintf("W/%q", sha256.Sum256([]byte(r.URL.Path)))
  944. w.Header().Set("Etag", etag)
  945. if v := r.Header.Get("If-None-Match"); v != "" {
  946. if v == etag {
  947. atomic.AddInt32(&reqsIfNoneMatchHit, 1)
  948. w.WriteHeader(304)
  949. return
  950. }
  951. atomic.AddInt32(&reqsIfNoneMatchMiss, 1)
  952. }
  953. io.WriteString(w, "foo\nbar\n"+string(r.URL.Path)+"\n")
  954. }))
  955. ts.StartTLS()
  956. defer ts.Close()
  957. keys := ts.URL
  958. clock := &tstest.Clock{}
  959. srv := &server{
  960. pubKeyHTTPClient: ts.Client(),
  961. timeNow: clock.Now,
  962. }
  963. for range 2 {
  964. got, err := srv.fetchPublicKeysURL(keys + "/alice.keys")
  965. if err != nil {
  966. t.Fatal(err)
  967. }
  968. if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) {
  969. t.Errorf("got %q; want %q", got, want)
  970. }
  971. }
  972. if got, want := atomic.LoadInt32(&reqsTotal), int32(1); got != want {
  973. t.Errorf("got %d requests; want %d", got, want)
  974. }
  975. if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(0); got != want {
  976. t.Errorf("got %d etag hits; want %d", got, want)
  977. }
  978. clock.Advance(5 * time.Minute)
  979. got, err := srv.fetchPublicKeysURL(keys + "/alice.keys")
  980. if err != nil {
  981. t.Fatal(err)
  982. }
  983. if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) {
  984. t.Errorf("got %q; want %q", got, want)
  985. }
  986. if got, want := atomic.LoadInt32(&reqsTotal), int32(2); got != want {
  987. t.Errorf("got %d requests; want %d", got, want)
  988. }
  989. if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(1); got != want {
  990. t.Errorf("got %d etag hits; want %d", got, want)
  991. }
  992. if got, want := atomic.LoadInt32(&reqsIfNoneMatchMiss), int32(0); got != want {
  993. t.Errorf("got %d etag misses; want %d", got, want)
  994. }
  995. }
  996. func TestExpandPublicKeyURL(t *testing.T) {
  997. c := &conn{
  998. info: &sshConnInfo{
  999. uprof: tailcfg.UserProfile{
  1000. LoginName: "[email protected]",
  1001. },
  1002. },
  1003. }
  1004. if got, want := c.expandPublicKeyURL("foo"), "foo"; got != want {
  1005. t.Errorf("basic: got %q; want %q", got, want)
  1006. }
  1007. if got, want := c.expandPublicKeyURL("https://example.com/$LOGINNAME_LOCALPART.keys"), "https://example.com/bar.keys"; got != want {
  1008. t.Errorf("localpart: got %q; want %q", got, want)
  1009. }
  1010. if got, want := c.expandPublicKeyURL("https://example.com/keys?email=$LOGINNAME_EMAIL"), "https://example.com/[email protected]"; got != want {
  1011. t.Errorf("email: got %q; want %q", got, want)
  1012. }
  1013. c.info = new(sshConnInfo)
  1014. if got, want := c.expandPublicKeyURL("https://example.com/keys?email=$LOGINNAME_EMAIL"), "https://example.com/keys?email="; got != want {
  1015. t.Errorf("on empty: got %q; want %q", got, want)
  1016. }
  1017. }
  1018. func TestAcceptEnvPair(t *testing.T) {
  1019. tests := []struct {
  1020. in string
  1021. want bool
  1022. }{
  1023. {"TERM=x", true},
  1024. {"term=x", false},
  1025. {"TERM", false},
  1026. {"LC_FOO=x", true},
  1027. {"LD_PRELOAD=naah", false},
  1028. {"TERM=screen-256color", true},
  1029. }
  1030. for _, tt := range tests {
  1031. if got := acceptEnvPair(tt.in); got != tt.want {
  1032. t.Errorf("for %q, got %v; want %v", tt.in, got, tt.want)
  1033. }
  1034. }
  1035. }
  1036. func TestPathFromPAMEnvLine(t *testing.T) {
  1037. u := &user.User{Username: "foo", HomeDir: "/Homes/Foo"}
  1038. tests := []struct {
  1039. line string
  1040. u *user.User
  1041. want string
  1042. }{
  1043. {"", u, ""},
  1044. {`PATH DEFAULT="/run/wrappers/bin:@{HOME}/.nix-profile/bin:/etc/profiles/per-user/@{PAM_USER}/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin"`,
  1045. u, "/run/wrappers/bin:/Homes/Foo/.nix-profile/bin:/etc/profiles/per-user/foo/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin"},
  1046. {`PATH DEFAULT="@{SOMETHING_ELSE}:nope:@{HOME}"`,
  1047. u, ""},
  1048. }
  1049. for i, tt := range tests {
  1050. got := pathFromPAMEnvLine([]byte(tt.line), tt.u)
  1051. if got != tt.want {
  1052. t.Errorf("%d. got %q; want %q", i, got, tt.want)
  1053. }
  1054. }
  1055. }
  1056. func TestExpandDefaultPathTmpl(t *testing.T) {
  1057. u := &user.User{Username: "foo", HomeDir: "/Homes/Foo"}
  1058. tests := []struct {
  1059. t string
  1060. u *user.User
  1061. want string
  1062. }{
  1063. {"", u, ""},
  1064. {`/run/wrappers/bin:@{HOME}/.nix-profile/bin:/etc/profiles/per-user/@{PAM_USER}/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin`,
  1065. u, "/run/wrappers/bin:/Homes/Foo/.nix-profile/bin:/etc/profiles/per-user/foo/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin"},
  1066. {`@{SOMETHING_ELSE}:nope:@{HOME}`, u, ""},
  1067. }
  1068. for i, tt := range tests {
  1069. got := expandDefaultPathTmpl(tt.t, tt.u)
  1070. if got != tt.want {
  1071. t.Errorf("%d. got %q; want %q", i, got, tt.want)
  1072. }
  1073. }
  1074. }
  1075. func TestPathFromPAMEnvLineOnNixOS(t *testing.T) {
  1076. if runtime.GOOS != "linux" {
  1077. t.Skip("skipping on non-linux")
  1078. }
  1079. if distro.Get() != distro.NixOS {
  1080. t.Skip("skipping on non-NixOS")
  1081. }
  1082. u, err := user.Current()
  1083. if err != nil {
  1084. t.Fatal(err)
  1085. }
  1086. got := defaultPathForUserOnNixOS(u)
  1087. if got == "" {
  1088. x, err := os.ReadFile("/etc/pam/environment")
  1089. t.Fatalf("no result. file was: err=%v, contents=%s", err, x)
  1090. }
  1091. t.Logf("success; got=%q", got)
  1092. }
  1093. func TestStdOsUserUserAssumptions(t *testing.T) {
  1094. v := reflect.TypeFor[user.User]()
  1095. if got, want := v.NumField(), 5; got != want {
  1096. t.Errorf("os/user.User has %v fields; this package assumes %v", got, want)
  1097. }
  1098. }