tailssh_test.go 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. //go:build linux || darwin
  4. package tailssh
  5. import (
  6. "bytes"
  7. "context"
  8. "crypto/ed25519"
  9. "crypto/rand"
  10. "crypto/sha256"
  11. "encoding/json"
  12. "errors"
  13. "fmt"
  14. "io"
  15. "net"
  16. "net/http"
  17. "net/http/httptest"
  18. "net/netip"
  19. "os"
  20. "os/exec"
  21. "os/user"
  22. "reflect"
  23. "runtime"
  24. "strconv"
  25. "strings"
  26. "sync"
  27. "sync/atomic"
  28. "testing"
  29. "time"
  30. gossh "github.com/tailscale/golang-x-crypto/ssh"
  31. "tailscale.com/ipn/ipnlocal"
  32. "tailscale.com/ipn/store/mem"
  33. "tailscale.com/net/memnet"
  34. "tailscale.com/net/tsdial"
  35. "tailscale.com/tailcfg"
  36. "tailscale.com/tempfork/gliderlabs/ssh"
  37. "tailscale.com/tsd"
  38. "tailscale.com/tstest"
  39. "tailscale.com/types/key"
  40. "tailscale.com/types/logger"
  41. "tailscale.com/types/logid"
  42. "tailscale.com/types/netmap"
  43. "tailscale.com/util/cibuild"
  44. "tailscale.com/util/lineread"
  45. "tailscale.com/util/must"
  46. "tailscale.com/version/distro"
  47. "tailscale.com/wgengine"
  48. )
  49. func TestMatchRule(t *testing.T) {
  50. someAction := new(tailcfg.SSHAction)
  51. tests := []struct {
  52. name string
  53. rule *tailcfg.SSHRule
  54. ci *sshConnInfo
  55. wantErr error
  56. wantUser string
  57. }{
  58. {
  59. name: "invalid-conn",
  60. rule: &tailcfg.SSHRule{
  61. Action: someAction,
  62. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  63. SSHUsers: map[string]string{
  64. "*": "ubuntu",
  65. },
  66. },
  67. wantErr: errInvalidConn,
  68. },
  69. {
  70. name: "nil-rule",
  71. ci: &sshConnInfo{},
  72. rule: nil,
  73. wantErr: errNilRule,
  74. },
  75. {
  76. name: "nil-action",
  77. ci: &sshConnInfo{},
  78. rule: &tailcfg.SSHRule{},
  79. wantErr: errNilAction,
  80. },
  81. {
  82. name: "expired",
  83. rule: &tailcfg.SSHRule{
  84. Action: someAction,
  85. RuleExpires: timePtr(time.Unix(100, 0)),
  86. },
  87. ci: &sshConnInfo{},
  88. wantErr: errRuleExpired,
  89. },
  90. {
  91. name: "no-principal",
  92. rule: &tailcfg.SSHRule{
  93. Action: someAction,
  94. SSHUsers: map[string]string{
  95. "*": "ubuntu",
  96. }},
  97. ci: &sshConnInfo{},
  98. wantErr: errPrincipalMatch,
  99. },
  100. {
  101. name: "no-user-match",
  102. rule: &tailcfg.SSHRule{
  103. Action: someAction,
  104. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  105. },
  106. ci: &sshConnInfo{sshUser: "alice"},
  107. wantErr: errUserMatch,
  108. },
  109. {
  110. name: "ok-wildcard",
  111. rule: &tailcfg.SSHRule{
  112. Action: someAction,
  113. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  114. SSHUsers: map[string]string{
  115. "*": "ubuntu",
  116. },
  117. },
  118. ci: &sshConnInfo{sshUser: "alice"},
  119. wantUser: "ubuntu",
  120. },
  121. {
  122. name: "ok-wildcard-and-nil-principal",
  123. rule: &tailcfg.SSHRule{
  124. Action: someAction,
  125. Principals: []*tailcfg.SSHPrincipal{
  126. nil, // don't crash on this
  127. {Any: true},
  128. },
  129. SSHUsers: map[string]string{
  130. "*": "ubuntu",
  131. },
  132. },
  133. ci: &sshConnInfo{sshUser: "alice"},
  134. wantUser: "ubuntu",
  135. },
  136. {
  137. name: "ok-exact",
  138. rule: &tailcfg.SSHRule{
  139. Action: someAction,
  140. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  141. SSHUsers: map[string]string{
  142. "*": "ubuntu",
  143. "alice": "thealice",
  144. },
  145. },
  146. ci: &sshConnInfo{sshUser: "alice"},
  147. wantUser: "thealice",
  148. },
  149. {
  150. name: "no-users-for-reject",
  151. rule: &tailcfg.SSHRule{
  152. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  153. Action: &tailcfg.SSHAction{Reject: true},
  154. },
  155. ci: &sshConnInfo{sshUser: "alice"},
  156. },
  157. {
  158. name: "match-principal-node-ip",
  159. rule: &tailcfg.SSHRule{
  160. Action: someAction,
  161. Principals: []*tailcfg.SSHPrincipal{{NodeIP: "1.2.3.4"}},
  162. SSHUsers: map[string]string{"*": "ubuntu"},
  163. },
  164. ci: &sshConnInfo{src: netip.MustParseAddrPort("1.2.3.4:30343")},
  165. wantUser: "ubuntu",
  166. },
  167. {
  168. name: "match-principal-node-id",
  169. rule: &tailcfg.SSHRule{
  170. Action: someAction,
  171. Principals: []*tailcfg.SSHPrincipal{{Node: "some-node-ID"}},
  172. SSHUsers: map[string]string{"*": "ubuntu"},
  173. },
  174. ci: &sshConnInfo{node: (&tailcfg.Node{StableID: "some-node-ID"}).View()},
  175. wantUser: "ubuntu",
  176. },
  177. {
  178. name: "match-principal-userlogin",
  179. rule: &tailcfg.SSHRule{
  180. Action: someAction,
  181. Principals: []*tailcfg.SSHPrincipal{{UserLogin: "[email protected]"}},
  182. SSHUsers: map[string]string{"*": "ubuntu"},
  183. },
  184. ci: &sshConnInfo{uprof: tailcfg.UserProfile{LoginName: "[email protected]"}},
  185. wantUser: "ubuntu",
  186. },
  187. {
  188. name: "ssh-user-equal",
  189. rule: &tailcfg.SSHRule{
  190. Action: someAction,
  191. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  192. SSHUsers: map[string]string{
  193. "*": "=",
  194. },
  195. },
  196. ci: &sshConnInfo{sshUser: "alice"},
  197. wantUser: "alice",
  198. },
  199. }
  200. for _, tt := range tests {
  201. t.Run(tt.name, func(t *testing.T) {
  202. c := &conn{
  203. info: tt.ci,
  204. srv: &server{logf: t.Logf},
  205. }
  206. got, gotUser, err := c.matchRule(tt.rule, nil)
  207. if err != tt.wantErr {
  208. t.Errorf("err = %v; want %v", err, tt.wantErr)
  209. }
  210. if gotUser != tt.wantUser {
  211. t.Errorf("user = %q; want %q", gotUser, tt.wantUser)
  212. }
  213. if err == nil && got == nil {
  214. t.Errorf("expected non-nil action on success")
  215. }
  216. })
  217. }
  218. }
  219. func timePtr(t time.Time) *time.Time { return &t }
  220. // localState implements ipnLocalBackend for testing.
  221. type localState struct {
  222. sshEnabled bool
  223. matchingRule *tailcfg.SSHRule
  224. // serverActions is a map of the action name to the action.
  225. // It is served for paths like https://unused/ssh-action/<action-name>.
  226. // The action name is the last part of the action URL.
  227. serverActions map[string]*tailcfg.SSHAction
  228. }
  229. var (
  230. currentUser = os.Getenv("USER") // Use the current user for the test.
  231. testSigner gossh.Signer
  232. testSignerOnce sync.Once
  233. )
  234. func (ts *localState) Dialer() *tsdial.Dialer {
  235. return &tsdial.Dialer{}
  236. }
  237. func (ts *localState) GetSSH_HostKeys() ([]gossh.Signer, error) {
  238. testSignerOnce.Do(func() {
  239. _, priv, err := ed25519.GenerateKey(rand.Reader)
  240. if err != nil {
  241. panic(err)
  242. }
  243. s, err := gossh.NewSignerFromSigner(priv)
  244. if err != nil {
  245. panic(err)
  246. }
  247. testSigner = s
  248. })
  249. return []gossh.Signer{testSigner}, nil
  250. }
  251. func (ts *localState) ShouldRunSSH() bool {
  252. return ts.sshEnabled
  253. }
  254. func (ts *localState) NetMap() *netmap.NetworkMap {
  255. var policy *tailcfg.SSHPolicy
  256. if ts.matchingRule != nil {
  257. policy = &tailcfg.SSHPolicy{
  258. Rules: []*tailcfg.SSHRule{
  259. ts.matchingRule,
  260. },
  261. }
  262. }
  263. return &netmap.NetworkMap{
  264. SelfNode: (&tailcfg.Node{
  265. ID: 1,
  266. }).View(),
  267. SSHPolicy: policy,
  268. }
  269. }
  270. func (ts *localState) WhoIs(ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) {
  271. return (&tailcfg.Node{
  272. ID: 2,
  273. StableID: "peer-id",
  274. }).View(), tailcfg.UserProfile{
  275. LoginName: "peer",
  276. }, true
  277. }
  278. func (ts *localState) DoNoiseRequest(req *http.Request) (*http.Response, error) {
  279. rec := httptest.NewRecorder()
  280. k, ok := strings.CutPrefix(req.URL.Path, "/ssh-action/")
  281. if !ok {
  282. rec.WriteHeader(http.StatusNotFound)
  283. }
  284. a, ok := ts.serverActions[k]
  285. if !ok {
  286. rec.WriteHeader(http.StatusNotFound)
  287. return rec.Result(), nil
  288. }
  289. rec.WriteHeader(http.StatusOK)
  290. if err := json.NewEncoder(rec).Encode(a); err != nil {
  291. return nil, err
  292. }
  293. return rec.Result(), nil
  294. }
  295. func (ts *localState) TailscaleVarRoot() string {
  296. return ""
  297. }
  298. func (ts *localState) NodeKey() key.NodePublic {
  299. return key.NewNode().Public()
  300. }
  301. func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule {
  302. return &tailcfg.SSHRule{
  303. SSHUsers: map[string]string{
  304. "*": currentUser,
  305. },
  306. Action: action,
  307. Principals: []*tailcfg.SSHPrincipal{
  308. {
  309. Any: true,
  310. },
  311. },
  312. }
  313. }
  314. func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
  315. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  316. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  317. }
  318. var handler http.HandlerFunc
  319. recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  320. handler(w, r)
  321. }))
  322. defer recordingServer.Close()
  323. s := &server{
  324. logf: t.Logf,
  325. lb: &localState{
  326. sshEnabled: true,
  327. matchingRule: newSSHRule(
  328. &tailcfg.SSHAction{
  329. Accept: true,
  330. Recorders: []netip.AddrPort{
  331. netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
  332. },
  333. OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
  334. RejectSessionWithMessage: "session rejected",
  335. TerminateSessionWithMessage: "session terminated",
  336. },
  337. },
  338. ),
  339. },
  340. }
  341. defer s.Shutdown()
  342. const sshUser = "alice"
  343. cfg := &gossh.ClientConfig{
  344. User: sshUser,
  345. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  346. }
  347. tests := []struct {
  348. name string
  349. handler func(w http.ResponseWriter, r *http.Request)
  350. sshCommand string
  351. wantClientOutput string
  352. clientOutputMustNotContain []string
  353. }{
  354. {
  355. name: "upload-denied",
  356. handler: func(w http.ResponseWriter, r *http.Request) {
  357. w.WriteHeader(http.StatusForbidden)
  358. },
  359. sshCommand: "echo hello",
  360. wantClientOutput: "session rejected\r\n",
  361. clientOutputMustNotContain: []string{"hello"},
  362. },
  363. {
  364. name: "upload-fails-after-starting",
  365. handler: func(w http.ResponseWriter, r *http.Request) {
  366. r.Body.Read(make([]byte, 1))
  367. time.Sleep(100 * time.Millisecond)
  368. w.WriteHeader(http.StatusInternalServerError)
  369. },
  370. sshCommand: "echo hello && sleep 1 && echo world",
  371. wantClientOutput: "\r\n\r\nsession terminated\r\n\r\n",
  372. clientOutputMustNotContain: []string{"world"},
  373. },
  374. }
  375. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  376. for _, tt := range tests {
  377. t.Run(tt.name, func(t *testing.T) {
  378. tstest.Replace(t, &handler, tt.handler)
  379. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  380. var wg sync.WaitGroup
  381. wg.Add(1)
  382. go func() {
  383. defer wg.Done()
  384. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  385. if err != nil {
  386. t.Errorf("client: %v", err)
  387. return
  388. }
  389. client := gossh.NewClient(c, chans, reqs)
  390. defer client.Close()
  391. session, err := client.NewSession()
  392. if err != nil {
  393. t.Errorf("client: %v", err)
  394. return
  395. }
  396. defer session.Close()
  397. t.Logf("client established session")
  398. got, err := session.CombinedOutput(tt.sshCommand)
  399. if err != nil {
  400. t.Logf("client got: %q: %v", got, err)
  401. } else {
  402. t.Errorf("client did not get kicked out: %q", got)
  403. }
  404. gotStr := string(got)
  405. if !strings.HasSuffix(gotStr, tt.wantClientOutput) {
  406. t.Errorf("client got %q, want %q", got, tt.wantClientOutput)
  407. }
  408. for _, x := range tt.clientOutputMustNotContain {
  409. if strings.Contains(gotStr, x) {
  410. t.Errorf("client output must not contain %q", x)
  411. }
  412. }
  413. }()
  414. if err := s.HandleSSHConn(dc); err != nil {
  415. t.Errorf("unexpected error: %v", err)
  416. }
  417. wg.Wait()
  418. })
  419. }
  420. }
  421. func TestMultipleRecorders(t *testing.T) {
  422. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  423. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  424. }
  425. done := make(chan struct{})
  426. recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  427. defer close(done)
  428. io.ReadAll(r.Body)
  429. w.WriteHeader(http.StatusOK)
  430. }))
  431. defer recordingServer.Close()
  432. badRecorder, err := net.Listen("tcp", ":0")
  433. if err != nil {
  434. t.Fatal(err)
  435. }
  436. badRecorderAddr := badRecorder.Addr().String()
  437. badRecorder.Close()
  438. badRecordingServer500 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  439. w.WriteHeader(500)
  440. }))
  441. defer badRecordingServer500.Close()
  442. badRecordingServer200 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  443. w.WriteHeader(200)
  444. }))
  445. defer badRecordingServer200.Close()
  446. s := &server{
  447. logf: t.Logf,
  448. lb: &localState{
  449. sshEnabled: true,
  450. matchingRule: newSSHRule(
  451. &tailcfg.SSHAction{
  452. Accept: true,
  453. Recorders: []netip.AddrPort{
  454. netip.MustParseAddrPort(badRecorderAddr),
  455. netip.MustParseAddrPort(badRecordingServer500.Listener.Addr().String()),
  456. netip.MustParseAddrPort(badRecordingServer200.Listener.Addr().String()),
  457. netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
  458. },
  459. OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
  460. RejectSessionWithMessage: "session rejected",
  461. TerminateSessionWithMessage: "session terminated",
  462. },
  463. },
  464. ),
  465. },
  466. }
  467. defer s.Shutdown()
  468. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  469. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  470. const sshUser = "alice"
  471. cfg := &gossh.ClientConfig{
  472. User: sshUser,
  473. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  474. }
  475. var wg sync.WaitGroup
  476. wg.Add(1)
  477. go func() {
  478. defer wg.Done()
  479. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  480. if err != nil {
  481. t.Errorf("client: %v", err)
  482. return
  483. }
  484. client := gossh.NewClient(c, chans, reqs)
  485. defer client.Close()
  486. session, err := client.NewSession()
  487. if err != nil {
  488. t.Errorf("client: %v", err)
  489. return
  490. }
  491. defer session.Close()
  492. t.Logf("client established session")
  493. out, err := session.CombinedOutput("echo Ran echo!")
  494. if err != nil {
  495. t.Errorf("client: %v", err)
  496. }
  497. if string(out) != "Ran echo!\n" {
  498. t.Errorf("client: unexpected output: %q", out)
  499. }
  500. }()
  501. if err := s.HandleSSHConn(dc); err != nil {
  502. t.Errorf("unexpected error: %v", err)
  503. }
  504. wg.Wait()
  505. select {
  506. case <-done:
  507. case <-time.After(1 * time.Second):
  508. t.Fatal("timed out waiting for recording")
  509. }
  510. }
  511. // TestSSHRecordingNonInteractive tests that the SSH server records the SSH session
  512. // when the client is not interactive (i.e. no PTY).
  513. // It starts a local SSH server and a recording server. The recording server
  514. // records the SSH session and returns it to the test.
  515. // The test then verifies that the recording has a valid CastHeader, it does not
  516. // validate the contents of the recording.
  517. func TestSSHRecordingNonInteractive(t *testing.T) {
  518. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  519. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  520. }
  521. var recording []byte
  522. ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  523. recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  524. defer cancel()
  525. var err error
  526. recording, err = io.ReadAll(r.Body)
  527. if err != nil {
  528. t.Error(err)
  529. return
  530. }
  531. }))
  532. defer recordingServer.Close()
  533. s := &server{
  534. logf: logger.Discard,
  535. lb: &localState{
  536. sshEnabled: true,
  537. matchingRule: newSSHRule(
  538. &tailcfg.SSHAction{
  539. Accept: true,
  540. Recorders: []netip.AddrPort{
  541. must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())),
  542. },
  543. OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
  544. RejectSessionWithMessage: "session rejected",
  545. TerminateSessionWithMessage: "session terminated",
  546. },
  547. },
  548. ),
  549. },
  550. }
  551. defer s.Shutdown()
  552. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  553. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  554. const sshUser = "alice"
  555. cfg := &gossh.ClientConfig{
  556. User: sshUser,
  557. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  558. }
  559. var wg sync.WaitGroup
  560. wg.Add(1)
  561. go func() {
  562. defer wg.Done()
  563. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  564. if err != nil {
  565. t.Errorf("client: %v", err)
  566. return
  567. }
  568. client := gossh.NewClient(c, chans, reqs)
  569. defer client.Close()
  570. session, err := client.NewSession()
  571. if err != nil {
  572. t.Errorf("client: %v", err)
  573. return
  574. }
  575. defer session.Close()
  576. t.Logf("client established session")
  577. _, err = session.CombinedOutput("echo Ran echo!")
  578. if err != nil {
  579. t.Errorf("client: %v", err)
  580. }
  581. }()
  582. if err := s.HandleSSHConn(dc); err != nil {
  583. t.Errorf("unexpected error: %v", err)
  584. }
  585. wg.Wait()
  586. <-ctx.Done() // wait for recording to finish
  587. var ch CastHeader
  588. if err := json.NewDecoder(bytes.NewReader(recording)).Decode(&ch); err != nil {
  589. t.Fatal(err)
  590. }
  591. if ch.SSHUser != sshUser {
  592. t.Errorf("SSHUser = %q; want %q", ch.SSHUser, sshUser)
  593. }
  594. if ch.Command != "echo Ran echo!" {
  595. t.Errorf("Command = %q; want %q", ch.Command, "echo Ran echo!")
  596. }
  597. }
  598. func TestSSHAuthFlow(t *testing.T) {
  599. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  600. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  601. }
  602. acceptRule := newSSHRule(&tailcfg.SSHAction{
  603. Accept: true,
  604. Message: "Welcome to Tailscale SSH!",
  605. })
  606. rejectRule := newSSHRule(&tailcfg.SSHAction{
  607. Reject: true,
  608. Message: "Go Away!",
  609. })
  610. tests := []struct {
  611. name string
  612. sshUser string // defaults to alice
  613. state *localState
  614. wantBanners []string
  615. usesPassword bool
  616. authErr bool
  617. }{
  618. {
  619. name: "no-policy",
  620. state: &localState{
  621. sshEnabled: true,
  622. },
  623. authErr: true,
  624. },
  625. {
  626. name: "accept",
  627. state: &localState{
  628. sshEnabled: true,
  629. matchingRule: acceptRule,
  630. },
  631. wantBanners: []string{"Welcome to Tailscale SSH!"},
  632. },
  633. {
  634. name: "reject",
  635. state: &localState{
  636. sshEnabled: true,
  637. matchingRule: rejectRule,
  638. },
  639. wantBanners: []string{"Go Away!"},
  640. authErr: true,
  641. },
  642. {
  643. name: "simple-check",
  644. state: &localState{
  645. sshEnabled: true,
  646. matchingRule: newSSHRule(&tailcfg.SSHAction{
  647. HoldAndDelegate: "https://unused/ssh-action/accept",
  648. }),
  649. serverActions: map[string]*tailcfg.SSHAction{
  650. "accept": acceptRule.Action,
  651. },
  652. },
  653. wantBanners: []string{"Welcome to Tailscale SSH!"},
  654. },
  655. {
  656. name: "multi-check",
  657. state: &localState{
  658. sshEnabled: true,
  659. matchingRule: newSSHRule(&tailcfg.SSHAction{
  660. Message: "First",
  661. HoldAndDelegate: "https://unused/ssh-action/check1",
  662. }),
  663. serverActions: map[string]*tailcfg.SSHAction{
  664. "check1": {
  665. Message: "url-here",
  666. HoldAndDelegate: "https://unused/ssh-action/check2",
  667. },
  668. "check2": acceptRule.Action,
  669. },
  670. },
  671. wantBanners: []string{"First", "url-here", "Welcome to Tailscale SSH!"},
  672. },
  673. {
  674. name: "check-reject",
  675. state: &localState{
  676. sshEnabled: true,
  677. matchingRule: newSSHRule(&tailcfg.SSHAction{
  678. Message: "First",
  679. HoldAndDelegate: "https://unused/ssh-action/reject",
  680. }),
  681. serverActions: map[string]*tailcfg.SSHAction{
  682. "reject": rejectRule.Action,
  683. },
  684. },
  685. wantBanners: []string{"First", "Go Away!"},
  686. authErr: true,
  687. },
  688. {
  689. name: "force-password-auth",
  690. sshUser: "alice+password",
  691. state: &localState{
  692. sshEnabled: true,
  693. matchingRule: acceptRule,
  694. },
  695. usesPassword: true,
  696. wantBanners: []string{"Welcome to Tailscale SSH!"},
  697. },
  698. }
  699. s := &server{
  700. logf: logger.Discard,
  701. }
  702. defer s.Shutdown()
  703. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  704. for _, tc := range tests {
  705. t.Run(tc.name, func(t *testing.T) {
  706. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  707. s.lb = tc.state
  708. sshUser := "alice"
  709. if tc.sshUser != "" {
  710. sshUser = tc.sshUser
  711. }
  712. var passwordUsed atomic.Bool
  713. cfg := &gossh.ClientConfig{
  714. User: sshUser,
  715. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  716. Auth: []gossh.AuthMethod{
  717. gossh.PasswordCallback(func() (secret string, err error) {
  718. if !tc.usesPassword {
  719. t.Error("unexpected use of PasswordCallback")
  720. return "", errors.New("unexpected use of PasswordCallback")
  721. }
  722. passwordUsed.Store(true)
  723. return "any-pass", nil
  724. }),
  725. },
  726. BannerCallback: func(message string) error {
  727. if len(tc.wantBanners) == 0 {
  728. t.Errorf("unexpected banner: %q", message)
  729. } else if message != tc.wantBanners[0] {
  730. t.Errorf("banner = %q; want %q", message, tc.wantBanners[0])
  731. } else {
  732. t.Logf("banner = %q", message)
  733. tc.wantBanners = tc.wantBanners[1:]
  734. }
  735. return nil
  736. },
  737. }
  738. var wg sync.WaitGroup
  739. wg.Add(1)
  740. go func() {
  741. defer wg.Done()
  742. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  743. if err != nil {
  744. if !tc.authErr {
  745. t.Errorf("client: %v", err)
  746. }
  747. return
  748. } else if tc.authErr {
  749. c.Close()
  750. t.Errorf("client: expected error, got nil")
  751. return
  752. }
  753. client := gossh.NewClient(c, chans, reqs)
  754. defer client.Close()
  755. session, err := client.NewSession()
  756. if err != nil {
  757. t.Errorf("client: %v", err)
  758. return
  759. }
  760. defer session.Close()
  761. _, err = session.CombinedOutput("echo Ran echo!")
  762. if err != nil {
  763. t.Errorf("client: %v", err)
  764. }
  765. }()
  766. if err := s.HandleSSHConn(dc); err != nil {
  767. t.Errorf("unexpected error: %v", err)
  768. }
  769. wg.Wait()
  770. if len(tc.wantBanners) > 0 {
  771. t.Errorf("missing banners: %v", tc.wantBanners)
  772. }
  773. })
  774. }
  775. }
  776. func TestSSH(t *testing.T) {
  777. var logf logger.Logf = t.Logf
  778. sys := &tsd.System{}
  779. eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set)
  780. if err != nil {
  781. t.Fatal(err)
  782. }
  783. sys.Set(eng)
  784. sys.Set(new(mem.Store))
  785. lb, err := ipnlocal.NewLocalBackend(logf, logid.PublicID{}, sys, 0)
  786. if err != nil {
  787. t.Fatal(err)
  788. }
  789. defer lb.Shutdown()
  790. dir := t.TempDir()
  791. lb.SetVarRoot(dir)
  792. srv := &server{
  793. lb: lb,
  794. logf: logf,
  795. }
  796. sc, err := srv.newConn()
  797. if err != nil {
  798. t.Fatal(err)
  799. }
  800. // Remove the auth checks for the test
  801. sc.insecureSkipTailscaleAuth = true
  802. u, err := user.Current()
  803. if err != nil {
  804. t.Fatal(err)
  805. }
  806. um, err := userLookup(u.Username)
  807. if err != nil {
  808. t.Fatal(err)
  809. }
  810. sc.localUser = um
  811. sc.info = &sshConnInfo{
  812. sshUser: "test",
  813. src: netip.MustParseAddrPort("1.2.3.4:32342"),
  814. dst: netip.MustParseAddrPort("1.2.3.5:22"),
  815. node: (&tailcfg.Node{}).View(),
  816. uprof: tailcfg.UserProfile{},
  817. }
  818. sc.action0 = &tailcfg.SSHAction{Accept: true}
  819. sc.finalAction = sc.action0
  820. sc.Handler = func(s ssh.Session) {
  821. sc.newSSHSession(s).run()
  822. }
  823. ln, err := net.Listen("tcp4", "127.0.0.1:0")
  824. if err != nil {
  825. t.Fatal(err)
  826. }
  827. defer ln.Close()
  828. port := ln.Addr().(*net.TCPAddr).Port
  829. go func() {
  830. for {
  831. c, err := ln.Accept()
  832. if err != nil {
  833. if !errors.Is(err, net.ErrClosed) {
  834. t.Errorf("Accept: %v", err)
  835. }
  836. return
  837. }
  838. go sc.HandleConn(c)
  839. }
  840. }()
  841. execSSH := func(args ...string) *exec.Cmd {
  842. cmd := exec.Command("ssh",
  843. "-F",
  844. "none",
  845. "-v",
  846. "-p", fmt.Sprint(port),
  847. "-o", "StrictHostKeyChecking=no",
  848. "[email protected]")
  849. cmd.Args = append(cmd.Args, args...)
  850. return cmd
  851. }
  852. t.Run("env", func(t *testing.T) {
  853. if cibuild.On() {
  854. t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")
  855. }
  856. cmd := execSSH("LANG=foo env")
  857. cmd.Env = append(os.Environ(), "LOCAL_ENV=bar")
  858. got, err := cmd.CombinedOutput()
  859. if err != nil {
  860. t.Fatal(err, string(got))
  861. }
  862. m := parseEnv(got)
  863. if got := m["USER"]; got == "" || got != u.Username {
  864. t.Errorf("USER = %q; want %q", got, u.Username)
  865. }
  866. if got := m["HOME"]; got == "" || got != u.HomeDir {
  867. t.Errorf("HOME = %q; want %q", got, u.HomeDir)
  868. }
  869. if got := m["PWD"]; got == "" || got != u.HomeDir {
  870. t.Errorf("PWD = %q; want %q", got, u.HomeDir)
  871. }
  872. if got := m["SHELL"]; got == "" {
  873. t.Errorf("no SHELL")
  874. }
  875. if got, want := m["LANG"], "foo"; got != want {
  876. t.Errorf("LANG = %q; want %q", got, want)
  877. }
  878. if got := m["LOCAL_ENV"]; got != "" {
  879. t.Errorf("LOCAL_ENV leaked over ssh: %v", got)
  880. }
  881. t.Logf("got: %+v", m)
  882. })
  883. t.Run("stdout_stderr", func(t *testing.T) {
  884. cmd := execSSH("sh", "-c", "echo foo; echo bar >&2")
  885. var outBuf, errBuf bytes.Buffer
  886. cmd.Stdout = &outBuf
  887. cmd.Stderr = &errBuf
  888. if err := cmd.Run(); err != nil {
  889. t.Fatal(err)
  890. }
  891. t.Logf("Got: %q and %q", outBuf.Bytes(), errBuf.Bytes())
  892. // TODO: figure out why these aren't right. should be
  893. // "foo\n" and "bar\n", not "\n" and "bar\n".
  894. })
  895. t.Run("large_file", func(t *testing.T) {
  896. const wantSize = 1e6
  897. var outBuf bytes.Buffer
  898. cmd := execSSH("head", "-c", strconv.Itoa(wantSize), "/dev/zero")
  899. cmd.Stdout = &outBuf
  900. if err := cmd.Run(); err != nil {
  901. t.Fatal(err)
  902. }
  903. if gotSize := outBuf.Len(); gotSize != wantSize {
  904. t.Fatalf("got %d, want %d", gotSize, int(wantSize))
  905. }
  906. })
  907. t.Run("stdin", func(t *testing.T) {
  908. if cibuild.On() {
  909. t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")
  910. }
  911. cmd := execSSH("cat")
  912. var outBuf bytes.Buffer
  913. cmd.Stdout = &outBuf
  914. const str = "foo\nbar\n"
  915. cmd.Stdin = strings.NewReader(str)
  916. if err := cmd.Run(); err != nil {
  917. t.Fatal(err)
  918. }
  919. if got := outBuf.String(); got != str {
  920. t.Errorf("got %q; want %q", got, str)
  921. }
  922. })
  923. }
  924. func parseEnv(out []byte) map[string]string {
  925. e := map[string]string{}
  926. lineread.Reader(bytes.NewReader(out), func(line []byte) error {
  927. i := bytes.IndexByte(line, '=')
  928. if i == -1 {
  929. return nil
  930. }
  931. e[string(line[:i])] = string(line[i+1:])
  932. return nil
  933. })
  934. return e
  935. }
  936. func TestPublicKeyFetching(t *testing.T) {
  937. var reqsTotal, reqsIfNoneMatchHit, reqsIfNoneMatchMiss int32
  938. ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  939. atomic.AddInt32((&reqsTotal), 1)
  940. etag := fmt.Sprintf("W/%q", sha256.Sum256([]byte(r.URL.Path)))
  941. w.Header().Set("Etag", etag)
  942. if v := r.Header.Get("If-None-Match"); v != "" {
  943. if v == etag {
  944. atomic.AddInt32(&reqsIfNoneMatchHit, 1)
  945. w.WriteHeader(304)
  946. return
  947. }
  948. atomic.AddInt32(&reqsIfNoneMatchMiss, 1)
  949. }
  950. io.WriteString(w, "foo\nbar\n"+string(r.URL.Path)+"\n")
  951. }))
  952. ts.StartTLS()
  953. defer ts.Close()
  954. keys := ts.URL
  955. clock := &tstest.Clock{}
  956. srv := &server{
  957. pubKeyHTTPClient: ts.Client(),
  958. timeNow: clock.Now,
  959. }
  960. for i := 0; i < 2; i++ {
  961. got, err := srv.fetchPublicKeysURL(keys + "/alice.keys")
  962. if err != nil {
  963. t.Fatal(err)
  964. }
  965. if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) {
  966. t.Errorf("got %q; want %q", got, want)
  967. }
  968. }
  969. if got, want := atomic.LoadInt32(&reqsTotal), int32(1); got != want {
  970. t.Errorf("got %d requests; want %d", got, want)
  971. }
  972. if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(0); got != want {
  973. t.Errorf("got %d etag hits; want %d", got, want)
  974. }
  975. clock.Advance(5 * time.Minute)
  976. got, err := srv.fetchPublicKeysURL(keys + "/alice.keys")
  977. if err != nil {
  978. t.Fatal(err)
  979. }
  980. if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) {
  981. t.Errorf("got %q; want %q", got, want)
  982. }
  983. if got, want := atomic.LoadInt32(&reqsTotal), int32(2); got != want {
  984. t.Errorf("got %d requests; want %d", got, want)
  985. }
  986. if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(1); got != want {
  987. t.Errorf("got %d etag hits; want %d", got, want)
  988. }
  989. if got, want := atomic.LoadInt32(&reqsIfNoneMatchMiss), int32(0); got != want {
  990. t.Errorf("got %d etag misses; want %d", got, want)
  991. }
  992. }
  993. func TestExpandPublicKeyURL(t *testing.T) {
  994. c := &conn{
  995. info: &sshConnInfo{
  996. uprof: tailcfg.UserProfile{
  997. LoginName: "[email protected]",
  998. },
  999. },
  1000. }
  1001. if got, want := c.expandPublicKeyURL("foo"), "foo"; got != want {
  1002. t.Errorf("basic: got %q; want %q", got, want)
  1003. }
  1004. if got, want := c.expandPublicKeyURL("https://example.com/$LOGINNAME_LOCALPART.keys"), "https://example.com/bar.keys"; got != want {
  1005. t.Errorf("localpart: got %q; want %q", got, want)
  1006. }
  1007. if got, want := c.expandPublicKeyURL("https://example.com/keys?email=$LOGINNAME_EMAIL"), "https://example.com/[email protected]"; got != want {
  1008. t.Errorf("email: got %q; want %q", got, want)
  1009. }
  1010. c.info = new(sshConnInfo)
  1011. if got, want := c.expandPublicKeyURL("https://example.com/keys?email=$LOGINNAME_EMAIL"), "https://example.com/keys?email="; got != want {
  1012. t.Errorf("on empty: got %q; want %q", got, want)
  1013. }
  1014. }
  1015. func TestAcceptEnvPair(t *testing.T) {
  1016. tests := []struct {
  1017. in string
  1018. want bool
  1019. }{
  1020. {"TERM=x", true},
  1021. {"term=x", false},
  1022. {"TERM", false},
  1023. {"LC_FOO=x", true},
  1024. {"LD_PRELOAD=naah", false},
  1025. {"TERM=screen-256color", true},
  1026. }
  1027. for _, tt := range tests {
  1028. if got := acceptEnvPair(tt.in); got != tt.want {
  1029. t.Errorf("for %q, got %v; want %v", tt.in, got, tt.want)
  1030. }
  1031. }
  1032. }
  1033. func TestPathFromPAMEnvLine(t *testing.T) {
  1034. u := &user.User{Username: "foo", HomeDir: "/Homes/Foo"}
  1035. tests := []struct {
  1036. line string
  1037. u *user.User
  1038. want string
  1039. }{
  1040. {"", u, ""},
  1041. {`PATH DEFAULT="/run/wrappers/bin:@{HOME}/.nix-profile/bin:/etc/profiles/per-user/@{PAM_USER}/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin"`,
  1042. u, "/run/wrappers/bin:/Homes/Foo/.nix-profile/bin:/etc/profiles/per-user/foo/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin"},
  1043. {`PATH DEFAULT="@{SOMETHING_ELSE}:nope:@{HOME}"`,
  1044. u, ""},
  1045. }
  1046. for i, tt := range tests {
  1047. got := pathFromPAMEnvLine([]byte(tt.line), tt.u)
  1048. if got != tt.want {
  1049. t.Errorf("%d. got %q; want %q", i, got, tt.want)
  1050. }
  1051. }
  1052. }
  1053. func TestExpandDefaultPathTmpl(t *testing.T) {
  1054. u := &user.User{Username: "foo", HomeDir: "/Homes/Foo"}
  1055. tests := []struct {
  1056. t string
  1057. u *user.User
  1058. want string
  1059. }{
  1060. {"", u, ""},
  1061. {`/run/wrappers/bin:@{HOME}/.nix-profile/bin:/etc/profiles/per-user/@{PAM_USER}/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin`,
  1062. u, "/run/wrappers/bin:/Homes/Foo/.nix-profile/bin:/etc/profiles/per-user/foo/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin"},
  1063. {`@{SOMETHING_ELSE}:nope:@{HOME}`, u, ""},
  1064. }
  1065. for i, tt := range tests {
  1066. got := expandDefaultPathTmpl(tt.t, tt.u)
  1067. if got != tt.want {
  1068. t.Errorf("%d. got %q; want %q", i, got, tt.want)
  1069. }
  1070. }
  1071. }
  1072. func TestPathFromPAMEnvLineOnNixOS(t *testing.T) {
  1073. if runtime.GOOS != "linux" {
  1074. t.Skip("skipping on non-linux")
  1075. }
  1076. if distro.Get() != distro.NixOS {
  1077. t.Skip("skipping on non-NixOS")
  1078. }
  1079. u, err := user.Current()
  1080. if err != nil {
  1081. t.Fatal(err)
  1082. }
  1083. got := defaultPathForUserOnNixOS(u)
  1084. if got == "" {
  1085. x, err := os.ReadFile("/etc/pam/environment")
  1086. t.Fatalf("no result. file was: err=%v, contents=%s", err, x)
  1087. }
  1088. t.Logf("success; got=%q", got)
  1089. }
  1090. func TestStdOsUserUserAssumptions(t *testing.T) {
  1091. v := reflect.TypeOf(user.User{})
  1092. if got, want := v.NumField(), 5; got != want {
  1093. t.Errorf("os/user.User has %v fields; this package assumes %v", got, want)
  1094. }
  1095. }